Skip to content

Commit bf32251

Browse files
Merge pull request #28 from hapsoc/nix-fork
feat!: Use ktls-recvmsg crate to add recvmsg fallback, closes #24
2 parents 7c377ca + 56d28b0 commit bf32251

File tree

6 files changed

+198
-59
lines changed

6 files changed

+198
-59
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Configures kTLS for tokio-rustls client and server connections.
1111
"""
1212

1313
[dependencies]
14-
libc = "0.2.146"
14+
libc = { version = "0.2.148", features = ["const-extern-fn"] }
1515
thiserror = "1.0.40"
1616
tracing = "0.1.37"
1717
tokio-rustls = "0.24.1"
@@ -22,6 +22,7 @@ pin-project-lite = "0.2.9"
2222
tokio = { version = "1.28.2", features = ["net", "macros", "io-util"] }
2323
futures = "0.3.28"
2424
ktls-sys = "1.0.0"
25+
ktls-recvmsg = { version = "0.1.1" }
2526

2627
[dev-dependencies]
2728
const-random = "0.1.15"

src/async_read_ready.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use std::{io, task};
2+
3+
pub trait AsyncReadReady {
4+
/// cf. https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.poll_read_ready
5+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>>;
6+
}
7+
8+
impl AsyncReadReady for tokio::net::TcpStream {
9+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
10+
tokio::net::TcpStream::poll_read_ready(self, cx)
11+
}
12+
}

src/cork_stream.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
use std::{
2-
io,
3-
pin::Pin,
4-
task::{Context, Poll},
5-
};
1+
use std::{io, pin::Pin, task};
62

73
use rustls::internal::msgs::codec::Codec;
84
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
95

6+
use crate::AsyncReadReady;
7+
108
enum State {
119
ReadHeader { header_buf: [u8; 5], offset: usize },
1210
ReadPayload { msg_size: usize, offset: usize },
@@ -59,9 +57,9 @@ where
5957
#[inline]
6058
fn poll_read(
6159
self: Pin<&mut Self>,
62-
cx: &mut Context<'_>,
60+
cx: &mut task::Context<'_>,
6361
buf: &mut ReadBuf<'_>,
64-
) -> Poll<io::Result<()>> {
62+
) -> task::Poll<io::Result<()>> {
6563
let this = unsafe { self.get_unchecked_mut() };
6664
let mut io = unsafe { Pin::new_unchecked(&mut this.io) };
6765

@@ -75,7 +73,7 @@ where
7573
"corked, returning empty read (but waking to prevent stalls)"
7674
);
7775
cx.waker().wake_by_ref();
78-
return Poll::Ready(Ok(()));
76+
return task::Poll::Ready(Ok(()));
7977
}
8078

8179
let left = header_buf.len() - *offset;
@@ -97,7 +95,7 @@ where
9795
buf.put_slice(&header_buf[..*offset]);
9896
*state = State::Passthrough;
9997

100-
return Poll::Ready(Ok(()));
98+
return task::Poll::Ready(Ok(()));
10199
}
102100
tracing::trace!("read {} bytes off of header", rest.filled().len());
103101
}
@@ -127,7 +125,7 @@ where
127125
}
128126
}
129127

130-
return Poll::Ready(Ok(()));
128+
return task::Poll::Ready(Ok(()));
131129
} else {
132130
// keep trying
133131
}
@@ -156,7 +154,7 @@ where
156154
let new_filled = buf.filled().len() + just_read;
157155
buf.set_filled(new_filled);
158156

159-
return Poll::Ready(Ok(()));
157+
return task::Poll::Ready(Ok(()));
160158
}
161159
State::Passthrough => {
162160
// we encountered EOF while reading, or saw an invalid header and we're just
@@ -168,28 +166,40 @@ where
168166
}
169167
}
170168

169+
impl<IO> AsyncReadReady for CorkStream<IO>
170+
where
171+
IO: AsyncReadReady,
172+
{
173+
fn poll_read_ready(&self, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
174+
self.io.poll_read_ready(cx)
175+
}
176+
}
177+
171178
impl<IO> AsyncWrite for CorkStream<IO>
172179
where
173180
IO: AsyncWrite,
174181
{
175182
#[inline]
176183
fn poll_write(
177184
self: Pin<&mut Self>,
178-
cx: &mut Context<'_>,
185+
cx: &mut task::Context<'_>,
179186
buf: &[u8],
180-
) -> Poll<io::Result<usize>> {
187+
) -> task::Poll<io::Result<usize>> {
181188
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
182189
io.poll_write(cx, buf)
183190
}
184191

185192
#[inline]
186-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193+
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
187194
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
188195
io.poll_flush(cx)
189196
}
190197

191198
#[inline]
192-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199+
fn poll_shutdown(
200+
self: Pin<&mut Self>,
201+
cx: &mut task::Context<'_>,
202+
) -> task::Poll<io::Result<()>> {
193203
let io = unsafe { self.map_unchecked_mut(|s| &mut s.io) };
194204
io.poll_shutdown(cx)
195205
}

src/ktls_stream.rs

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
use std::{io, os::unix::prelude::AsRawFd, pin::Pin, task};
1+
use ktls_recvmsg::{recvmsg, ControlMessageOwned, Errno, MsgFlags, SockaddrIn};
2+
use std::{
3+
io::{self, IoSliceMut},
4+
os::unix::prelude::AsRawFd,
5+
pin::Pin,
6+
task,
7+
};
28

39
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
410

11+
use crate::AsyncReadReady;
12+
513
// A wrapper around `IO` that sends a `close_notify` when shut down or dropped.
614
pin_project_lite::pin_project! {
715
pub struct KtlsStream<IO>
@@ -54,16 +62,16 @@ where
5462

5563
impl<IO> AsyncRead for KtlsStream<IO>
5664
where
57-
IO: AsRawFd + AsyncRead,
65+
IO: AsRawFd + AsyncRead + AsyncReadReady,
5866
{
5967
fn poll_read(
6068
self: Pin<&mut Self>,
6169
cx: &mut task::Context<'_>,
6270
buf: &mut ReadBuf<'_>,
6371
) -> task::Poll<io::Result<()>> {
64-
tracing::trace!(remaining = %buf.remaining(), "KtlsStream::poll_read");
72+
tracing::trace!(buf.remaining = %buf.remaining(), "KtlsStream::poll_read");
6573

66-
let this = self.project();
74+
let mut this = self.project();
6775

6876
if let Some((drain_index, drained)) = this.drained.as_mut() {
6977
let drained = &drained[*drain_index..];
@@ -79,11 +87,99 @@ where
7987
}
8088
cx.waker().wake_by_ref();
8189

90+
tracing::trace!("KtlsStream::poll_read, returning after drain");
8291
return task::Poll::Ready(Ok(()));
8392
}
8493

85-
tracing::trace!("KtlsStream::poll_read, forwarding to inner IO");
86-
this.inner.poll_read(cx, buf)
94+
let read_res = this.inner.as_mut().poll_read(cx, buf);
95+
if let task::Poll::Ready(Err(e)) = &read_res {
96+
// 5 is a generic "input/output error", it happens when
97+
// using poll_read on a kTLS socket that just received
98+
// a control message
99+
if let Some(5) = e.raw_os_error() {
100+
// could be a control message, let's check
101+
let fd = this.inner.as_raw_fd();
102+
103+
// XXX: recvmsg wants a `&mut Vec<u8>` so it's able to resize it
104+
// I guess? Or so there's a clear separation between uninitialized
105+
// and initialized? We could probably get read of that heap alloc, idk.
106+
107+
// let mut cmsgspace =
108+
// [0u8; unsafe { libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _ }];
109+
let mut cmsgspace = Vec::with_capacity(unsafe {
110+
libc::CMSG_SPACE(std::mem::size_of::<u8>() as _) as _
111+
});
112+
113+
let mut iov = [IoSliceMut::new(buf.initialize_unfilled())];
114+
let flags = MsgFlags::empty();
115+
116+
let r = recvmsg::<SockaddrIn>(fd, &mut iov, Some(&mut cmsgspace), flags);
117+
let r = match r {
118+
Ok(r) => r,
119+
Err(Errno::EAGAIN) => {
120+
unreachable!("expected a control message, got EAGAIN")
121+
}
122+
Err(e) => {
123+
// ok I guess it really failed then
124+
tracing::trace!(?e, "recvmsg failed");
125+
return Err(e.into()).into();
126+
}
127+
};
128+
let cmsg = r
129+
.cmsgs()
130+
.next()
131+
.expect("we should've received exactly one control message");
132+
match cmsg {
133+
// cf. RFC 5246, Section 6.2.1
134+
// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1
135+
ControlMessageOwned::TlsGetRecordType(t) => {
136+
match t {
137+
// change_cipher_spec
138+
20 => {
139+
panic!(
140+
"received TLS change_cipher_spec, this isn't supported by ktls"
141+
)
142+
}
143+
// alert
144+
21 => {
145+
panic!("received TLS alert, this isn't supported by ktls")
146+
}
147+
// handshake
148+
22 => {
149+
// TODO: this is where we receive TLS 1.3 resumption tickets,
150+
// should those be stored anywhere? I'm not even sure what
151+
// format they have at this point
152+
tracing::trace!(
153+
"ignoring handshake message (probably a resumption ticket)"
154+
);
155+
}
156+
// application data
157+
23 => {
158+
unreachable!("received TLS application in recvmsg, this is supposed to happen in the poll_read codepath")
159+
}
160+
_ => {
161+
// just ignore the message type then
162+
tracing::trace!("received message_type {t:#?}");
163+
}
164+
}
165+
}
166+
_ => panic!("unexpected cmsg type: {cmsg:#?}"),
167+
};
168+
169+
// FIXME: this is hacky, but can we do better?
170+
// after we handled (..ignored) the control message, we don't
171+
// know whether the scoket is still ready to be read or not.
172+
//
173+
// we could try looping (tricky code structure), but we can't,
174+
// for example, just call `poll_read`, which might fail not
175+
// not with EAGAIN/EWOULDBLOCK, but because _another_ control
176+
// message is available.
177+
cx.waker().wake_by_ref();
178+
return task::Poll::Pending;
179+
}
180+
}
181+
182+
read_res
87183
}
88184
}
89185

src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ use tokio::{
1414
mod ffi;
1515
use crate::ffi::CryptoInfo;
1616

17+
mod async_read_ready;
18+
pub use async_read_ready::AsyncReadReady;
19+
1720
mod ktls_stream;
1821
pub use ktls_stream::KtlsStream;
1922

@@ -203,7 +206,7 @@ pub async fn config_ktls_server<IO>(
203206
mut stream: tokio_rustls::server::TlsStream<CorkStream<IO>>,
204207
) -> Result<KtlsStream<IO>, Error>
205208
where
206-
IO: AsRawFd + AsyncRead + AsyncWrite + Unpin,
209+
IO: AsRawFd + AsyncRead + AsyncReadReady + AsyncWrite + Unpin,
207210
{
208211
stream.get_mut().0.corked = true;
209212
let drained = drain(&mut stream).await.map_err(Error::DrainError)?;

0 commit comments

Comments
 (0)