diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 641c9718f..900cbce23 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -47,6 +47,8 @@ ctrlc = { version = "3.4", optional = true } num_cpus = { version = "1.16", optional = true } futures-util = { version = "0.3", optional = true, default-features = false } mews = { version = "0.2", optional = true } +rustls = { version = "0.23.23", optional = true } +tokio-rustls = { version = "0.26.2", optional = true } [features] @@ -89,6 +91,7 @@ nightly = [] openapi = ["dep:ohkami_openapi", "ohkami_macros/openapi"] sse = ["ohkami_lib/stream"] ws = ["ohkami_lib/stream", "dep:mews"] +tls = ["__rt_native__", "rt_tokio", "dep:rustls", "dep:tokio-rustls"] ##### internal ##### __rt__ = [] diff --git a/ohkami/src/header/etag.rs b/ohkami/src/header/etag.rs index bcbf20d69..d0056efbe 100644 --- a/ohkami/src/header/etag.rs +++ b/ohkami/src/header/etag.rs @@ -123,4 +123,4 @@ impl ETag<'static> { .then_some(Self::Strong(value.into())) .ok_or(ETagError::InvalidCharactor) } -} +} \ No newline at end of file diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 56489c804..5debb0347 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -218,6 +218,9 @@ pub use ohkami::{Ohkami, Route}; pub mod fang; pub use fang::{handler, Fang, FangProc}; +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +mod tls; + pub mod format; pub mod header; diff --git a/ohkami/src/ohkami/mod.rs b/ohkami/src/ohkami/mod.rs index 3296fcc86..3eab6a751 100644 --- a/ohkami/src/ohkami/mod.rs +++ b/ohkami/src/ohkami/mod.rs @@ -14,6 +14,12 @@ use std::sync::Arc; #[cfg(feature="__rt_native__")] use crate::{__rt__, Session}; +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +use tokio_rustls::TlsAcceptor; + +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +use crate::tls::TlsStream; + /// # Ohkami - a smart wolf who serves your web app /// /// ## Definition @@ -589,6 +595,100 @@ impl Ohkami { wg.await; } + #[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] + /// Bind this `Ohkami` to an address and start serving with TLS/HTTPS support! + /// + /// This method works like `howl` but upgrades connections to HTTPS using the provided + /// rustls configuration. This functionality is only available with the `rt_tokio` feature. + /// + /// ### Parameters + /// + /// - `bind`: Same as `howl`, can be a socket address or TcpListener + /// - `tls_config`: A rustls server configuration containing your certificates and keys + /// + /// ### Example + /// + /// ```no_run + /// use ohkami::prelude::*; + /// use rustls::{ServerConfig, Certificate, PrivateKey}; + /// use std::fs::File; + /// use std::io::BufReader; + /// + /// async fn hello() -> &'static str { + /// "Hello, secure ohkami!" + /// } + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// // Initialize rustls crypto provider + /// match rustls::crypto::ring::default_provider().install_default() { + // Ok(_) => println!("Successfully installed rustls crypto provider"), + // Err(e) => { + // eprintln!("Failed to install rustls crypto provider: {:?}", e); + // std::process::exit(1); + // } + // } + /// // Load certificates and private key + /// let cert_file = File::open("path/to/cert.pem")?; + /// let key_file = File::open("path/to/key.pem")?; + /// + /// let cert_chain = rustls_pemfile::certs(&mut BufReader::new(cert_file)) + /// .map(|certs| certs.into_iter().map(Certificate).collect()) + /// .unwrap_or_default(); + /// + /// let key = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(key_file)) + /// .next() + /// .map(|key| PrivateKey(key)) + /// .expect("Failed to load private key"); + /// + /// // Build TLS configuration + /// let tls_config = ServerConfig::builder() + /// .with_safe_defaults() + /// .with_no_client_auth() + /// .with_single_cert(cert_chain, key) + /// .expect("Failed to build TLS configuration"); + /// + /// // Create and run Ohkami with HTTPS + /// Ohkami::new(( + /// "/".GET(hello), + /// )).howl_tls("0.0.0.0:8443", tls_config).await; + /// + /// Ok(()) + /// } + /// ``` + pub async fn howl_tls(self, bind: impl __rt__::IntoTcpListener, tls_config: rustls::ServerConfig) { + let (router, _) = self.into_router().finalize(); + let router = Arc::new(router); + let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config)); + + let listener = bind.ino_tcp_listener().await; + let (wg, ctrl_c) = (sync::WaitGroup::new(), sync::CtrlC::new()); + + while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await { + let Ok((tcp_stream, addr)) = accept else { continue }; + + let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else { continue }; + + let session = Session::new( + router.clone(), + TlsStream(tls_stream), + addr.ip() + ); + + let wg = wg.add(); + tokio::spawn(async move { + session.manage().await; + wg.done(); + }); + } + + crate::DEBUG!("interrupted, trying graceful shutdown..."); + drop(listener); + + crate::DEBUG!("waiting {} session(s) to finish...", wg.count()); + wg.await; + } + #[cfg(feature="rt_worker")] #[doc(hidden)] pub async fn __worker__(self, diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 4de4d3d73..474fb2677 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -2,23 +2,42 @@ use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; use std::panic::{AssertUnwindSafe, catch_unwind}; -use crate::__rt__::TcpStream; +use crate::__rt__::{AsyncRead, AsyncWrite}; use crate::response::Upgrade; use crate::util::timeout_in; use crate::router::r#final::Router; use crate::{Request, Response}; -pub(crate) struct Session { - router: Arc, - connection: TcpStream, - ip: std::net::IpAddr, +#[cfg(feature="ws")] +use crate::__rt__::TcpStream; + +pub(crate) struct Session { + router: Arc, + connection: S, + ip: std::net::IpAddr, +} + +#[cfg(feature="ws")] +pub(crate) trait WebSocketUpgradeable { + fn into_websocket_stream(self) -> Result; } -impl Session { +#[cfg(feature="ws")] +impl WebSocketUpgradeable for TcpStream { + fn into_websocket_stream(self) -> Result { + Ok(self) + } +} + +#[cfg(feature="ws")] +impl Session +where + S: AsyncRead + AsyncWrite + Unpin + WebSocketUpgradeable, +{ pub(crate) fn new( - router: Arc, - connection: TcpStream, - ip: std::net::IpAddr + router: Arc, + connection: S, + ip: std::net::IpAddr ) -> Self { Self { router, @@ -31,11 +50,11 @@ impl Session { #[cold] #[inline(never)] fn panicking(panic: Box) -> Response { if let Some(msg) = panic.downcast_ref::() { - crate::WARNING!("panic: {msg}"); + crate::WARNING!("[Panicked]: {msg}"); } else if let Some(msg) = panic.downcast_ref::<&str>() { - crate::WARNING!("panic: {msg}"); + crate::WARNING!("[Panicked]: {msg}"); } else { - crate::WARNING!("panic"); + crate::WARNING!("[Panicked]"); } crate::Response::InternalServerError() } @@ -66,7 +85,7 @@ impl Session { } } }).await { - None => crate::WARNING!("\ + None => crate::WARNING!("[WARNING] \ Session timeouted. In Ohkami, Keep-Alive timeout \ is set to 42 seconds by default and is configurable \ by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\ @@ -74,25 +93,99 @@ impl Session { Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"), - #[cfg(feature="ws")] Some(Upgrade::WebSocket(ws)) => { - crate::DEBUG!("WebSocket session started"); - - let aborted = ws.manage_with_timeout( - Duration::from_secs(crate::CONFIG.websocket_timeout()), - self.connection - ).await; - if aborted { - crate::WARNING!("\ - WebSocket session aborted by timeout. In Ohkami, \ - WebSocket timeout is set to 3600 seconds (1 hour) \ - by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ - environment variable.\ - "); - } + match self.connection.into_websocket_stream() { + Ok(tcp_stream) => { + crate::DEBUG!("WebSocket session started"); - crate::DEBUG!("WebSocket session finished"); + let aborted = ws.manage_with_timeout( + Duration::from_secs(crate::CONFIG.websocket_timeout()), + tcp_stream + ).await; + if aborted { + crate::WARNING!("[WARNING] \ + WebSocket session aborted by timeout. In Ohkami, \ + WebSocket timeout is set to 3600 seconds (1 hour) \ + by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ + environment variable.\ + "); + } + + crate::DEBUG!("WebSocket session finished"); + } + Err(msg) => { + crate::WARNING!("[WARNING] {}", msg); + } + } } } } } + +// There has to be some cleaner implementation to apply the conditional trait bounds in this... +#[cfg(not(feature="ws"))] +impl Session +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new( + router: Arc, + connection: S, + ip: std::net::IpAddr + ) -> Self { + Self { + router, + connection, + ip + } + } + + pub(crate) async fn manage(mut self) { + #[cold] #[inline(never)] + fn panicking(panic: Box) -> Response { + if let Some(msg) = panic.downcast_ref::() { + crate::WARNING!("[Panicked]: {msg}"); + } else if let Some(msg) = panic.downcast_ref::<&str>() { + crate::WARNING!("[Panicked]: {msg}"); + } else { + crate::WARNING!("[Panicked]"); + } + crate::Response::InternalServerError() + } + + match timeout_in(Duration::from_secs(crate::CONFIG.keepalive_timeout()), async { + let mut req = Request::init(self.ip); + let mut req = unsafe {Pin::new_unchecked(&mut req)}; + loop { + req.clear(); + match req.as_mut().read(&mut self.connection).await { + Ok(Some(())) => { + let close = matches!(req.headers.Connection(), Some("close" | "Close")); + + let res = match catch_unwind(AssertUnwindSafe({ + let req = req.as_mut(); + || self.router.handle(req.get_mut()) + })) { + Ok(future) => future.await, + Err(panic) => panicking(panic), + }; + let upgrade = res.send(&mut self.connection).await; + + if !upgrade.is_none() {break upgrade} + if close {break Upgrade::None} + } + Ok(None) => break Upgrade::None, + Err(res) => {res.send(&mut self.connection).await;}, + } + } + }).await { + None => crate::WARNING!("[WARNING] \ + Session timeouted. In Ohkami, Keep-Alive timeout \ + is set to 42 seconds by default and is configurable \ + by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\ + "), + + Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"), + } + } +} \ No newline at end of file diff --git a/ohkami/src/tls/mod.rs b/ohkami/src/tls/mod.rs new file mode 100644 index 000000000..2afc66de1 --- /dev/null +++ b/ohkami/src/tls/mod.rs @@ -0,0 +1,52 @@ +use tokio::io::{AsyncRead, AsyncWrite}; +pub struct TlsStream(pub tokio_rustls::server::TlsStream); + +#[cfg(feature="ws")] +impl crate::session::WebSocketUpgradeable for TlsStream { + fn into_websocket_stream(self) -> Result { + Err("WebSocket connections are not supported over TLS yet") + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_> + ) -> std::task::Poll> { + match std::pin::Pin::new(&mut self.0).poll_read(cx, buf) { + std::task::Poll::Ready(Err(e)) => { + if e.to_string().contains("close_notify") { + std::task::Poll::Ready(Ok(())) + } else { + std::task::Poll::Ready(Err(e)) + } + }, + other => other, + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8] + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_shutdown(cx) + } +} \ No newline at end of file