Skip to content

Commit 4601c84

Browse files
authored
stream: add next_many and poll_next_many to StreamMap (#6409)
1 parent deff252 commit 4601c84

File tree

4 files changed

+378
-3
lines changed

4 files changed

+378
-3
lines changed

tokio-stream/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@
7373
#[macro_use]
7474
mod macros;
7575

76+
mod poll_fn;
77+
pub(crate) use poll_fn::poll_fn;
78+
7679
pub mod wrappers;
7780

7881
mod stream_ext;

tokio-stream/src/poll_fn.rs

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
pub(crate) struct PollFn<F> {
6+
f: F,
7+
}
8+
9+
pub(crate) fn poll_fn<T, F>(f: F) -> PollFn<F>
10+
where
11+
F: FnMut(&mut Context<'_>) -> Poll<T>,
12+
{
13+
PollFn { f }
14+
}
15+
16+
impl<T, F> Future for PollFn<F>
17+
where
18+
F: FnMut(&mut Context<'_>) -> Poll<T>,
19+
{
20+
type Output = T;
21+
22+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
23+
// Safety: We never construct a `Pin<&mut F>` anywhere, so accessing `f`
24+
// mutably in an unpinned way is sound.
25+
//
26+
// This use of unsafe cannot be replaced with the pin-project macro
27+
// because:
28+
// * If we put `#[pin]` on the field, then it gives us a `Pin<&mut F>`,
29+
// which we can't use to call the closure.
30+
// * If we don't put `#[pin]` on the field, then it makes `PollFn` be
31+
// unconditionally `Unpin`, which we also don't want.
32+
let me = unsafe { Pin::into_inner_unchecked(self) };
33+
(me.f)(cx)
34+
}
35+
}

tokio-stream/src/stream_map.rs

+105-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::Stream;
1+
use crate::{poll_fn, Stream};
22

33
use std::borrow::Borrow;
44
use std::hash::Hash;
@@ -561,6 +561,110 @@ impl<K, V> Default for StreamMap<K, V> {
561561
}
562562
}
563563

564+
impl<K, V> StreamMap<K, V>
565+
where
566+
K: Clone + Unpin,
567+
V: Stream + Unpin,
568+
{
569+
/// Receives multiple items on this [`StreamMap`], extending the provided `buffer`.
570+
///
571+
/// This method returns the number of items that is appended to the `buffer`.
572+
///
573+
/// Note that this method does not guarantee that exactly `limit` items
574+
/// are received. Rather, if at least one item is available, it returns
575+
/// as many items as it can up to the given limit. This method returns
576+
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
577+
///
578+
/// # Cancel safety
579+
///
580+
/// This method is cancel safe. If `next_many` is used as the event in a
581+
/// [`tokio::select!`](tokio::select) statement and some other branch
582+
/// completes first, it is guaranteed that no items were received on any of
583+
/// the underlying streams.
584+
pub async fn next_many(&mut self, buffer: &mut Vec<(K, V::Item)>, limit: usize) -> usize {
585+
poll_fn(|cx| self.poll_next_many(cx, buffer, limit)).await
586+
}
587+
588+
/// Polls to receive multiple items on this `StreamMap`, extending the provided `buffer`.
589+
///
590+
/// This method returns:
591+
/// * `Poll::Pending` if no items are available but the `StreamMap` is not empty.
592+
/// * `Poll::Ready(count)` where `count` is the number of items successfully received and
593+
/// stored in `buffer`. This can be less than, or equal to, `limit`.
594+
/// * `Poll::Ready(0)` if `limit` is set to zero or when the `StreamMap` is empty.
595+
///
596+
/// Note that this method does not guarantee that exactly `limit` items
597+
/// are received. Rather, if at least one item is available, it returns
598+
/// as many items as it can up to the given limit. This method returns
599+
/// zero only if the `StreamMap` is empty (or if `limit` is zero).
600+
pub fn poll_next_many(
601+
&mut self,
602+
cx: &mut Context<'_>,
603+
buffer: &mut Vec<(K, V::Item)>,
604+
limit: usize,
605+
) -> Poll<usize> {
606+
if limit == 0 || self.entries.is_empty() {
607+
return Poll::Ready(0);
608+
}
609+
610+
let mut added = 0;
611+
612+
let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;
613+
let mut idx = start;
614+
615+
while added < limit {
616+
// Indicates whether at least one stream returned a value when polled or not
617+
let mut should_loop = false;
618+
619+
for _ in 0..self.entries.len() {
620+
let (_, stream) = &mut self.entries[idx];
621+
622+
match Pin::new(stream).poll_next(cx) {
623+
Poll::Ready(Some(val)) => {
624+
added += 1;
625+
626+
let key = self.entries[idx].0.clone();
627+
buffer.push((key, val));
628+
629+
should_loop = true;
630+
631+
idx = idx.wrapping_add(1) % self.entries.len();
632+
}
633+
Poll::Ready(None) => {
634+
// Remove the entry
635+
self.entries.swap_remove(idx);
636+
637+
// Check if this was the last entry, if so the cursor needs
638+
// to wrap
639+
if idx == self.entries.len() {
640+
idx = 0;
641+
} else if idx < start && start <= self.entries.len() {
642+
// The stream being swapped into the current index has
643+
// already been polled, so skip it.
644+
idx = idx.wrapping_add(1) % self.entries.len();
645+
}
646+
}
647+
Poll::Pending => {
648+
idx = idx.wrapping_add(1) % self.entries.len();
649+
}
650+
}
651+
}
652+
653+
if !should_loop {
654+
break;
655+
}
656+
}
657+
658+
if added > 0 {
659+
Poll::Ready(added)
660+
} else if self.entries.is_empty() {
661+
Poll::Ready(0)
662+
} else {
663+
Poll::Pending
664+
}
665+
}
666+
}
667+
564668
impl<K, V> Stream for StreamMap<K, V>
565669
where
566670
K: Clone + Unpin,

0 commit comments

Comments
 (0)