Skip to content

Commit 4d3d94f

Browse files
boquan-fangBoquan Fang
andauthored
feat(s2n-quic-dc): implement dcQUIC's server, client, and io providers (#2752)
Co-authored-by: Boquan Fang <[email protected]>
1 parent a2ad54e commit 4d3d94f

File tree

8 files changed

+1044
-0
lines changed

8 files changed

+1044
-0
lines changed

dc/s2n-quic-dc/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ pin-project-lite = "0.2"
4545
rand = { version = "0.9", features = ["small_rng"] }
4646
rand_chacha = "0.9"
4747
s2n-codec = { version = "=0.63.0", path = "../../common/s2n-codec", default-features = false }
48+
s2n-quic = { version = "=1.63.0", path = "../../quic/s2n-quic", features = ["unstable-provider-dc"] }
4849
s2n-quic-core = { version = "=0.63.0", path = "../../quic/s2n-quic-core", default-features = false }
4950
s2n-quic-platform = { version = "=0.63.0", path = "../../quic/s2n-quic-platform" }
5051
slotmap = "1"
5152
hashbrown = "0.15"
5253
thiserror = "2"
5354
tokio = { version = "1", default-features = false, features = ["sync"] }
55+
tokio-util = "0.7"
5456
tracing = "0.1"
5557
tracing-subscriber = { version = "0.3", features = [
5658
"env-filter",

dc/s2n-quic-dc/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub mod msg;
1414
pub mod packet;
1515
pub mod path;
1616
pub mod pool;
17+
pub mod psk;
1718
pub mod random;
1819
pub mod recovery;
1920
pub mod socket;

dc/s2n-quic-dc/src/psk.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
pub mod client;
5+
pub mod io;
6+
pub mod server;

dc/s2n-quic-dc/src/psk/client.rs

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use super::io::{self, HandshakeFailed};
5+
use crate::path::secret;
6+
use s2n_quic::{
7+
provider::{event::Subscriber as Sub, tls::Provider as Prov},
8+
Connection,
9+
};
10+
use std::{net::SocketAddr, sync::Arc, time::Duration};
11+
use tokio::runtime::Runtime;
12+
use tokio_util::sync::DropGuard;
13+
14+
mod builder;
15+
16+
pub use crate::path::secret::HandshakeKind;
17+
pub use builder::Builder;
18+
19+
#[derive(Clone)]
20+
pub struct Provider {
21+
state: Arc<State>,
22+
}
23+
24+
struct State {
25+
// This is always present in production, but for testing purposes we sometimes run within the
26+
// deterministic simulation framework. In that case there's no runtime for us to push work
27+
// into.
28+
runtime: Option<(Arc<Runtime>, DropGuard)>,
29+
map: secret::Map,
30+
client: io::Client,
31+
local_addr: SocketAddr,
32+
}
33+
34+
fn make_runtime() -> (Arc<Runtime>, DropGuard) {
35+
let runtime = Arc::new(
36+
tokio::runtime::Builder::new_current_thread()
37+
.enable_all()
38+
.build()
39+
.unwrap(),
40+
);
41+
42+
let token = tokio_util::sync::CancellationToken::new();
43+
let cancelled = token.clone().cancelled_owned();
44+
let rt = runtime.clone();
45+
std::thread::Builder::new()
46+
.name(String::from("hs-client"))
47+
.spawn(move || {
48+
rt.block_on(cancelled);
49+
})
50+
.unwrap();
51+
52+
(runtime, token.drop_guard())
53+
}
54+
55+
impl State {
56+
fn new_runtime<
57+
Provider: Prov + Clone + Send + Sync + 'static,
58+
Subscriber: Sub + Send + Sync + 'static,
59+
Event: s2n_quic::provider::event::Subscriber,
60+
>(
61+
addr: SocketAddr,
62+
map: secret::Map,
63+
tls_materials_provider: Provider,
64+
subscriber: Subscriber,
65+
builder: Builder<Event>,
66+
) -> io::Result<Self> {
67+
let (runtime, rt_guard) = make_runtime();
68+
let guard = runtime.enter();
69+
let client = io::Client::bind::<Provider, Subscriber, Event>(
70+
addr,
71+
map.clone(),
72+
tls_materials_provider,
73+
subscriber,
74+
builder,
75+
)?;
76+
drop(guard);
77+
78+
Ok(Self {
79+
map,
80+
runtime: Some((runtime, rt_guard)),
81+
local_addr: client.local_addr()?,
82+
client,
83+
})
84+
}
85+
}
86+
87+
impl Provider {
88+
/// Returns a [`Builder`] which is able to configure the [`Provider`]
89+
pub fn builder() -> Builder<impl s2n_quic::provider::event::Subscriber> {
90+
Builder::default()
91+
}
92+
93+
pub fn new<
94+
Provider: Prov + Clone + Send + Sync + 'static,
95+
Subscriber: Sub + Send + Sync + 'static,
96+
Event: s2n_quic::provider::event::Subscriber,
97+
>(
98+
addr: SocketAddr,
99+
map: secret::Map,
100+
tls_materials_provider: Provider,
101+
subscriber: Subscriber,
102+
query_event_callback: fn(&mut Connection, Duration),
103+
builder: Builder<Event>,
104+
) -> io::Result<Self> {
105+
let state = State::new_runtime(
106+
addr,
107+
map.clone(),
108+
tls_materials_provider,
109+
subscriber,
110+
builder,
111+
)?;
112+
let state = Arc::new(state);
113+
114+
// Avoid holding onto the state unintentionally after it's no longer needed.
115+
let weak = Arc::downgrade(&state);
116+
map.register_request_handshake(Box::new(move |peer| {
117+
if let Some(state) = weak.upgrade() {
118+
let runtime = state.runtime.as_ref().map(|v| &v.0).unwrap();
119+
let client = state.client.clone();
120+
// Drop the JoinHandle -- we're not actually going to block on the join handle's
121+
// result. The future will keep running in the background.
122+
runtime.spawn(async move {
123+
if let Err(HandshakeFailed { .. }) =
124+
client.connect(peer, query_event_callback).await
125+
{
126+
// failure has already been logged, no further action required.
127+
}
128+
});
129+
}
130+
}));
131+
132+
Ok(Self { state })
133+
}
134+
135+
/// Handshake asynchronously with a peer.
136+
///
137+
/// This method can be called with any async runtime.
138+
#[inline]
139+
pub async fn handshake_with(
140+
&self,
141+
peer: SocketAddr,
142+
query_event_callback: fn(&mut Connection, Duration),
143+
) -> std::io::Result<HandshakeKind> {
144+
let (_peer, kind) = self
145+
.handshake_with_entry(peer, query_event_callback)
146+
.await?;
147+
Ok(kind)
148+
}
149+
150+
/// Handshake asynchronously with a peer, returning an entry for secret derivation
151+
///
152+
/// This method can be called with any async runtime.
153+
#[inline]
154+
#[doc(hidden)]
155+
pub async fn handshake_with_entry(
156+
&self,
157+
peer: SocketAddr,
158+
query_event_callback: fn(&mut Connection, Duration),
159+
) -> std::io::Result<(secret::map::Peer, HandshakeKind)> {
160+
// Unconditionally request a background handshake. This schedules any re-handshaking
161+
// needed.
162+
if self.state.runtime.is_some() {
163+
let _ = self.background_handshake_with(peer, query_event_callback);
164+
}
165+
166+
if let Some(peer) = self.state.map.get_tracked(peer) {
167+
return Ok((peer, HandshakeKind::Cached));
168+
}
169+
170+
let state = self.state.clone();
171+
if let Some((runtime, _)) = self.state.runtime.as_ref() {
172+
runtime
173+
.spawn(async move { state.client.connect(peer, query_event_callback).await })
174+
.await??;
175+
} else {
176+
state.client.connect(peer, query_event_callback).await?;
177+
}
178+
179+
// already recorded a metric above in get_tracked.
180+
let peer = self.state.map.get_untracked(peer).ok_or_else(|| {
181+
std::io::Error::new(
182+
std::io::ErrorKind::NotFound,
183+
format!("handshake failed to exchange credentials for {peer}"),
184+
)
185+
})?;
186+
187+
Ok((peer, HandshakeKind::Fresh))
188+
}
189+
190+
/// Handshake with a peer in the background.∂
191+
#[inline]
192+
pub fn background_handshake_with(
193+
&self,
194+
peer: SocketAddr,
195+
query_event_callback: fn(&mut Connection, Duration),
196+
) -> std::io::Result<HandshakeKind> {
197+
if self.state.map.contains(&peer) {
198+
return Ok(HandshakeKind::Cached);
199+
}
200+
201+
let client = self.state.client.clone();
202+
if let Some((runtime, _)) = self.state.runtime.as_ref() {
203+
// Drop the JoinHandle -- we're not actually going to block on the join handle's
204+
// result. The future will keep running in the background.
205+
runtime.spawn(async move {
206+
if let Err(HandshakeFailed { .. }) =
207+
client.connect(peer, query_event_callback).await
208+
{
209+
// error already logged
210+
}
211+
});
212+
} else {
213+
panic!("background_handshake_with not supported with deterministic testing");
214+
}
215+
216+
// Technically this might not be true (the handshake may get deduplicated), but it's close
217+
// enough to accurate that we're OK claiming it's true.
218+
Ok(HandshakeKind::Fresh)
219+
}
220+
221+
/// Handshake synchronously with a peer.
222+
///
223+
/// This method will block the calling thread and will panic if called from within a Tokio
224+
/// runtime.
225+
// We duplicate the implementation of this method with handshake_with so that we preserve the fast
226+
// path (not interacting with the runtime at all) for cached handshakes.
227+
#[inline]
228+
pub fn blocking_handshake_with(
229+
&self,
230+
peer: SocketAddr,
231+
query_event_callback: fn(&mut Connection, Duration),
232+
) -> std::io::Result<HandshakeKind> {
233+
// Unconditionally request a background handshake. This schedules any re-handshaking
234+
// needed.
235+
if self.state.runtime.is_some() {
236+
let _ = self.background_handshake_with(peer, query_event_callback);
237+
}
238+
239+
if self.state.map.contains(&peer) {
240+
return Ok(HandshakeKind::Cached);
241+
}
242+
243+
let fut = self.state.client.connect(peer, query_event_callback);
244+
if let Some((runtime, _)) = self.state.runtime.as_ref() {
245+
runtime.block_on(fut)?
246+
} else {
247+
panic!("blocking_handshake_with not supported with deterministic testing");
248+
}
249+
250+
debug_assert!(self.state.map.contains(&peer));
251+
252+
Ok(HandshakeKind::Fresh)
253+
}
254+
255+
/// This forces a handshake with the given peer, ignoring whether there's already an entry or
256+
/// not.
257+
#[inline]
258+
#[doc(hidden)]
259+
pub async fn unconditionally_handshake_with_entry(
260+
&self,
261+
peer: SocketAddr,
262+
query_event_callback: fn(&mut Connection, Duration),
263+
) -> std::io::Result<secret::map::Peer> {
264+
let state = self.state.clone();
265+
if let Some((runtime, _)) = self.state.runtime.as_ref() {
266+
runtime
267+
.spawn(async move { state.client.connect(peer, query_event_callback).await })
268+
.await??;
269+
} else {
270+
return Err(std::io::Error::new(
271+
std::io::ErrorKind::InvalidInput,
272+
"missing runtime for handshake client",
273+
));
274+
}
275+
276+
// Don't bother recording metrics on access.
277+
let peer = self.state.map.get_untracked(peer).ok_or_else(|| {
278+
std::io::Error::new(
279+
std::io::ErrorKind::NotFound,
280+
format!("handshake failed to exchange credentials for {peer}"),
281+
)
282+
})?;
283+
284+
Ok(peer)
285+
}
286+
287+
// FIXME: Remove Result (breaking change)
288+
#[inline]
289+
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
290+
Ok(self.state.local_addr)
291+
}
292+
293+
pub fn map(&self) -> &secret::Map {
294+
&self.state.map
295+
}
296+
}

0 commit comments

Comments
 (0)