Skip to content

Commit 538f70f

Browse files
committed
Adapt AsyncRead, AsynWrite
1 parent 842ad6e commit 538f70f

File tree

3 files changed

+20
-83
lines changed

3 files changed

+20
-83
lines changed

src/common/io/rewind.rs

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::marker::Unpin;
22
use std::{cmp, io};
33

44
use bytes::{Buf, Bytes};
5-
use tokio::io::{AsyncRead, AsyncWrite};
5+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
66

77
use crate::common::{task, Pin, Poll};
88

@@ -46,27 +46,22 @@ impl<T> AsyncRead for Rewind<T>
4646
where
4747
T: AsyncRead + Unpin,
4848
{
49-
#[inline]
50-
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
51-
self.inner.prepare_uninitialized_buffer(buf)
52-
}
53-
5449
fn poll_read(
55-
mut self: Pin<&mut Self>,
50+
self: Pin<&mut Self>,
5651
cx: &mut task::Context<'_>,
57-
buf: &mut [u8],
58-
) -> Poll<io::Result<usize>> {
52+
buf: &mut ReadBuf<'_>,
53+
) -> Poll<io::Result<()>> {
5954
if let Some(mut prefix) = self.pre.take() {
6055
// If there are no remaining bytes, let the bytes get dropped.
6156
if !prefix.is_empty() {
62-
let copy_len = cmp::min(prefix.len(), buf.len());
63-
prefix.copy_to_slice(&mut buf[..copy_len]);
57+
let copy_len = cmp::min(prefix.len(), buf.remaining());
58+
// TODO: There should be a way to do following two lines cleaner...
59+
buf.put_slice(prefix.to_vec().as_slice());
60+
prefix.advance(copy_len);
6461
// Put back whats left
6562
if !prefix.is_empty() {
6663
self.pre = Some(prefix);
6764
}
68-
69-
return Poll::Ready(Ok(copy_len));
7065
}
7166
}
7267
Pin::new(&mut self.inner).poll_read(cx, buf)
@@ -92,15 +87,6 @@ where
9287
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
9388
Pin::new(&mut self.inner).poll_shutdown(cx)
9489
}
95-
96-
#[inline]
97-
fn poll_write_buf<B: Buf>(
98-
mut self: Pin<&mut Self>,
99-
cx: &mut task::Context<'_>,
100-
buf: &mut B,
101-
) -> Poll<io::Result<usize>> {
102-
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
103-
}
10490
}
10591

10692
#[cfg(test)]

src/server/tcp.rs

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ mod addr_stream {
186186
use std::net::SocketAddr;
187187
#[cfg(unix)]
188188
use std::os::unix::io::{AsRawFd, RawFd};
189-
use tokio::io::{AsyncRead, AsyncWrite};
189+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
190190
use tokio::net::TcpStream;
191191

192192
use crate::common::{task, Pin, Poll};
@@ -231,30 +231,14 @@ mod addr_stream {
231231
}
232232

233233
impl AsyncRead for AddrStream {
234-
unsafe fn prepare_uninitialized_buffer(
235-
&self,
236-
buf: &mut [std::mem::MaybeUninit<u8>],
237-
) -> bool {
238-
self.inner.prepare_uninitialized_buffer(buf)
239-
}
240-
241234
#[inline]
242235
fn poll_read(
243-
mut self: Pin<&mut Self>,
236+
self: Pin<&mut Self>,
244237
cx: &mut task::Context<'_>,
245-
buf: &mut [u8],
246-
) -> Poll<io::Result<usize>> {
238+
buf: &mut ReadBuf<'_>,
239+
) -> Poll<io::Result<()>> {
247240
Pin::new(&mut self.inner).poll_read(cx, buf)
248241
}
249-
250-
#[inline]
251-
fn poll_read_buf<B: BufMut>(
252-
mut self: Pin<&mut Self>,
253-
cx: &mut task::Context<'_>,
254-
buf: &mut B,
255-
) -> Poll<io::Result<usize>> {
256-
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
257-
}
258242
}
259243

260244
impl AsyncWrite for AddrStream {
@@ -267,15 +251,6 @@ mod addr_stream {
267251
Pin::new(&mut self.inner).poll_write(cx, buf)
268252
}
269253

270-
#[inline]
271-
fn poll_write_buf<B: Buf>(
272-
mut self: Pin<&mut Self>,
273-
cx: &mut task::Context<'_>,
274-
buf: &mut B,
275-
) -> Poll<io::Result<usize>> {
276-
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
277-
}
278-
279254
#[inline]
280255
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
281256
// TCP flush is a noop

src/upgrade.rs

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::io;
1212
use std::marker::Unpin;
1313

1414
use bytes::{Buf, Bytes};
15-
use tokio::io::{AsyncRead, AsyncWrite};
15+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1616
use tokio::sync::oneshot;
1717

1818
use crate::common::io::Rewind;
@@ -105,15 +105,11 @@ impl Upgraded {
105105
}
106106

107107
impl AsyncRead for Upgraded {
108-
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
109-
self.io.prepare_uninitialized_buffer(buf)
110-
}
111-
112108
fn poll_read(
113-
mut self: Pin<&mut Self>,
109+
self: Pin<&mut Self>,
114110
cx: &mut task::Context<'_>,
115-
buf: &mut [u8],
116-
) -> Poll<io::Result<usize>> {
111+
buf: &mut ReadBuf<'_>,
112+
) -> Poll<io::Result<()>> {
117113
Pin::new(&mut self.io).poll_read(cx, buf)
118114
}
119115
}
@@ -127,14 +123,6 @@ impl AsyncWrite for Upgraded {
127123
Pin::new(&mut self.io).poll_write(cx, buf)
128124
}
129125

130-
fn poll_write_buf<B: Buf>(
131-
mut self: Pin<&mut Self>,
132-
cx: &mut task::Context<'_>,
133-
buf: &mut B,
134-
) -> Poll<io::Result<usize>> {
135-
Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf)
136-
}
137-
138126
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
139127
Pin::new(&mut self.io).poll_flush(cx)
140128
}
@@ -247,15 +235,11 @@ impl dyn Io + Send {
247235
}
248236

249237
impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
250-
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
251-
self.0.prepare_uninitialized_buffer(buf)
252-
}
253-
254238
fn poll_read(
255-
mut self: Pin<&mut Self>,
239+
self: Pin<&mut Self>,
256240
cx: &mut task::Context<'_>,
257-
buf: &mut [u8],
258-
) -> Poll<io::Result<usize>> {
241+
buf: &mut ReadBuf<'_>,
242+
) -> Poll<io::Result<()>> {
259243
Pin::new(&mut self.0).poll_read(cx, buf)
260244
}
261245
}
@@ -269,14 +253,6 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
269253
Pin::new(&mut self.0).poll_write(cx, buf)
270254
}
271255

272-
fn poll_write_buf<B: Buf>(
273-
mut self: Pin<&mut Self>,
274-
cx: &mut task::Context<'_>,
275-
buf: &mut B,
276-
) -> Poll<io::Result<usize>> {
277-
Pin::new(&mut self.0).poll_write_buf(cx, buf)
278-
}
279-
280256
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
281257
Pin::new(&mut self.0).poll_flush(cx)
282258
}
@@ -292,7 +268,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
292268
cx: &mut task::Context<'_>,
293269
mut buf: &mut dyn Buf,
294270
) -> Poll<io::Result<usize>> {
295-
Pin::new(&mut self.0).poll_write_buf(cx, &mut buf)
271+
Pin::new(&mut self.0).poll_write(cx, buf.bytes())
296272
}
297273
}
298274

0 commit comments

Comments
 (0)