Skip to content

Commit 6846acd

Browse files
dimaestrookanarus
andauthored
tls: basic tls implementation with tokio-rustls (tokio-only!) (#441)
* enhance(Dir): support conditional fetch with `If-None-Match` `If-Modified-Since` * next: test conditional fetch impl * tls: implement basic tls implementation with tokio-rustls (tokio-only!) * fix(session): fix AsyncRead & AsyncWrite imports for session * fix(session): fix unsupported WS over TLS & cfg handling * fix(session): fix trait bounds when ws feature is disabled * fix(session): omit Send trait to unbreak builds on other runtimes (hopefully) * fix(etag): merge issues fix --------- Co-authored-by: kanarus <kanarus786@gmail.com>
1 parent 9187715 commit 6846acd

6 files changed

Lines changed: 281 additions & 30 deletions

File tree

ohkami/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ ctrlc = { version = "3.4", optional = true }
4747
num_cpus = { version = "1.16", optional = true }
4848
futures-util = { version = "0.3", optional = true, default-features = false }
4949
mews = { version = "0.2", optional = true }
50+
rustls = { version = "0.23.23", optional = true }
51+
tokio-rustls = { version = "0.26.2", optional = true }
5052

5153

5254
[features]
@@ -89,6 +91,7 @@ nightly = []
8991
openapi = ["dep:ohkami_openapi", "ohkami_macros/openapi"]
9092
sse = ["ohkami_lib/stream"]
9193
ws = ["ohkami_lib/stream", "dep:mews"]
94+
tls = ["__rt_native__", "rt_tokio", "dep:rustls", "dep:tokio-rustls"]
9295

9396
##### internal #####
9497
__rt__ = []

ohkami/src/header/etag.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,4 @@ impl ETag<'static> {
123123
.then_some(Self::Strong(value.into()))
124124
.ok_or(ETagError::InvalidCharactor)
125125
}
126-
}
126+
}

ohkami/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ pub use ohkami::{Ohkami, Route};
218218
pub mod fang;
219219
pub use fang::{handler, Fang, FangProc};
220220

221+
#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
222+
mod tls;
223+
221224
pub mod format;
222225

223226
pub mod header;

ohkami/src/ohkami/mod.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ use std::sync::Arc;
1414
#[cfg(feature="__rt_native__")]
1515
use crate::{__rt__, Session};
1616

17+
#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
18+
use tokio_rustls::TlsAcceptor;
19+
20+
#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
21+
use crate::tls::TlsStream;
22+
1723
/// # Ohkami - a smart wolf who serves your web app
1824
///
1925
/// ## Definition
@@ -589,6 +595,100 @@ impl Ohkami {
589595
wg.await;
590596
}
591597

598+
#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
599+
/// Bind this `Ohkami` to an address and start serving with TLS/HTTPS support!
600+
///
601+
/// This method works like `howl` but upgrades connections to HTTPS using the provided
602+
/// rustls configuration. This functionality is only available with the `rt_tokio` feature.
603+
///
604+
/// ### Parameters
605+
///
606+
/// - `bind`: Same as `howl`, can be a socket address or TcpListener
607+
/// - `tls_config`: A rustls server configuration containing your certificates and keys
608+
///
609+
/// ### Example
610+
///
611+
/// ```no_run
612+
/// use ohkami::prelude::*;
613+
/// use rustls::{ServerConfig, Certificate, PrivateKey};
614+
/// use std::fs::File;
615+
/// use std::io::BufReader;
616+
///
617+
/// async fn hello() -> &'static str {
618+
/// "Hello, secure ohkami!"
619+
/// }
620+
///
621+
/// #[tokio::main]
622+
/// async fn main() -> std::io::Result<()> {
623+
/// // Initialize rustls crypto provider
624+
/// match rustls::crypto::ring::default_provider().install_default() {
625+
// Ok(_) => println!("Successfully installed rustls crypto provider"),
626+
// Err(e) => {
627+
// eprintln!("Failed to install rustls crypto provider: {:?}", e);
628+
// std::process::exit(1);
629+
// }
630+
// }
631+
/// // Load certificates and private key
632+
/// let cert_file = File::open("path/to/cert.pem")?;
633+
/// let key_file = File::open("path/to/key.pem")?;
634+
///
635+
/// let cert_chain = rustls_pemfile::certs(&mut BufReader::new(cert_file))
636+
/// .map(|certs| certs.into_iter().map(Certificate).collect())
637+
/// .unwrap_or_default();
638+
///
639+
/// let key = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(key_file))
640+
/// .next()
641+
/// .map(|key| PrivateKey(key))
642+
/// .expect("Failed to load private key");
643+
///
644+
/// // Build TLS configuration
645+
/// let tls_config = ServerConfig::builder()
646+
/// .with_safe_defaults()
647+
/// .with_no_client_auth()
648+
/// .with_single_cert(cert_chain, key)
649+
/// .expect("Failed to build TLS configuration");
650+
///
651+
/// // Create and run Ohkami with HTTPS
652+
/// Ohkami::new((
653+
/// "/".GET(hello),
654+
/// )).howl_tls("0.0.0.0:8443", tls_config).await;
655+
///
656+
/// Ok(())
657+
/// }
658+
/// ```
659+
pub async fn howl_tls<T>(self, bind: impl __rt__::IntoTcpListener<T>, tls_config: rustls::ServerConfig) {
660+
let (router, _) = self.into_router().finalize();
661+
let router = Arc::new(router);
662+
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
663+
664+
let listener = bind.ino_tcp_listener().await;
665+
let (wg, ctrl_c) = (sync::WaitGroup::new(), sync::CtrlC::new());
666+
667+
while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await {
668+
let Ok((tcp_stream, addr)) = accept else { continue };
669+
670+
let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else { continue };
671+
672+
let session = Session::new(
673+
router.clone(),
674+
TlsStream(tls_stream),
675+
addr.ip()
676+
);
677+
678+
let wg = wg.add();
679+
tokio::spawn(async move {
680+
session.manage().await;
681+
wg.done();
682+
});
683+
}
684+
685+
crate::DEBUG!("interrupted, trying graceful shutdown...");
686+
drop(listener);
687+
688+
crate::DEBUG!("waiting {} session(s) to finish...", wg.count());
689+
wg.await;
690+
}
691+
592692
#[cfg(feature="rt_worker")]
593693
#[doc(hidden)]
594694
pub async fn __worker__(self,

ohkami/src/session/mod.rs

Lines changed: 122 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,42 @@
22

33
use std::{any::Any, pin::Pin, sync::Arc, time::Duration};
44
use std::panic::{AssertUnwindSafe, catch_unwind};
5-
use crate::__rt__::TcpStream;
5+
use crate::__rt__::{AsyncRead, AsyncWrite};
66
use crate::response::Upgrade;
77
use crate::util::timeout_in;
88
use crate::router::r#final::Router;
99
use crate::{Request, Response};
1010

11-
pub(crate) struct Session {
12-
router: Arc<Router>,
13-
connection: TcpStream,
14-
ip: std::net::IpAddr,
11+
#[cfg(feature="ws")]
12+
use crate::__rt__::TcpStream;
13+
14+
pub(crate) struct Session<S> {
15+
router: Arc<Router>,
16+
connection: S,
17+
ip: std::net::IpAddr,
18+
}
19+
20+
#[cfg(feature="ws")]
21+
pub(crate) trait WebSocketUpgradeable {
22+
fn into_websocket_stream(self) -> Result<TcpStream, &'static str>;
1523
}
1624

17-
impl Session {
25+
#[cfg(feature="ws")]
26+
impl WebSocketUpgradeable for TcpStream {
27+
fn into_websocket_stream(self) -> Result<TcpStream, &'static str> {
28+
Ok(self)
29+
}
30+
}
31+
32+
#[cfg(feature="ws")]
33+
impl<S> Session<S>
34+
where
35+
S: AsyncRead + AsyncWrite + Unpin + WebSocketUpgradeable,
36+
{
1837
pub(crate) fn new(
19-
router: Arc<Router>,
20-
connection: TcpStream,
21-
ip: std::net::IpAddr
38+
router: Arc<Router>,
39+
connection: S,
40+
ip: std::net::IpAddr
2241
) -> Self {
2342
Self {
2443
router,
@@ -31,11 +50,11 @@ impl Session {
3150
#[cold] #[inline(never)]
3251
fn panicking(panic: Box<dyn Any + Send>) -> Response {
3352
if let Some(msg) = panic.downcast_ref::<String>() {
34-
crate::WARNING!("panic: {msg}");
53+
crate::WARNING!("[Panicked]: {msg}");
3554
} else if let Some(msg) = panic.downcast_ref::<&str>() {
36-
crate::WARNING!("panic: {msg}");
55+
crate::WARNING!("[Panicked]: {msg}");
3756
} else {
38-
crate::WARNING!("panic");
57+
crate::WARNING!("[Panicked]");
3958
}
4059
crate::Response::InternalServerError()
4160
}
@@ -66,33 +85,107 @@ impl Session {
6685
}
6786
}
6887
}).await {
69-
None => crate::WARNING!("\
88+
None => crate::WARNING!("[WARNING] \
7089
Session timeouted. In Ohkami, Keep-Alive timeout \
7190
is set to 42 seconds by default and is configurable \
7291
by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\
7392
"),
7493

7594
Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"),
7695

77-
#[cfg(feature="ws")]
7896
Some(Upgrade::WebSocket(ws)) => {
79-
crate::DEBUG!("WebSocket session started");
80-
81-
let aborted = ws.manage_with_timeout(
82-
Duration::from_secs(crate::CONFIG.websocket_timeout()),
83-
self.connection
84-
).await;
85-
if aborted {
86-
crate::WARNING!("\
87-
WebSocket session aborted by timeout. In Ohkami, \
88-
WebSocket timeout is set to 3600 seconds (1 hour) \
89-
by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \
90-
environment variable.\
91-
");
92-
}
97+
match self.connection.into_websocket_stream() {
98+
Ok(tcp_stream) => {
99+
crate::DEBUG!("WebSocket session started");
93100

94-
crate::DEBUG!("WebSocket session finished");
101+
let aborted = ws.manage_with_timeout(
102+
Duration::from_secs(crate::CONFIG.websocket_timeout()),
103+
tcp_stream
104+
).await;
105+
if aborted {
106+
crate::WARNING!("[WARNING] \
107+
WebSocket session aborted by timeout. In Ohkami, \
108+
WebSocket timeout is set to 3600 seconds (1 hour) \
109+
by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \
110+
environment variable.\
111+
");
112+
}
113+
114+
crate::DEBUG!("WebSocket session finished");
115+
}
116+
Err(msg) => {
117+
crate::WARNING!("[WARNING] {}", msg);
118+
}
119+
}
95120
}
96121
}
97122
}
98123
}
124+
125+
// There has to be some cleaner implementation to apply the conditional trait bounds in this...
126+
#[cfg(not(feature="ws"))]
127+
impl<S> Session<S>
128+
where
129+
S: AsyncRead + AsyncWrite + Unpin,
130+
{
131+
pub(crate) fn new(
132+
router: Arc<Router>,
133+
connection: S,
134+
ip: std::net::IpAddr
135+
) -> Self {
136+
Self {
137+
router,
138+
connection,
139+
ip
140+
}
141+
}
142+
143+
pub(crate) async fn manage(mut self) {
144+
#[cold] #[inline(never)]
145+
fn panicking(panic: Box<dyn Any + Send>) -> Response {
146+
if let Some(msg) = panic.downcast_ref::<String>() {
147+
crate::WARNING!("[Panicked]: {msg}");
148+
} else if let Some(msg) = panic.downcast_ref::<&str>() {
149+
crate::WARNING!("[Panicked]: {msg}");
150+
} else {
151+
crate::WARNING!("[Panicked]");
152+
}
153+
crate::Response::InternalServerError()
154+
}
155+
156+
match timeout_in(Duration::from_secs(crate::CONFIG.keepalive_timeout()), async {
157+
let mut req = Request::init(self.ip);
158+
let mut req = unsafe {Pin::new_unchecked(&mut req)};
159+
loop {
160+
req.clear();
161+
match req.as_mut().read(&mut self.connection).await {
162+
Ok(Some(())) => {
163+
let close = matches!(req.headers.Connection(), Some("close" | "Close"));
164+
165+
let res = match catch_unwind(AssertUnwindSafe({
166+
let req = req.as_mut();
167+
|| self.router.handle(req.get_mut())
168+
})) {
169+
Ok(future) => future.await,
170+
Err(panic) => panicking(panic),
171+
};
172+
let upgrade = res.send(&mut self.connection).await;
173+
174+
if !upgrade.is_none() {break upgrade}
175+
if close {break Upgrade::None}
176+
}
177+
Ok(None) => break Upgrade::None,
178+
Err(res) => {res.send(&mut self.connection).await;},
179+
}
180+
}
181+
}).await {
182+
None => crate::WARNING!("[WARNING] \
183+
Session timeouted. In Ohkami, Keep-Alive timeout \
184+
is set to 42 seconds by default and is configurable \
185+
by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\
186+
"),
187+
188+
Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"),
189+
}
190+
}
191+
}

ohkami/src/tls/mod.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use tokio::io::{AsyncRead, AsyncWrite};
2+
pub struct TlsStream(pub tokio_rustls::server::TlsStream<tokio::net::TcpStream>);
3+
4+
#[cfg(feature="ws")]
5+
impl crate::session::WebSocketUpgradeable for TlsStream {
6+
fn into_websocket_stream(self) -> Result<crate::__rt__::TcpStream, &'static str> {
7+
Err("WebSocket connections are not supported over TLS yet")
8+
}
9+
}
10+
11+
impl AsyncRead for TlsStream {
12+
fn poll_read(
13+
mut self: std::pin::Pin<&mut Self>,
14+
cx: &mut std::task::Context<'_>,
15+
buf: &mut tokio::io::ReadBuf<'_>
16+
) -> std::task::Poll<std::io::Result<()>> {
17+
match std::pin::Pin::new(&mut self.0).poll_read(cx, buf) {
18+
std::task::Poll::Ready(Err(e)) => {
19+
if e.to_string().contains("close_notify") {
20+
std::task::Poll::Ready(Ok(()))
21+
} else {
22+
std::task::Poll::Ready(Err(e))
23+
}
24+
},
25+
other => other,
26+
}
27+
}
28+
}
29+
30+
impl AsyncWrite for TlsStream {
31+
fn poll_write(
32+
mut self: std::pin::Pin<&mut Self>,
33+
cx: &mut std::task::Context<'_>,
34+
buf: &[u8]
35+
) -> std::task::Poll<std::io::Result<usize>> {
36+
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
37+
}
38+
39+
fn poll_flush(
40+
mut self: std::pin::Pin<&mut Self>,
41+
cx: &mut std::task::Context<'_>
42+
) -> std::task::Poll<std::io::Result<()>> {
43+
std::pin::Pin::new(&mut self.0).poll_flush(cx)
44+
}
45+
46+
fn poll_shutdown(
47+
mut self: std::pin::Pin<&mut Self>,
48+
cx: &mut std::task::Context<'_>
49+
) -> std::task::Poll<std::io::Result<()>> {
50+
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
51+
}
52+
}

0 commit comments

Comments
 (0)