diff --git a/src/futures/write/generic/decoder.rs b/src/futures/write/generic/decoder.rs index 7edf741e..f82f8751 100644 --- a/src/futures/write/generic/decoder.rs +++ b/src/futures/write/generic/decoder.rs @@ -151,7 +151,7 @@ impl AsyncWrite for Decoder { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_close(cx))?; + ready!(self.project().writer.as_mut().poll_flush(cx))?; Poll::Ready(Ok(())) } diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index df342a50..f73f8d21 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,5 +1,7 @@ #![allow(dead_code, unused_macros)] // Different tests use a different subset of functions +mod track_closed; + use bytes::Bytes; use futures::{ io::AsyncBufRead, @@ -61,6 +63,7 @@ impl From>> for InputStream { } pub mod prelude { + use super::track_closed::TrackClosedExt as _; pub use async_compression::Level; pub use bytes::Bytes; pub use futures::{ @@ -109,13 +112,17 @@ pub mod prelude { { let mut test_writer = (&mut output) .limited_write(limit) - .interleave_pending_write(); - let mut writer = create_writer(&mut test_writer); - for chunk in input { - block_on(writer.write_all(chunk)).unwrap(); - block_on(writer.flush()).unwrap(); + .interleave_pending_write() + .track_closed(); + { + let mut writer = create_writer(&mut test_writer); + for chunk in input { + block_on(writer.write_all(chunk)).unwrap(); + block_on(writer.flush()).unwrap(); + } + block_on(writer.close()).unwrap(); } - block_on(writer.close()).unwrap(); + assert!(test_writer.is_closed()); } output } diff --git a/tests/utils/track_closed.rs b/tests/utils/track_closed.rs new file mode 100644 index 00000000..93e12d42 --- /dev/null +++ b/tests/utils/track_closed.rs @@ -0,0 +1,63 @@ +use core::{ + pin::Pin, + task::{Context, Poll}, +}; +use futures::io::AsyncWrite; +use std::io::{IoSlice, Result}; + +pub trait TrackClosedExt: AsyncWrite { + fn track_closed(self) -> TrackClosed + where + Self: Sized + Unpin, + { + TrackClosed { + inner: self, + closed: false, + } + } +} + +impl TrackClosedExt for W {} + +pub struct TrackClosed { + inner: W, + closed: bool, +} + +impl TrackClosed { + pub fn is_closed(&self) -> bool { + self.closed + } +} + +impl AsyncWrite for TrackClosed { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + assert!(!self.closed); + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + assert!(!self.closed); + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + assert!(!self.closed); + match Pin::new(&mut self.inner).poll_close(cx) { + Poll::Ready(Ok(())) => { + self.closed = true; + Poll::Ready(Ok(())) + } + other => other, + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context, + bufs: &[IoSlice], + ) -> Poll> { + assert!(!self.closed); + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } +}