Skip to content

Commit 9dc9b53

Browse files
authored
io: implement vectored writes for write_buf (#7871)
1 parent c3b31ba commit 9dc9b53

2 files changed

Lines changed: 67 additions & 2 deletions

File tree

tokio/src/io/util/write_buf.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::io::AsyncWrite;
33
use bytes::Buf;
44
use pin_project_lite::pin_project;
55
use std::future::Future;
6-
use std::io;
6+
use std::io::{self, IoSlice};
77
use std::marker::PhantomPinned;
88
use std::pin::Pin;
99
use std::task::{ready, Context, Poll};
@@ -42,13 +42,22 @@ where
4242
type Output = io::Result<usize>;
4343

4444
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
45+
const MAX_VECTOR_ELEMENTS: usize = 64;
46+
4547
let me = self.project();
4648

4749
if !me.buf.has_remaining() {
4850
return Poll::Ready(Ok(0));
4951
}
5052

51-
let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.chunk()))?;
53+
let n = if me.writer.is_write_vectored() {
54+
let mut slices = [IoSlice::new(&[]); MAX_VECTOR_ELEMENTS];
55+
let cnt = me.buf.chunks_vectored(&mut slices);
56+
ready!(Pin::new(&mut *me.writer).poll_write_vectored(cx, &slices[..cnt]))?
57+
} else {
58+
ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk()))?
59+
};
60+
5261
me.buf.advance(n);
5362
Poll::Ready(Ok(n))
5463
}

tokio/tests/io_write_buf.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,59 @@ async fn write_all() {
5454
assert_eq!(wr.cnt, 1);
5555
assert_eq!(buf.position(), 4);
5656
}
57+
58+
#[tokio::test]
59+
async fn write_buf_vectored() {
60+
struct Wr {
61+
buf: BytesMut,
62+
cnt: usize,
63+
}
64+
65+
impl AsyncWrite for Wr {
66+
fn poll_write(
67+
self: Pin<&mut Self>,
68+
_cx: &mut Context<'_>,
69+
_buf: &[u8],
70+
) -> Poll<io::Result<usize>> {
71+
panic!("shouldn't be called")
72+
}
73+
74+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75+
Ok(()).into()
76+
}
77+
78+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79+
Ok(()).into()
80+
}
81+
82+
fn poll_write_vectored(
83+
mut self: Pin<&mut Self>,
84+
_cx: &mut Context<'_>,
85+
bufs: &[io::IoSlice<'_>],
86+
) -> Poll<Result<usize, io::Error>> {
87+
let mut n = 0;
88+
for buf in bufs {
89+
self.buf.extend_from_slice(buf);
90+
n += buf.len();
91+
}
92+
self.cnt += 1;
93+
Ok(n).into()
94+
}
95+
96+
fn is_write_vectored(&self) -> bool {
97+
true
98+
}
99+
}
100+
101+
let mut wr = Wr {
102+
buf: BytesMut::with_capacity(64),
103+
cnt: 0,
104+
};
105+
106+
let mut buf = Cursor::new(&b"hello world"[..]);
107+
108+
assert_ok!(wr.write_buf(&mut buf).await);
109+
assert_eq!(wr.buf, b"hello world"[..]);
110+
assert_eq!(wr.cnt, 1);
111+
assert_eq!(buf.position(), 11);
112+
}

0 commit comments

Comments
 (0)