Skip to content

Commit 1d08245

Browse files
committed
wip
1 parent 4f472ba commit 1d08245

File tree

3 files changed

+83
-79
lines changed

3 files changed

+83
-79
lines changed

ntex-tls/src/rustls/client.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ impl TlsClientFilter {
3232
cfg: Arc<ClientConfig>,
3333
domain: ServerName<'static>,
3434
) -> Result<Io<Layer<TlsClientFilter, F>>, io::Error> {
35-
let mut session = ClientConnection::new(cfg, domain).map_err(io::Error::other)?;
36-
37-
Stream::new(&mut session).handshake(&io).await?;
38-
Ok(io.add_filter(TlsClientFilter {
35+
let session = ClientConnection::new(cfg, domain).map_err(io::Error::other)?;
36+
let io = io.add_filter(TlsClientFilter {
3937
session: RefCell::new(session),
40-
}))
38+
});
39+
super::stream::handshake(&io.filter().session, &io).await?;
40+
Ok(io)
4141
}
4242
}

ntex-tls/src/rustls/server.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ impl TlsServerFilter {
4545
timeout: Millis,
4646
) -> Result<Io<Layer<TlsServerFilter, F>>, io::Error> {
4747
time::timeout(timeout, async {
48-
let mut session = ServerConnection::new(cfg).map_err(io::Error::other)?;
49-
50-
Stream::new(&mut session).handshake(&io).await?;
51-
Ok(io.add_filter(TlsServerFilter {
48+
let session = ServerConnection::new(cfg).map_err(io::Error::other)?;
49+
let io = io.add_filter(TlsServerFilter {
5250
session: RefCell::new(session),
53-
}))
51+
});
52+
53+
super::stream::handshake(&io.filter().session, &io).await?;
54+
Ok(io)
5455
})
5556
.await
5657
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "rustls handshake timeout"))

ntex-tls/src/rustls/stream.rs

Lines changed: 72 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::cell::RefCell;
12
use std::io::{self, Read, Write};
23
use std::{any, cmp, future::poll_fn, ops::Deref, ops::DerefMut, task::ready, task::Poll};
34

@@ -22,75 +23,6 @@ where
2223
S: DerefMut + Deref<Target = ConnectionCommon<SD>>,
2324
SD: SideData,
2425
{
25-
pub(crate) async fn handshake<F>(&mut self, io: &Io<F>) -> Result<(), io::Error> {
26-
let session = &mut self.session;
27-
28-
loop {
29-
let result = io.with_buf(|buf| {
30-
let mut wrp = Wrapper(buf);
31-
let mut result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
32-
33-
while session.wants_write() {
34-
result = session.write_tls(&mut wrp).map(|_| ());
35-
if result.is_err() {
36-
break;
37-
}
38-
}
39-
if session.wants_read() {
40-
let has_data = buf.with_read_buf(|rbuf| {
41-
rbuf.with_src(|b| {
42-
b.as_ref().map(|b| !b.is_empty()).unwrap_or_default()
43-
})
44-
});
45-
46-
if has_data {
47-
result = match session.read_tls(&mut wrp) {
48-
Ok(0) => Err(io::Error::new(
49-
io::ErrorKind::NotConnected,
50-
"disconnected",
51-
)),
52-
Ok(_) => Ok(()),
53-
Err(e) => Err(e),
54-
};
55-
56-
session.process_new_packets().map_err(|err| {
57-
// In case we have an alert to send describing this error,
58-
// try a last-gasp write -- but don't predate the primary
59-
// error.
60-
let _ = session.write_tls(&mut wrp);
61-
io::Error::new(io::ErrorKind::InvalidData, err)
62-
})?;
63-
} else {
64-
result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
65-
}
66-
}
67-
68-
Ok::<_, io::Error>(result)
69-
})??;
70-
71-
match result {
72-
Ok(()) => return Ok(()),
73-
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
74-
if !session.is_handshaking() {
75-
return Ok(());
76-
}
77-
poll_fn(|cx| {
78-
match ready!(io.poll_read_notify(cx))? {
79-
Some(_) => Ok(()),
80-
None => Err(io::Error::new(
81-
io::ErrorKind::NotConnected,
82-
"disconnected",
83-
)),
84-
}?;
85-
Poll::Ready(Ok::<_, io::Error>(()))
86-
})
87-
.await?;
88-
}
89-
Err(e) => return Err(e),
90-
}
91-
}
92-
}
93-
9426
pub(crate) fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
9527
const H2: &[u8] = b"h2";
9628

@@ -197,6 +129,77 @@ where
197129
}
198130
}
199131

132+
pub(crate) async fn handshake<F, S, SD>(
133+
session: &RefCell<S>,
134+
io: &Io<F>,
135+
) -> Result<(), io::Error>
136+
where
137+
S: DerefMut + Deref<Target = ConnectionCommon<SD>>,
138+
SD: SideData,
139+
{
140+
loop {
141+
let (result, handshaking) = io.with_buf(|buf| {
142+
let mut session = session.borrow_mut();
143+
let mut wrp = Wrapper(buf);
144+
let mut result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
145+
146+
while session.wants_write() {
147+
result = session.write_tls(&mut wrp).map(|_| ());
148+
if result.is_err() {
149+
break;
150+
}
151+
}
152+
if session.wants_read() {
153+
let has_data = buf.with_read_buf(|rbuf| {
154+
rbuf.with_src(|b| b.as_ref().map(|b| !b.is_empty()).unwrap_or_default())
155+
});
156+
157+
if has_data {
158+
result = match session.read_tls(&mut wrp) {
159+
Ok(0) => {
160+
Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected"))
161+
}
162+
Ok(_) => Ok(()),
163+
Err(e) => Err(e),
164+
};
165+
166+
session.process_new_packets().map_err(|err| {
167+
// In case we have an alert to send describing this error,
168+
// try a last-gasp write -- but don't predate the primary
169+
// error.
170+
let _ = session.write_tls(&mut wrp);
171+
io::Error::new(io::ErrorKind::InvalidData, err)
172+
})?;
173+
} else {
174+
result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
175+
}
176+
}
177+
178+
Ok::<_, io::Error>((result, session.is_handshaking()))
179+
})??;
180+
181+
match result {
182+
Ok(()) => return Ok(()),
183+
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
184+
if !handshaking {
185+
return Ok(());
186+
}
187+
poll_fn(|cx| {
188+
match ready!(io.poll_read_notify(cx))? {
189+
Some(_) => Ok(()),
190+
None => {
191+
Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected"))
192+
}
193+
}?;
194+
Poll::Ready(Ok::<_, io::Error>(()))
195+
})
196+
.await?;
197+
}
198+
Err(e) => return Err(e),
199+
}
200+
}
201+
}
202+
200203
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
201204

202205
impl io::Read for Wrapper<'_, '_> {

0 commit comments

Comments
 (0)