Skip to content

Commit ea43500

Browse files
authored
Add test utility that verifies an AsyncWrite is closed correctly (#2159)
* Add test utility that verifies an AsyncWrite is closed correctly * Add track_closed for sinks too
1 parent cb696f9 commit ea43500

File tree

5 files changed

+247
-0
lines changed

5 files changed

+247
-0
lines changed

futures-test/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ futures-task = { version = "0.3.5", path = "../futures-task", default-features =
1717
futures-io = { version = "0.3.5", path = "../futures-io", default-features = false }
1818
futures-util = { version = "0.3.5", path = "../futures-util", default-features = false }
1919
futures-executor = { version = "0.3.5", path = "../futures-executor", default-features = false }
20+
futures-sink = { version = "0.3.5", path = "../futures-sink", default-features = false }
2021
pin-utils = { version = "0.1.0", default-features = false }
2122
once_cell = { version = "1.3.1", default-features = false, features = ["std"], optional = true }
2223
pin-project = "0.4.20"

futures-test/src/io/write/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use futures_io::AsyncWrite;
44

55
pub use super::limited::Limited;
66
pub use crate::interleave_pending::InterleavePending;
7+
pub use crate::track_closed::TrackClosed;
78

89
/// Additional combinators for testing async writers.
910
pub trait AsyncWriteTestExt: AsyncWrite {
@@ -80,6 +81,45 @@ pub trait AsyncWriteTestExt: AsyncWrite {
8081
{
8182
Limited::new(self, limit)
8283
}
84+
85+
/// Track whether this stream has been closed and errors if it is used after closing.
86+
///
87+
/// # Examples
88+
///
89+
/// ```
90+
/// # futures::executor::block_on(async {
91+
/// use futures::io::{AsyncWriteExt, Cursor};
92+
/// use futures_test::io::AsyncWriteTestExt;
93+
///
94+
/// let mut writer = Cursor::new(vec![0u8; 4]).track_closed();
95+
///
96+
/// writer.write_all(&[1, 2]).await?;
97+
/// assert!(!writer.is_closed());
98+
/// writer.close().await?;
99+
/// assert!(writer.is_closed());
100+
///
101+
/// # Ok::<(), std::io::Error>(()) })?;
102+
/// # Ok::<(), std::io::Error>(())
103+
/// ```
104+
///
105+
/// ```
106+
/// # futures::executor::block_on(async {
107+
/// use futures::io::{AsyncWriteExt, Cursor};
108+
/// use futures_test::io::AsyncWriteTestExt;
109+
///
110+
/// let mut writer = Cursor::new(vec![0u8; 4]).track_closed();
111+
///
112+
/// writer.close().await?;
113+
/// assert!(writer.write_all(&[1, 2]).await.is_err());
114+
/// # Ok::<(), std::io::Error>(()) })?;
115+
/// # Ok::<(), std::io::Error>(())
116+
/// ```
117+
fn track_closed(self) -> TrackClosed<Self>
118+
where
119+
Self: Sized,
120+
{
121+
TrackClosed::new(self)
122+
}
83123
}
84124

85125
impl<W> AsyncWriteTestExt for W where W: AsyncWrite {}

futures-test/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ pub mod future;
3939
#[cfg(feature = "std")]
4040
pub mod stream;
4141

42+
#[cfg(feature = "std")]
43+
pub mod sink;
44+
4245
#[cfg(feature = "std")]
4346
pub mod io;
4447

4548
mod interleave_pending;
49+
mod track_closed;

futures-test/src/sink/mod.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//! Additional combinators for testing sinks.
2+
3+
use futures_sink::Sink;
4+
5+
pub use crate::track_closed::TrackClosed;
6+
7+
/// Additional combinators for testing sinks.
8+
pub trait SinkTestExt<Item>: Sink<Item> {
9+
/// Track whether this sink has been closed and panics if it is used after closing.
10+
///
11+
/// # Examples
12+
///
13+
/// ```
14+
/// # futures::executor::block_on(async {
15+
/// use futures::sink::{SinkExt, drain};
16+
/// use futures_test::sink::SinkTestExt;
17+
///
18+
/// let mut sink = drain::<i32>().track_closed();
19+
///
20+
/// sink.send(1).await?;
21+
/// assert!(!sink.is_closed());
22+
/// sink.close().await?;
23+
/// assert!(sink.is_closed());
24+
///
25+
/// # Ok::<(), std::convert::Infallible>(()) })?;
26+
/// # Ok::<(), std::convert::Infallible>(())
27+
/// ```
28+
///
29+
/// Note: Unlike [`AsyncWriteTestExt::track_closed`] when
30+
/// used as a sink the adaptor will panic if closed too early as there's no easy way to
31+
/// integrate as an error.
32+
///
33+
/// [`AsyncWriteTestExt::track_closed`]: crate::io::AsyncWriteTestExt::track_closed
34+
///
35+
/// ```
36+
/// # futures::executor::block_on(async {
37+
/// use std::panic::AssertUnwindSafe;
38+
/// use futures::{sink::{SinkExt, drain}, future::FutureExt};
39+
/// use futures_test::sink::SinkTestExt;
40+
///
41+
/// let mut sink = drain::<i32>().track_closed();
42+
///
43+
/// sink.close().await?;
44+
/// assert!(AssertUnwindSafe(sink.send(1)).catch_unwind().await.is_err());
45+
/// # Ok::<(), std::convert::Infallible>(()) })?;
46+
/// # Ok::<(), std::convert::Infallible>(())
47+
/// ```
48+
fn track_closed(self) -> TrackClosed<Self>
49+
where
50+
Self: Sized,
51+
{
52+
TrackClosed::new(self)
53+
}
54+
}
55+
56+
impl<Item, W> SinkTestExt<Item> for W where W: Sink<Item> {}

futures-test/src/track_closed.rs

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
use futures_io::AsyncWrite;
2+
use futures_sink::Sink;
3+
use std::{
4+
io::{self, IoSlice},
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
/// Async wrapper that tracks whether it has been closed.
10+
///
11+
/// See the `track_closed` methods on:
12+
/// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed)
13+
/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed)
14+
#[pin_project::pin_project]
15+
#[derive(Debug)]
16+
pub struct TrackClosed<T> {
17+
#[pin]
18+
inner: T,
19+
closed: bool,
20+
}
21+
22+
impl<T> TrackClosed<T> {
23+
pub(crate) fn new(inner: T) -> TrackClosed<T> {
24+
TrackClosed {
25+
inner,
26+
closed: false,
27+
}
28+
}
29+
30+
/// Check whether this object has been closed.
31+
pub fn is_closed(&self) -> bool {
32+
self.closed
33+
}
34+
35+
/// Acquires a reference to the underlying object that this adaptor is
36+
/// wrapping.
37+
pub fn get_ref(&self) -> &T {
38+
&self.inner
39+
}
40+
41+
/// Acquires a mutable reference to the underlying object that this
42+
/// adaptor is wrapping.
43+
pub fn get_mut(&mut self) -> &mut T {
44+
&mut self.inner
45+
}
46+
47+
/// Acquires a pinned mutable reference to the underlying object that
48+
/// this adaptor is wrapping.
49+
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
50+
self.project().inner
51+
}
52+
53+
/// Consumes this adaptor returning the underlying object.
54+
pub fn into_inner(self) -> T {
55+
self.inner
56+
}
57+
}
58+
59+
impl<T: AsyncWrite> AsyncWrite for TrackClosed<T> {
60+
fn poll_write(
61+
self: Pin<&mut Self>,
62+
cx: &mut Context<'_>,
63+
buf: &[u8],
64+
) -> Poll<io::Result<usize>> {
65+
if self.is_closed() {
66+
return Poll::Ready(Err(io::Error::new(
67+
io::ErrorKind::Other,
68+
"Attempted to write after stream was closed",
69+
)));
70+
}
71+
self.project().inner.poll_write(cx, buf)
72+
}
73+
74+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75+
if self.is_closed() {
76+
return Poll::Ready(Err(io::Error::new(
77+
io::ErrorKind::Other,
78+
"Attempted to flush after stream was closed",
79+
)));
80+
}
81+
assert!(!self.is_closed());
82+
self.project().inner.poll_flush(cx)
83+
}
84+
85+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
86+
if self.is_closed() {
87+
return Poll::Ready(Err(io::Error::new(
88+
io::ErrorKind::Other,
89+
"Attempted to close after stream was closed",
90+
)));
91+
}
92+
let this = self.project();
93+
match this.inner.poll_close(cx) {
94+
Poll::Ready(Ok(())) => {
95+
*this.closed = true;
96+
Poll::Ready(Ok(()))
97+
}
98+
other => other,
99+
}
100+
}
101+
102+
fn poll_write_vectored(
103+
self: Pin<&mut Self>,
104+
cx: &mut Context<'_>,
105+
bufs: &[IoSlice<'_>],
106+
) -> Poll<io::Result<usize>> {
107+
if self.is_closed() {
108+
return Poll::Ready(Err(io::Error::new(
109+
io::ErrorKind::Other,
110+
"Attempted to write after stream was closed",
111+
)));
112+
}
113+
self.project().inner.poll_write_vectored(cx, bufs)
114+
}
115+
}
116+
117+
impl<Item, T: Sink<Item>> Sink<Item> for TrackClosed<T> {
118+
type Error = T::Error;
119+
120+
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121+
assert!(!self.is_closed());
122+
self.project().inner.poll_ready(cx)
123+
}
124+
125+
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
126+
assert!(!self.is_closed());
127+
self.project().inner.start_send(item)
128+
}
129+
130+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131+
assert!(!self.is_closed());
132+
self.project().inner.poll_flush(cx)
133+
}
134+
135+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136+
assert!(!self.is_closed());
137+
let this = self.project();
138+
match this.inner.poll_close(cx) {
139+
Poll::Ready(Ok(())) => {
140+
*this.closed = true;
141+
Poll::Ready(Ok(()))
142+
}
143+
other => other,
144+
}
145+
}
146+
}

0 commit comments

Comments
 (0)