Skip to content

Commit 81ebfcd

Browse files
Merge f69d70d into 7c377ca
2 parents 7c377ca + f69d70d commit 81ebfcd

File tree

6 files changed

+201
-59
lines changed

6 files changed

+201
-59
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ pin-project-lite = "0.2.9"
2222
tokio = { version = "1.28.2", features = ["net", "macros", "io-util"] }
2323
futures = "0.3.28"
2424
ktls-sys = "1.0.0"
25+
nix = { version = "0.26.1", features = ["socket", "uio", "net"], default-features = false }
2526

2627
[dev-dependencies]
2728
const-random = "0.1.15"
2829
rcgen = "0.10.0"
2930
socket2 = "0.5.3"
3031
tokio = { version = "1.28.2", features = ["full"] }
3132
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
33+
34+
[patch.crates-io]
35+
nix = { git = "https://github.com/fasterthanlime/nix", rev = "004d31c" } # on branch 'sol_tls'
36+

src/async_read_ready.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use std::{io, task};
2+
3+
pub trait AsyncReadReady {
4+
/// cf. https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.poll_read_ready
5+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>>;
6+
}
7+
8+
impl AsyncReadReady for tokio::net::TcpStream {
9+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
10+
tokio::net::TcpStream::poll_read_ready(self, cx)
11+
}
12+
}

src/cork_stream.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
use std::{
2-
io,
3-
pin::Pin,
4-
task::{Context, Poll},
5-
};
1+
use std::{io, pin::Pin, task};
62

73
use rustls::internal::msgs::codec::Codec;
84
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
95

6+
use crate::AsyncReadReady;
7+
108
enum State {
119
ReadHeader { header_buf: [u8; 5], offset: usize },
1210
ReadPayload { msg_size: usize, offset: usize },
@@ -59,9 +57,9 @@ where
5957
#[inline]
6058
fn poll_read(
6159
self: Pin<&mut Self>,
62-
cx: &mut Context<'_>,
60+
cx: &mut task::Context<'_>,
6361
buf: &mut ReadBuf<'_>,
64-
) -> Poll<io::Result<()>> {
62+
) -> task::Poll<io::Result<()>> {
6563
let this = unsafe { self.get_unchecked_mut() };
6664
let mut io = unsafe { Pin::new_unchecked(&mut this.io) };
6765

@@ -75,7 +73,7 @@ where
7573
"corked, returning empty read (but waking to prevent stalls)"
7674
);
7775
cx.waker().wake_by_ref();
78-
return Poll::Ready(Ok(()));
76+
return task::Poll::Ready(Ok(()));
7977
}
8078

8179
let left = header_buf.len() - *offset;
@@ -97,7 +95,7 @@ where
9795
buf.put_slice(&header_buf[..*offset]);
9896
*state = State::Passthrough;
9997

100-
return Poll::Ready(Ok(()));
98+
return task::Poll::Ready(Ok(()));
10199
}
102100
tracing::trace!("read {} bytes off of header", rest.filled().len());
103101
}
@@ -127,7 +125,7 @@ where
127125
}
128126
}
129127

130-
return Poll::Ready(Ok(()));
128+
return task::Poll::Ready(Ok(()));
131129
} else {
132130
// keep trying
133131
}
@@ -156,7 +154,7 @@ where
156154
let new_filled = buf.filled().len() + just_read;
157155
buf.set_filled(new_filled);
158156

159-
return Poll::Ready(Ok(()));
157+
return task::Poll::Ready(Ok(()));
160158
}
161159
State::Passthrough => {
162160
// we encountered EOF while reading, or saw an invalid header and we're just
@@ -168,28 +166,40 @@ where
168166
}
169167
}
170168

169+
impl<IO> AsyncReadReady for CorkStream<IO>
170+
where
171+
IO: AsyncReadReady,
172+
{
173+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
174+
self.io.poll_read_ready(cx)
175+
}
176+
}
177+
171178
impl<IO> AsyncWrite for CorkStream<IO>
172179
where
173180
IO: AsyncWrite,
174181
{
175182
#[inline]
176183
fn poll_write(
177184
self: Pin<&mut Self>,
178-
cx: &mut Context<'_>,
185+
cx: &mut task::Context<'_>,
179186
buf: &[u8],
180-
) -> Poll<io::Result<usize>> {
187+
) -> task::Poll<io::Result<usize>> {
181188
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
182189
io.poll_write(cx, buf)
183190
}
184191

185192
#[inline]
186-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193+
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
187194
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
188195
io.poll_flush(cx)
189196
}
190197

191198
#[inline]
192-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199+
fn poll_shutdown(
200+
self: Pin<&mut Self>,
201+
cx: &mut task::Context<'_>,
202+
) -> task::Poll<io::Result<()>> {
193203
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
194204
io.poll_shutdown(cx)
195205
}

src/ktls_stream.rs

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1-
use std::{io, os::unix::prelude::AsRawFd, pin::Pin, task};
2-
1+
use std::{
2+
io::{self, IoSliceMut},
3+
os::unix::prelude::AsRawFd,
4+
pin::Pin,
5+
task,
6+
};
7+
8+
use nix::{
9+
cmsg_space,
10+
sys::socket::{ControlMessageOwned, MsgFlags, SockaddrIn},
11+
};
312
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
413

14+
use crate::AsyncReadReady;
15+
516
// A wrapper around `IO` that sends a `close_notify` when shut down or dropped.
617
pin_project_lite::pin_project! {
718
pub struct KtlsStream<IO>
@@ -54,16 +65,16 @@ where
5465

5566
impl<IO> AsyncRead for KtlsStream<IO>
5667
where
57-
IO: AsRawFd + AsyncRead,
68+
IO: AsRawFd + AsyncRead + AsyncReadReady,
5869
{
5970
fn poll_read(
6071
self: Pin<&mut Self>,
6172
cx: &mut task::Context<'_>,
6273
buf: &mut ReadBuf<'_>,
6374
) -> task::Poll<io::Result<()>> {
64-
tracing::trace!(remaining = %buf.remaining(), "KtlsStream::poll_read");
75+
tracing::trace!(buf.remaining = %buf.remaining(), "KtlsStream::poll_read");
6576

66-
let this = self.project();
77+
let mut this = self.project();
6778

6879
if let Some((drain_index, drained)) = this.drained.as_mut() {
6980
let drained = &drained[*drain_index..];
@@ -79,11 +90,95 @@ where
7990
}
8091
cx.waker().wake_by_ref();
8192

93+
tracing::trace!("KtlsStream::poll_read, returning after drain");
8294
return task::Poll::Ready(Ok(()));
8395
}
8496

85-
tracing::trace!("KtlsStream::poll_read, forwarding to inner IO");
86-
this.inner.poll_read(cx, buf)
97+
let read_res = this.inner.as_mut().poll_read(cx, buf);
98+
if let task::Poll::Ready(Err(e)) = &read_res {
99+
// 5 is a generic "input/output error", it happens when
100+
// using poll_read on a kTLS socket that just received
101+
// a control message
102+
if let Some(5) = e.raw_os_error() {
103+
// could be a control message, let's check
104+
let fd = this.inner.as_raw_fd();
105+
106+
let mut cmsgspace = cmsg_space!(nix::sys::time::TimeVal);
107+
let mut iov = [IoSliceMut::new(buf.initialize_unfilled())];
108+
let flags = MsgFlags::empty();
109+
110+
let r = nix::sys::socket::recvmsg::<SockaddrIn>(
111+
fd,
112+
&mut iov,
113+
Some(&mut cmsgspace),
114+
flags,
115+
);
116+
let r = match r {
117+
Ok(r) => r,
118+
Err(nix::errno::Errno::EAGAIN) => {
119+
unreachable!("expected a control message, got EAGAIN")
120+
}
121+
Err(e) => {
122+
// ok I guess it really failed then
123+
tracing::trace!(?e, "recvmsg failed");
124+
return Err(e.into()).into();
125+
}
126+
};
127+
let cmsg = r
128+
.cmsgs()
129+
.next()
130+
.expect("we should've received exactly one control message");
131+
match cmsg {
132+
// cf. RFC 5246, Section 6.2.1
133+
// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1
134+
ControlMessageOwned::TlsGetRecordType(t) => {
135+
match t {
136+
// change_cipher_spec
137+
20 => {
138+
panic!(
139+
"received TLS change_cipher_spec, this isn't supported by ktls"
140+
)
141+
}
142+
// alert
143+
21 => {
144+
panic!("received TLS alert, this isn't supported by ktls")
145+
}
146+
// handshake
147+
22 => {
148+
// TODO: this is where we receive TLS 1.3 resumption tickets,
149+
// should those be stored anywhere? I'm not even sure what
150+
// format they have at this point
151+
tracing::trace!(
152+
"ignoring handshake message (probably a resumption ticket)"
153+
);
154+
}
155+
// application data
156+
23 => {
157+
unreachable!("received TLS application in recvmsg, this is supposed to happen in the poll_read codepath")
158+
}
159+
_ => {
160+
// just ignore the message type then
161+
tracing::trace!("received message_type {t:#?}");
162+
}
163+
}
164+
}
165+
_ => panic!("unexpected cmsg type: {cmsg:#?}"),
166+
};
167+
168+
// FIXME: this is hacky, but can we do better?
169+
// after we handled (..ignored) the control message, we don't
170+
// know whether the scoket is still ready to be read or not.
171+
//
172+
// we could try looping (tricky code structure), but we can't,
173+
// for example, just call `poll_read`, which might fail not
174+
// not with EAGAIN/EWOULDBLOCK, but because _another_ control
175+
// message is available.
176+
cx.waker().wake_by_ref();
177+
return task::Poll::Pending;
178+
}
179+
}
180+
181+
read_res
87182
}
88183
}
89184

src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ use tokio::{
1414
mod ffi;
1515
use crate::ffi::CryptoInfo;
1616

17+
mod async_read_ready;
18+
pub use async_read_ready::AsyncReadReady;
19+
1720
mod ktls_stream;
1821
pub use ktls_stream::KtlsStream;
1922

@@ -203,7 +206,7 @@ pub async fn config_ktls_server<IO>(
203206
mut stream: tokio_rustls::server::TlsStream<CorkStream<IO>>,
204207
) -> Result<KtlsStream<IO>, Error>
205208
where
206-
IO: AsRawFd + AsyncRead + AsyncWrite + Unpin,
209+
IO: AsRawFd + AsyncRead + AsyncReadReady + AsyncWrite + Unpin,
207210
{
208211
stream.get_mut().0.corked = true;
209212
let drained = drain(&mut stream).await.map_err(Error::DrainError)?;

0 commit comments

Comments
 (0)