Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce mpsc::Receiver peek #7156

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions tokio/src/sync/mpsc/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,21 @@ impl<T> Block<T> {
///
/// * No concurrent access to the slot.
pub(crate) unsafe fn read(&self, slot_index: usize) -> Option<Read<T>> {
let offset = offset(slot_index);

let ready_bits = self.header.ready_slots.load(Acquire);

if !is_ready(ready_bits, offset) {
if is_tx_closed(ready_bits) {
return Some(Read::Closed);
}
self.ready_offset_read(slot_index, |offset| {
// Get the value
let value = self.values[offset].with(|ptr| ptr::read(ptr));

return None;
}
Some(Read::Value(value.assume_init()))
})
}

// Get the value
let value = self.values[offset].with(|ptr| ptr::read(ptr));
pub(crate) unsafe fn peek(&self, slot_index: usize) -> Option<Read<&T>> {
self.ready_offset_read(slot_index, |offset| {
// Get the value
let value = self.values[offset].with(|ptr| &*ptr);

Some(Read::Value(value.assume_init()))
Some(Read::Value(value.assume_init_ref()))
})
}

/// Returns true if *this* block has a value in the given slot.
Expand Down Expand Up @@ -404,6 +403,23 @@ impl<T> Block<T> {
crate::loom::thread::yield_now();
}
}

unsafe fn ready_offset_read<F, Y>(&self, slot_index: usize, operation: F) -> Option<Read<Y>>
where
F: FnOnce(usize) -> Option<Read<Y>>,
{
let offset = offset(slot_index);
let ready_bits = self.header.ready_slots.load(Acquire);

if !is_ready(ready_bits, offset) {
if is_tx_closed(ready_bits) {
return Some(Read::Closed);
}
return None;
}

operation(offset)
}
}

/// Returns `true` if the specified slot has a value ready to be consumed.
Expand Down
64 changes: 64 additions & 0 deletions tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,70 @@ impl<T> Receiver<T> {
Receiver { chan }
}

/// Peeks at the next value available for this receiver.
/// The referenced value returned by `peek` is identical to the one that will be returned by
/// the next call to `recv`.
/// Once a value is consumed by `recv`, it can no longer be peeked.
///
/// This method returns `None` if the channel has been closed and there are
/// no remaining messages in the channel's buffer. This indicates that no
/// further values can ever be received from this `Receiver`. The channel is
/// closed when all senders have been dropped, or when [`close`] is called.
///
/// If there are no messages in the channel's buffer, but the channel has
/// not yet been closed, this method will sleep until a message is sent or
/// the channel is closed. Note that if [`close`] is called, but there are
/// still outstanding [`Permits`] from before it was closed, the channel is
/// not considered closed by `peek` until the permits are released.
pub async fn peek(&self) -> Option<&T> {
use std::future::poll_fn;
poll_fn(|cx| self.chan.peek(cx)).await
}

/// Tries to peek the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// Unlike the [`poll_recv`] method, this method will never return an
/// [`Empty`] error spuriously.
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [`poll_recv`]: Self::poll_recv
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
///
/// assert_eq!(Ok(&"hello"), rx.try_peek());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_peek());
///
/// tx.send("hello").await.unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok(&"hello"), rx.try_peek());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_peek());
/// }
/// ```
pub fn try_peek(&self) -> Result<&T, TryRecvError> {
self.chan.try_peek()
}

/// Receives the next value for this receiver.
///
/// This method returns `None` if the channel has been closed and there are
Expand Down
110 changes: 86 additions & 24 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,17 @@ impl<T, S: Semaphore> Rx<T, S> {
})
}

/// Receive the next value
pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
pub(crate) fn peek(&self, cx: &mut Context<'_>) -> Poll<Option<&T>> {
use super::block::Read;

ready!(crate::trace::trace_leaf(cx));

// Keep track of task budget
let coop = ready!(crate::runtime::coop::poll_proceed(cx));

self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };

macro_rules! try_recv {
self.inner.rx_fields.with(|rx_fields_ptr| {
let rx_fields = unsafe { &*rx_fields_ptr };
macro_rules! try_peek {
() => {
match rx_fields.list.pop(&self.inner.tx) {
match rx_fields.list.peek() {
Some(Read::Value(value)) => {
self.inner.semaphore.add_permit();
coop.made_progress();
return Ready(Some(value));
}
Expand All @@ -316,19 +310,14 @@ impl<T, S: Semaphore> Rx<T, S> {
coop.made_progress();
return Ready(None);
}
None => {} // fall through
None => {}
}
};
}

try_recv!();

try_peek!();
self.inner.rx_waker.register_by_ref(cx.waker());

// It is possible that a value was pushed between attempting to read
// and registering the task, so we have to check the channel a
// second time here.
try_recv!();
try_peek!();

if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
coop.made_progress();
Expand All @@ -339,6 +328,79 @@ impl<T, S: Semaphore> Rx<T, S> {
})
}

// TODO: Consolidate with try_pop
pub(crate) fn try_peek(&self) -> Result<&T, TryRecvError> {
use super::list::TryReadResult;

self.inner.rx_fields.with(|rx_fields_ptr| {
let rx_fields = unsafe { &*rx_fields_ptr };

macro_rules! try_peek {
() => {
match rx_fields.list.try_peek(&self.inner.tx) {
TryReadResult::Ok(value) => {
return Ok(value);
}
TryReadResult::Closed => return Err(TryRecvError::Disconnected),
TryReadResult::Empty => return Err(TryRecvError::Empty),
TryReadResult::Busy => {} // fall through
}
};
}

try_peek!();

// If a previous `poll_recv` call has set a waker, we wake it here.
// This allows us to put our own CachedParkThread waker in the
// AtomicWaker slot instead.
//
// This is not a spurious wakeup to `poll_recv` since we just got a
// Busy from `try_peek`, which only happens if there are messages in
// the queue.
self.inner.rx_waker.wake();

// Park the thread until the problematic send has completed.
let mut park = CachedParkThread::new();
let waker = park.waker().unwrap();
loop {
self.inner.rx_waker.register_by_ref(&waker);
// It is possible that the problematic send has now completed,
// so we have to check for messages again.
try_peek!();
park.park();
}
})
}

/// Receive the next value by peeking and then advancing the list
pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
use std::ptr;

ready!(crate::trace::trace_leaf(cx));

// Keep track of task budget
let coop = ready!(crate::runtime::coop::poll_proceed(cx));

match self.peek(cx) {
Ready(Some(peeked_value)) => {
let val = unsafe { ptr::read(peeked_value) };
self.advance_rx();
self.inner.semaphore.add_permit();
coop.made_progress();
Poll::Ready(Some(val))
}
Ready(None) => Poll::Ready(None),
Pending => Poll::Pending,
}
}

fn advance_rx(&mut self) {
self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };
rx_fields.list.advance(&self.inner.tx);
});
}

/// Receives up to `limit` values into `buffer`
///
/// For `limit > 0`, receives up to limit values into `buffer`.
Expand Down Expand Up @@ -426,21 +488,21 @@ impl<T, S: Semaphore> Rx<T, S> {

/// Try to receive the next value.
pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
use super::list::TryPopResult;
use super::list::TryReadResult;

self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };

macro_rules! try_recv {
() => {
match rx_fields.list.try_pop(&self.inner.tx) {
TryPopResult::Ok(value) => {
TryReadResult::Ok(value) => {
self.inner.semaphore.add_permit();
return Ok(value);
}
TryPopResult::Closed => return Err(TryRecvError::Disconnected),
TryPopResult::Empty => return Err(TryRecvError::Empty),
TryPopResult::Busy => {} // fall through
TryReadResult::Closed => return Err(TryRecvError::Disconnected),
TryReadResult::Empty => return Err(TryRecvError::Empty),
TryReadResult::Busy => {} // fall through
}
};
}
Expand Down
Loading
Loading