diff --git a/futures-util/src/stream/mod.rs b/futures-util/src/stream/mod.rs index 06bfd14cf1..b95edd51be 100644 --- a/futures-util/src/stream/mod.rs +++ b/futures-util/src/stream/mod.rs @@ -45,7 +45,7 @@ mod try_stream; pub use self::try_stream::{ try_unfold, AndThen, ErrInto, InspectErr, InspectOk, IntoStream, MapErr, MapOk, OrElse, TryCollect, TryConcat, TryFilter, TryFilterMap, TryFlatten, TryFold, TryForEach, TryNext, - TrySkipWhile, TryStreamExt, TryUnfold, + TrySkipWhile, TryStreamExt, TryTakeWhile, TryUnfold, }; #[cfg(feature = "io")] diff --git a/futures-util/src/stream/try_stream/mod.rs b/futures-util/src/stream/try_stream/mod.rs index 686f8ff5ad..4087298b7b 100644 --- a/futures-util/src/stream/try_stream/mod.rs +++ b/futures-util/src/stream/try_stream/mod.rs @@ -103,6 +103,10 @@ mod try_skip_while; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::try_skip_while::TrySkipWhile; +mod try_take_while; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::try_take_while::TryTakeWhile; + cfg_target_has_atomic! { #[cfg(feature = "alloc")] mod try_buffer_unordered; @@ -432,6 +436,36 @@ pub trait TryStreamExt: TryStream { TrySkipWhile::new(self, f) } + /// Take elements on this stream while the provided asynchronous predicate + /// resolves to `true`. + /// + /// This function is similar to + /// [`StreamExt::take_while`](crate::stream::StreamExt::take_while) but exits + /// early if an error occurs. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::future; + /// use futures::stream::{self, TryStreamExt}; + /// + /// let stream = stream::iter(vec![Ok::(1), Ok(2), Ok(3), Ok(2)]); + /// let stream = stream.try_take_while(|x| future::ready(Ok(*x < 3))); + /// + /// let output: Result, i32> = stream.try_collect().await; + /// assert_eq!(output, Ok(vec![1, 2])); + /// # }) + /// ``` + fn try_take_while(self, f: F) -> TryTakeWhile + where + F: FnMut(&Self::Ok) -> Fut, + Fut: TryFuture, + Self: Sized, + { + TryTakeWhile::new(self, f) + } + /// Attempts to run this stream to completion, executing the provided asynchronous /// closure for each element on the stream concurrently as elements become /// available, exiting as soon as an error occurs. diff --git a/futures-util/src/stream/try_stream/try_take_while.rs b/futures-util/src/stream/try_stream/try_take_while.rs new file mode 100644 index 0000000000..16bfb2047e --- /dev/null +++ b/futures-util/src/stream/try_stream/try_take_while.rs @@ -0,0 +1,132 @@ +use core::fmt; +use core::pin::Pin; +use futures_core::future::TryFuture; +use futures_core::stream::{FusedStream, Stream, TryStream}; +use futures_core::task::{Context, Poll}; +#[cfg(feature = "sink")] +use futures_sink::Sink; +use pin_project::pin_project; + +/// Stream for the [`try_take_while`](super::TryStreamExt::try_take_while) +/// method. +#[pin_project] +#[must_use = "streams do nothing unless polled"] +pub struct TryTakeWhile +where + St: TryStream, +{ + #[pin] + stream: St, + f: F, + #[pin] + pending_fut: Option, + pending_item: Option, + done_taking: bool, +} + +impl fmt::Debug for TryTakeWhile +where + St: TryStream + fmt::Debug, + St::Ok: fmt::Debug, + Fut: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryTakeWhile") + .field("stream", &self.stream) + .field("pending_fut", &self.pending_fut) + .field("pending_item", &self.pending_item) + .field("done_taking", &self.done_taking) + .finish() + } +} + +impl TryTakeWhile +where + St: TryStream, + F: FnMut(&St::Ok) -> Fut, + Fut: TryFuture, +{ + pub(super) fn new(stream: St, f: F) -> TryTakeWhile { + TryTakeWhile { + stream, + f, + pending_fut: None, + pending_item: None, + done_taking: false, + } + } + + delegate_access_inner!(stream, St, ()); +} + +impl Stream for TryTakeWhile +where + St: TryStream, + F: FnMut(&St::Ok) -> Fut, + Fut: TryFuture, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if *this.done_taking { + return Poll::Ready(None); + } + + Poll::Ready(loop { + if let Some(fut) = this.pending_fut.as_mut().as_pin_mut() { + let take = ready!(fut.try_poll(cx)?); + let item = this.pending_item.take(); + this.pending_fut.set(None); + if take { + break item.map(Ok); + } else { + *this.done_taking = true; + break None; + } + } else if let Some(item) = ready!(this.stream.as_mut().try_poll_next(cx)?) { + this.pending_fut.set(Some((this.f)(&item))); + *this.pending_item = Some(item); + } else { + break None; + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + if self.done_taking { + return (0, Some(0)); + } + + let pending_len = if self.pending_item.is_some() { 1 } else { 0 }; + let (_, upper) = self.stream.size_hint(); + let upper = match upper { + Some(x) => x.checked_add(pending_len), + None => None, + }; + (0, upper) // can't know a lower bound, due to the predicate + } +} + +impl FusedStream for TryTakeWhile +where + St: TryStream + FusedStream, + F: FnMut(&St::Ok) -> Fut, + Fut: TryFuture, +{ + fn is_terminated(&self) -> bool { + self.done_taking || self.pending_item.is_none() && self.stream.is_terminated() + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for TryTakeWhile +where + S: TryStream + Sink, +{ + type Error = E; + + delegate_sink!(stream, Item); +} diff --git a/futures/src/lib.rs b/futures/src/lib.rs index 32c746d442..c85a483818 100644 --- a/futures/src/lib.rs +++ b/futures/src/lib.rs @@ -465,7 +465,7 @@ pub mod stream { AndThen, ErrInto, MapOk, MapErr, OrElse, InspectOk, InspectErr, TryNext, TryForEach, TryFilter, TryFilterMap, TryFlatten, - TryCollect, TryConcat, TryFold, TrySkipWhile, + TryCollect, TryConcat, TryFold, TrySkipWhile, TryTakeWhile, IntoStream, };