From 4bf5c29ec1158e31eb59da60494144173f7a9808 Mon Sep 17 00:00:00 2001 From: kanarus Date: Wed, 19 Mar 2025 14:08:47 +0900 Subject: [PATCH 1/8] enhance(Dir): support conditional fetch with `If-None-Match` `If-Modified-Since` --- ohkami/src/header/etag.rs | 126 +++++++++++++++++++++++++++++++++ ohkami/src/header/mod.rs | 3 + ohkami/src/ohkami/dir.rs | 115 ++++++++++++++++++++++-------- ohkami/src/ohkami/routing.rs | 2 +- ohkami/src/response/headers.rs | 3 +- ohkami_lib/src/time.rs | 90 ++++++++++++++++++++--- 6 files changed, 300 insertions(+), 39 deletions(-) create mode 100644 ohkami/src/header/etag.rs diff --git a/ohkami/src/header/etag.rs b/ohkami/src/header/etag.rs new file mode 100644 index 000000000..7fca0fd74 --- /dev/null +++ b/ohkami/src/header/etag.rs @@ -0,0 +1,126 @@ +use std::borrow::Cow; + +#[derive(Clone)] +pub enum ETag<'header> { + Any, + Strong(Cow<'header, str>), + Weak(Cow<'header, str>), +} + +pub enum ETagError { + InvalidFormat, + InvalidCharactor, +} +impl std::fmt::Debug for ETagError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + ETagError::InvalidFormat => "InvalidFormat", + ETagError::InvalidCharactor => "InvalidCharactor", + }) + } +} +impl std::fmt::Display for ETagError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + ETagError::InvalidFormat => "InvalidFormat(Etag must be * or a strong/weak tag)", + ETagError::InvalidCharactor => "InvalidCharactor(Etag can only contain ASCII characters)", + }) + } +} +impl std::error::Error for ETagError {} + +impl<'header> ETag<'header> { + pub fn serialize(&self) -> Cow<'static, str> { + match self { + ETag::Any => Cow::Borrowed("*"), + ETag::Strong(value) => Cow::Owned(format!("\"{value}\"")), + ETag::Weak(value) => Cow::Owned(format!("W/\"{value}\"")), + } + } + + /// Parse a single ETag. + pub fn parse(mut raw: &'header str) -> Result { + if raw == "*" { + Ok(ETag::Any) + } else { + let is_weak = raw.starts_with("W/"); + if is_weak { + raw = &raw[2..]; + } + + raw = (raw.len() >= 2 && raw.starts_with('"') && raw.ends_with('"')) + .then(|| &raw[1..raw.len() - 1]) + .ok_or(ETagError::InvalidFormat)?; + + let _ = raw.is_ascii() + .then_some(()) + .ok_or(ETagError::InvalidCharactor)?; + + Ok(if is_weak { + ETag::Weak(Cow::Borrowed(raw)) + } else { + ETag::Strong(Cow::Borrowed(raw)) + }) + } + } + + /// Parse comma-separated ETags into an iterator of `Result`. + /// Invalid ETag is returned as `Err`. + pub fn try_iter_from(raw: &'header str) -> impl Iterator> + 'header { + raw.split(", ").map(ETag::parse) + } + + /// Parse comma-separated ETags into an iterator of `ETag`. + /// Invalid ETag is just ignored. + /// + /// ## Example + /// + /// ``` + /// use ohkami::header::ETag; + /// + /// # fn main() { + /// let mut etags = ETag::iter_from( + /// r#""abc123", W/"def456", "ghi789""# + /// ); + /// + /// assert_eq!(etags.next(), Some(ETag::Strong("abc123"))); + /// assert_eq!(etags.next(), Some(ETag::Weak("def456"))); + /// assert_eq!(etags.next(), Some(ETag::Strong("ghi789"))); + /// assert_eq!(etags.next(), None); + /// + /// let mut etags = ETag::iter_from("*"); + /// assert_eq!(etags.next(), Some(ETag::Any)); + /// assert_eq!(etags.next(), None); + /// # } + /// ``` + pub fn iter_from(raw: &'header str) -> impl Iterator + 'header { + raw.split(", ").filter_map(|it| ETag::parse(it).ok()) + } + + pub fn matches(&self, other: &ETag<'_>) -> bool { + match (self, other) { + (ETag::Any, _) | (_, ETag::Any) => true, + | (ETag::Strong(a), ETag::Strong(b)) + | (ETag::Strong(a), ETag::Weak(b)) + | (ETag::Weak(a), ETag::Strong(b)) + | (ETag::Weak(a), ETag::Weak(b)) + => a == b, + } + } + + pub fn into_owned(self) -> ETag<'static> { + match self { + ETag::Any => ETag::Any, + ETag::Strong(cow) => ETag::Strong(Cow::Owned(cow.into_owned())), + ETag::Weak(cow) => ETag::Weak(Cow::Owned(cow.into_owned())), + } + } +} + +impl ETag<'static> { + pub fn new(value: String) -> Result { + value.is_ascii() + .then_some(Self::Strong(value.into())) + .ok_or(ETagError::InvalidCharactor) + } +} diff --git a/ohkami/src/header/mod.rs b/ohkami/src/header/mod.rs index 0d400ea81..cb328280f 100644 --- a/ohkami/src/header/mod.rs +++ b/ohkami/src/header/mod.rs @@ -4,6 +4,9 @@ mod append; pub use append::append; pub(crate) use append::Append; +mod etag; +pub use etag::ETag; + mod setcookie; pub(crate) use setcookie::*; diff --git a/ohkami/src/ohkami/dir.rs b/ohkami/src/ohkami/dir.rs index 34da4344a..9bdfd4b77 100644 --- a/ohkami/src/ohkami/dir.rs +++ b/ohkami/src/ohkami/dir.rs @@ -1,8 +1,9 @@ #![cfg(feature="__rt_native__")] use crate::handler::{Handler, IntoHandler}; -use crate::response::Content; -use std::fs::File; +use crate::header::ETag; +use ohkami_lib::time::ImfFixdate; +use std::{io, fs::File}; use std::path::{PathBuf, Path}; pub struct Dir { @@ -12,23 +13,24 @@ pub struct Dir { /*=== config ===*/ pub(crate) serve_dotfiles: bool, pub(crate) omit_extensions: &'static [&'static str], + pub(crate) etag: Option String>, } impl Dir { - pub(super) fn new(route: &'static str, dir_path: std::path::PathBuf) -> std::io::Result { + pub(super) fn new(route: &'static str, dir_path: PathBuf) -> io::Result { let dir_path = dir_path.canonicalize()?; if !dir_path.is_dir() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, + return Err(io::Error::new( + io::ErrorKind::InvalidInput, format!("{} is not directory", dir_path.display())) ) } let mut files = Vec::<(PathBuf, File)>::new(); { fn fetch_entries( - dir: std::path::PathBuf - ) -> std::io::Result> { + dir: PathBuf + ) -> io::Result> { dir.read_dir()? .map(|de| de.map(|de| de.path())) .collect() @@ -39,7 +41,7 @@ impl Dir { if entry.is_file() { files.push(( entry.iter().skip(dir_path.iter().count()).collect(), - std::fs::File::open(entry)? + File::open(entry)? )); } else if entry.is_dir() { @@ -56,6 +58,7 @@ impl Dir { files, serve_dotfiles: false, omit_extensions: &[], + etag: None, }) } @@ -75,16 +78,46 @@ impl Dir { self.omit_extensions = extensions_to_omit; self } + + /// Set a function to generate ETag for each file. + pub fn etag(mut self, etag: impl Into String>>) -> Self { + self.etag = etag.into(); + self + } } #[derive(Clone)] pub(super) struct StaticFileHandler { - mime: &'static str, + last_modified: ImfFixdate, + last_modified_str: String, + etag: Option>, + mime: &'static str, content: std::sync::Arc>, } impl StaticFileHandler { - pub(super) fn new(path: &Path, file: std::fs::File) -> std::io::Result { + pub(super) fn new( + path: &Path, + file: File, + get_etag: Option String>, + ) -> io::Result { + let last_modified_str = ohkami_lib::time::UTCDateTime::from_unix_timestamp( + file + .metadata()? + .modified()? + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + ).into_imf_fixdate(); + + let last_modified = ImfFixdate::parse(&last_modified_str) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let etag = get_etag + .map(|f| ETag::new(f(&file))) + .transpose() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mime = ::mime_guess::from_path(path) .first_raw() .unwrap_or("application/octet-stream"); @@ -92,12 +125,18 @@ impl StaticFileHandler { let mut content = vec![ u8::default(); file.metadata().unwrap().len() as usize - ]; {use std::io::Read; + ]; {use io::Read; let mut file = file; file.read_exact(&mut content)?; } - Ok(Self { mime, content:std::sync::Arc::new(content) }) + Ok(Self { + last_modified, + last_modified_str, + etag, + mime, + content: std::sync::Arc::new(content) + }) } } @@ -105,25 +144,43 @@ impl IntoHandler for StaticFileHandler { fn n_params(&self) -> usize {0} fn into_handler(self) -> Handler { - let this: &'static StaticFileHandler - = Box::leak(Box::new(self)); - - Handler::new(|_| Box::pin(async { - let mut res = crate::Response::OK(); - { - let content: &'static [u8] = &this.content; - res.headers.set() - .ContentType(this.mime) - .ContentLength(ohkami_lib::num::itoa(content.len())); - res.content = Content::Payload(content.into()); + let this: &'static StaticFileHandler = Box::leak(Box::new(self)); + + Handler::new(|req| Box::pin(async { + use crate::{Response, header::ETag}; + + if let (Some(if_none_match), Some(etag)) = (req.headers.IfNoneMatch(), &this.etag) { + if ETag::iter_from(if_none_match).any(|it| it.matches(etag)) { + return Response::NotModified(); + } } - res + if let Some(if_modified_since) = req.headers.IfModifiedSince() { + let Ok(if_modified_since) = ImfFixdate::parse(if_modified_since) else { + return Response::BadRequest(); + }; + if if_modified_since >= this.last_modified { + return Response::NotModified(); + } + } + + Response::OK() + .with_payload(this.mime, &*this.content) + .with_headers(|h| h + .LastModified(&*this.last_modified_str) + .ETag(this.etag.as_ref().map(|etag| etag.serialize())) + ) }), #[cfg(feature="openapi")] {use crate::openapi; - openapi::Operation::with(openapi::Responses::new([( - 200, - openapi::Response::when("OK") - .content(this.mime, openapi::string().format("binary")) - )])) + openapi::Operation::with(openapi::Responses::new([ + ( + 200, + openapi::Response::when("OK") + .content(this.mime, openapi::string().format("binary")) + ), + ( + 304, + openapi::Response::when("Not Modified") + ) + ])) }) } } diff --git a/ohkami/src/ohkami/routing.rs b/ohkami/src/ohkami/routing.rs index a25f11df8..51b2cd9ed 100644 --- a/ohkami/src/ohkami/routing.rs +++ b/ohkami/src/ohkami/routing.rs @@ -195,7 +195,7 @@ const _: () = { }; for (mut path, file) in self.files { - let handler = StaticFileHandler::new(&path, file) + let handler = StaticFileHandler::new(&path, file, self.etag) .expect(&format!("can't serve file: `{}`", path.display())); let file_name = path.file_name().unwrap().to_str() diff --git a/ohkami/src/response/headers.rs b/ohkami/src/response/headers.rs index 9d56a1600..67f5dd2f3 100644 --- a/ohkami/src/response/headers.rs +++ b/ohkami/src/response/headers.rs @@ -188,7 +188,7 @@ macro_rules! Header { } } }; -} Header! {47; +} Header! {48; AcceptRanges: b"Accept-Ranges", AccessControlAllowCredentials: b"Access-Control-Allow-Credentials", AccessControlAllowHeaders: b"Access-Control-Allow-Headers", @@ -219,6 +219,7 @@ macro_rules! Header { Expires: b"Expires", Link: b"Link", Location: b"Location", + LastModified: b"Last-Modified", ProxyAuthenticate: b"Proxy-Authenticate", ReferrerPolicy: b"Referrer-Policy", Refresh: b"Refresh", diff --git a/ohkami_lib/src/time.rs b/ohkami_lib/src/time.rs index d7726757f..016e40abc 100644 --- a/ohkami_lib/src/time.rs +++ b/ohkami_lib/src/time.rs @@ -1,6 +1,5 @@ //! Most parts are based on [chrono](https://github.com/chronotope/chrono); MIT. - /// Current datetime by **IMF-fixdate** format like `Sun, 06 Nov 1994 08:49:37 GMT`, used in `Date` header. /// /// (reference:[https://datatracker.ietf.org/doc/html/rfc9110#name-date-time-formats](https://datatracker.ietf.org/doc/html/rfc9110#name-date-time-formats)) @@ -8,7 +7,87 @@ UTCDateTime::from_unix_timestamp(unix_timestamp).into_imf_fixdate() } +const SHORT_WEEKDAYS: [&[u8; 3]; 7 ] = [b"Sun", b"Mon", b"Tue", b"Wed", b"Thu", b"Fri", b"Sat"]; +const SHORT_MONTHS: [&[u8; 3]; 12] = [b"Jan", b"Feb", b"Mar", b"Apr", b"May", b"Jun", b"Jul", b"Aug", b"Sep", b"Oct", b"Nov", b"Dec"]; + +const IMF_FIXDATE_LEN: usize = str::len("Sun, 06 Nov 1994 08:49:37 GMT"); + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub struct ImfFixdate { + year: u16, + month: u8, + day: u8, + hour: u8, + min: u8, + sec: u8, +} +impl ImfFixdate { + pub fn parse(s: &str) -> Result { + let mut r = ::byte_reader::Reader::new(s.as_bytes()); + + let _ = r.consume_oneof(SHORT_WEEKDAYS) + .ok_or_else(|| format!("invalid weekday: `{s}`"))?; + + r.consume(b", ").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let day = r.read_uint() + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| format!("invalid day: `{s}`"))?; + + r.consume(b" ").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let month = r.consume_oneof(SHORT_MONTHS) + .and_then(|index| u8::try_from(index).ok()) + .ok_or_else(|| format!("invalid month: `{s}`"))?; + + r.consume(b" ").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let year = r.read_uint() + .and_then(|n| u16::try_from(n).ok()) + .ok_or_else(|| format!("invalid year: `{s}`"))?; + + r.consume(b" ").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let hour = r.read_uint() + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| format!("invalid hour: `{s}`"))?; + + r.consume(b":").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let min = r.read_uint() + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| format!("invalid minute: `{s}`"))?; + + r.consume(b":").ok_or_else(|| format!("invalid separator: `{s}`"))?; + + let sec = r.read_uint() + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| format!("invalid second: `{s}`"))?; + + r.consume(b" GMT").ok_or_else(|| format!("invalid timezone: `{s}`"))?; + + Ok(ImfFixdate { year, month, day, hour, min, sec }) + } +} +impl PartialOrd for ImfFixdate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for ImfFixdate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + std::cmp::Ordering::Equal + .then(self.year.cmp(&other.year)) + .then(self.month.cmp(&other.month)) + .then(self.day.cmp(&other.day)) + .then(self.hour.cmp(&other.hour)) + .then(self.min.cmp(&other.min)) + .then(self.sec.cmp(&other.sec)) + } +} + /// date time on UTC *to the second* +#[derive(Clone, Copy)] pub struct UTCDateTime { date: Date, time: Time, @@ -28,11 +107,6 @@ impl UTCDateTime { } pub fn into_imf_fixdate(self) -> String { - const SHORT_WEEKDAYS: [&[u8; 3]; 7 ] = [b"Sun", b"Mon", b"Tue", b"Wed", b"Thu", b"Fri", b"Sat"]; - const SHORT_MONTHS: [&[u8; 3]; 12] = [b"Jan", b"Feb", b"Mar", b"Apr", b"May", b"Jun", b"Jul", b"Aug", b"Sep", b"Oct", b"Nov", b"Dec"]; - - const IMF_FIXDATE_LEN: usize = str::len("Sun, 06 Nov 1994 08:49:37 GMT"); - let mut buf = [std::mem::MaybeUninit::::uninit(); IMF_FIXDATE_LEN]; let mut i = 0; @@ -102,7 +176,7 @@ impl UTCDateTime { } /// (year << 13) | of -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] struct Date(i32); impl Date { fn from_days(days: i32) -> Self { @@ -180,7 +254,7 @@ impl Date { } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq)] struct Time { secs: u32, } From ef6effa22f98cd83e92d349ded0f4489110208ec Mon Sep 17 00:00:00 2001 From: kanarus Date: Wed, 19 Mar 2025 14:14:27 +0900 Subject: [PATCH 2/8] next: test conditional fetch impl --- examples/static_files/src/main.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/static_files/src/main.rs b/examples/static_files/src/main.rs index 17fe57f92..9b72ed713 100644 --- a/examples/static_files/src/main.rs +++ b/examples/static_files/src/main.rs @@ -3,21 +3,24 @@ use ohkami::prelude::*; struct Options { omit_dot_html: bool, serve_dotfiles: bool, + etag: Option String>, } impl Default for Options { fn default() -> Self { Self { omit_dot_html: false, serve_dotfiles: false, + etag: None, } } } -fn ohkami(Options { omit_dot_html, serve_dotfiles }: Options) -> Ohkami { +fn ohkami(Options { omit_dot_html, serve_dotfiles, etag }: Options) -> Ohkami { Ohkami::new(( "/".Dir("./public") .omit_extensions(if omit_dot_html {&["html"]} else {&[]}) - .serve_dotfiles(serve_dotfiles), + .serve_dotfiles(serve_dotfiles) + .etag(etag), )) } From 3674200df32310f6fc00497f0481173e982fdf8d Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 00:40:58 +0700 Subject: [PATCH 3/8] tls: implement basic tls implementation with tokio-rustls (tokio-only!) --- ohkami/Cargo.toml | 3 ++ ohkami/src/lib.rs | 3 ++ ohkami/src/ohkami/mod.rs | 100 ++++++++++++++++++++++++++++++++++++++ ohkami/src/session/mod.rs | 40 ++++++++------- ohkami/src/tls/mod.rs | 47 ++++++++++++++++++ 5 files changed, 175 insertions(+), 18 deletions(-) create mode 100644 ohkami/src/tls/mod.rs diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 641c9718f..900cbce23 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -47,6 +47,8 @@ ctrlc = { version = "3.4", optional = true } num_cpus = { version = "1.16", optional = true } futures-util = { version = "0.3", optional = true, default-features = false } mews = { version = "0.2", optional = true } +rustls = { version = "0.23.23", optional = true } +tokio-rustls = { version = "0.26.2", optional = true } [features] @@ -89,6 +91,7 @@ nightly = [] openapi = ["dep:ohkami_openapi", "ohkami_macros/openapi"] sse = ["ohkami_lib/stream"] ws = ["ohkami_lib/stream", "dep:mews"] +tls = ["__rt_native__", "rt_tokio", "dep:rustls", "dep:tokio-rustls"] ##### internal ##### __rt__ = [] diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 56489c804..5debb0347 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -218,6 +218,9 @@ pub use ohkami::{Ohkami, Route}; pub mod fang; pub use fang::{handler, Fang, FangProc}; +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +mod tls; + pub mod format; pub mod header; diff --git a/ohkami/src/ohkami/mod.rs b/ohkami/src/ohkami/mod.rs index 3296fcc86..e0f821439 100644 --- a/ohkami/src/ohkami/mod.rs +++ b/ohkami/src/ohkami/mod.rs @@ -14,6 +14,12 @@ use std::sync::Arc; #[cfg(feature="__rt_native__")] use crate::{__rt__, Session}; +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +use tokio_rustls::TlsAcceptor; + +#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +use crate::tls::TlsStream; + /// # Ohkami - a smart wolf who serves your web app /// /// ## Definition @@ -589,6 +595,100 @@ impl Ohkami { wg.await; } + #[cfg(all(feature="__rt_native__", feature="rt_tokio"))] + /// Bind this `Ohkami` to an address and start serving with TLS/HTTPS support! + /// + /// This method works like `howl` but upgrades connections to HTTPS using the provided + /// rustls configuration. This functionality is only available with the `rt_tokio` feature. + /// + /// ### Parameters + /// + /// - `bind`: Same as `howl`, can be a socket address or TcpListener + /// - `tls_config`: A rustls server configuration containing your certificates and keys + /// + /// ### Example + /// + /// ```no_run + /// use ohkami::prelude::*; + /// use rustls::{ServerConfig, Certificate, PrivateKey}; + /// use std::fs::File; + /// use std::io::BufReader; + /// + /// async fn hello() -> &'static str { + /// "Hello, secure ohkami!" + /// } + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// // Initialize rustls crypto provider + /// match rustls::crypto::ring::default_provider().install_default() { + // Ok(_) => println!("Successfully installed rustls crypto provider"), + // Err(e) => { + // eprintln!("Failed to install rustls crypto provider: {:?}", e); + // std::process::exit(1); + // } + // } + /// // Load certificates and private key + /// let cert_file = File::open("path/to/cert.pem")?; + /// let key_file = File::open("path/to/key.pem")?; + /// + /// let cert_chain = rustls_pemfile::certs(&mut BufReader::new(cert_file)) + /// .map(|certs| certs.into_iter().map(Certificate).collect()) + /// .unwrap_or_default(); + /// + /// let key = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(key_file)) + /// .next() + /// .map(|key| PrivateKey(key)) + /// .expect("Failed to load private key"); + /// + /// // Build TLS configuration + /// let tls_config = ServerConfig::builder() + /// .with_safe_defaults() + /// .with_no_client_auth() + /// .with_single_cert(cert_chain, key) + /// .expect("Failed to build TLS configuration"); + /// + /// // Create and run Ohkami with HTTPS + /// Ohkami::new(( + /// "/".GET(hello), + /// )).howl_tls("0.0.0.0:8443", tls_config).await; + /// + /// Ok(()) + /// } + /// ``` + pub async fn howl_tls(self, bind: impl __rt__::IntoTcpListener, tls_config: rustls::ServerConfig) { + let (router, _) = self.into_router().finalize(); + let router = Arc::new(router); + let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config)); + + let listener = bind.ino_tcp_listener().await; + let (wg, ctrl_c) = (sync::WaitGroup::new(), sync::CtrlC::new()); + + while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await { + let Ok((tcp_stream, addr)) = accept else { continue }; + + let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else { continue }; + + let session = Session::new( + router.clone(), + TlsStream(tls_stream), + addr.ip() + ); + + let wg = wg.add(); + tokio::spawn(async move { + session.manage().await; + wg.done(); + }); + } + + crate::DEBUG!("interrupted, trying graceful shutdown..."); + drop(listener); + + crate::DEBUG!("waiting {} session(s) to finish...", wg.count()); + wg.await; + } + #[cfg(feature="rt_worker")] #[doc(hidden)] pub async fn __worker__(self, diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 4de4d3d73..0420694de 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -1,29 +1,33 @@ -#![cfg(feature="__rt_native__")] - -use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; -use std::panic::{AssertUnwindSafe, catch_unwind}; -use crate::__rt__::TcpStream; +#[cfg(feature = "__rt_native__")] +use crate::router::r#final::Router; use crate::response::Upgrade; use crate::util::timeout_in; -use crate::router::r#final::Router; use crate::{Request, Response}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; -pub(crate) struct Session { - router: Arc, - connection: TcpStream, - ip: std::net::IpAddr, +#[cfg(feature = "rt_tokio")] +use tokio::io::{AsyncRead, AsyncWrite}; +#[cfg(any(feature = "rt_async-std", feature = "rt_smol"))] +use futures_util::io::{AsyncRead, AsyncWrite}; +#[cfg(feature = "rt_glommio")] +use glommio::io::{AsyncRead, AsyncWrite}; + +pub(crate) struct Session { + router: Arc, + connection: S, // Changed connection to generic type for TcpStream and TlsStream + ip: std::net::IpAddr, } -impl Session { - pub(crate) fn new( - router: Arc, - connection: TcpStream, - ip: std::net::IpAddr - ) -> Self { +impl Session +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new(router: Arc, connection: S, ip: std::net::IpAddr) -> Self { Self { router, connection, - ip + ip, } } @@ -95,4 +99,4 @@ impl Session { } } } -} +} \ No newline at end of file diff --git a/ohkami/src/tls/mod.rs b/ohkami/src/tls/mod.rs new file mode 100644 index 000000000..a9e9f384f --- /dev/null +++ b/ohkami/src/tls/mod.rs @@ -0,0 +1,47 @@ +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +pub struct TlsStream(pub tokio_rustls::server::TlsStream); + +impl AsyncRead for TlsStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_> + ) -> std::task::Poll> { + match std::pin::Pin::new(&mut self.0).poll_read(cx, buf) { + std::task::Poll::Ready(Err(e)) => { + if e.to_string().contains("close_notify") { + std::task::Poll::Ready(Ok(())) + //Re-impl TlsStream's AsyncRead & AsyncWrite just for this, to prevent panic on abrupt client TLS connection close. Probably not a great idea, but it works I guess... + } else { + std::task::Poll::Ready(Err(e)) + } + }, + other => other, + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8] + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_shutdown(cx) + } +} \ No newline at end of file From b45481419102d53b9dfa33cef773642fa468a9a0 Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 19:21:10 +0700 Subject: [PATCH 4/8] fix(session): fix AsyncRead & AsyncWrite imports for session --- ohkami/src/ohkami/mod.rs | 2 +- ohkami/src/session/mod.rs | 10 ++-------- ohkami/src/tls/mod.rs | 4 +--- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/ohkami/src/ohkami/mod.rs b/ohkami/src/ohkami/mod.rs index e0f821439..3eab6a751 100644 --- a/ohkami/src/ohkami/mod.rs +++ b/ohkami/src/ohkami/mod.rs @@ -595,7 +595,7 @@ impl Ohkami { wg.await; } - #[cfg(all(feature="__rt_native__", feature="rt_tokio"))] + #[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] /// Bind this `Ohkami` to an address and start serving with TLS/HTTPS support! /// /// This method works like `howl` but upgrades connections to HTTPS using the provided diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 0420694de..1546a4f97 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -5,17 +5,11 @@ use crate::util::timeout_in; use crate::{Request, Response}; use std::panic::{AssertUnwindSafe, catch_unwind}; use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; - -#[cfg(feature = "rt_tokio")] -use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(any(feature = "rt_async-std", feature = "rt_smol"))] -use futures_util::io::{AsyncRead, AsyncWrite}; -#[cfg(feature = "rt_glommio")] -use glommio::io::{AsyncRead, AsyncWrite}; +use crate::__rt__::{AsyncRead, AsyncWrite}; pub(crate) struct Session { router: Arc, - connection: S, // Changed connection to generic type for TcpStream and TlsStream + connection: S, ip: std::net::IpAddr, } diff --git a/ohkami/src/tls/mod.rs b/ohkami/src/tls/mod.rs index a9e9f384f..f7ee21bdd 100644 --- a/ohkami/src/tls/mod.rs +++ b/ohkami/src/tls/mod.rs @@ -1,5 +1,4 @@ -use tokio::io::AsyncRead; -use tokio::io::AsyncWrite; +use tokio::io::{AsyncRead, AsyncWrite}; pub struct TlsStream(pub tokio_rustls::server::TlsStream); impl AsyncRead for TlsStream { @@ -12,7 +11,6 @@ impl AsyncRead for TlsStream { std::task::Poll::Ready(Err(e)) => { if e.to_string().contains("close_notify") { std::task::Poll::Ready(Ok(())) - //Re-impl TlsStream's AsyncRead & AsyncWrite just for this, to prevent panic on abrupt client TLS connection close. Probably not a great idea, but it works I guess... } else { std::task::Poll::Ready(Err(e)) } From 3e9f84908111583de593468df9ab1f2b570f3a50 Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 21:25:13 +0700 Subject: [PATCH 5/8] fix(session): fix unsupported WS over TLS & cfg handling --- ohkami/src/session/mod.rs | 81 ++++++++++++++++++++++++++------------- ohkami/src/tls/mod.rs | 7 ++++ 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 1546a4f97..4b967f9bf 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -1,11 +1,15 @@ -#[cfg(feature = "__rt_native__")] -use crate::router::r#final::Router; +#![cfg(feature="__rt_native__")] + +use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use crate::__rt__::{AsyncRead, AsyncWrite}; use crate::response::Upgrade; use crate::util::timeout_in; +use crate::router::r#final::Router; use crate::{Request, Response}; -use std::panic::{AssertUnwindSafe, catch_unwind}; -use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; -use crate::__rt__::{AsyncRead, AsyncWrite}; + +#[cfg(feature="ws")] +use crate::__rt__::TcpStream; pub(crate) struct Session { router: Arc, @@ -13,15 +17,31 @@ pub(crate) struct Session { ip: std::net::IpAddr, } -impl Session +#[cfg(feature="ws")] +pub(crate) trait WebSocketUpgradeable { + fn into_websocket_stream(self) -> Result; +} + +#[cfg(feature="ws")] +impl WebSocketUpgradeable for TcpStream { + fn into_websocket_stream(self) -> Result { + Ok(self) + } +} + +impl Session where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static + WebSocketUpgradeable, { - pub(crate) fn new(router: Arc, connection: S, ip: std::net::IpAddr) -> Self { + pub(crate) fn new( + router: Arc, + connection: S, + ip: std::net::IpAddr + ) -> Self { Self { router, connection, - ip, + ip } } @@ -29,11 +49,11 @@ where #[cold] #[inline(never)] fn panicking(panic: Box) -> Response { if let Some(msg) = panic.downcast_ref::() { - crate::WARNING!("panic: {msg}"); + crate::WARNING!("[Panicked]: {msg}"); } else if let Some(msg) = panic.downcast_ref::<&str>() { - crate::WARNING!("panic: {msg}"); + crate::WARNING!("[Panicked]: {msg}"); } else { - crate::WARNING!("panic"); + crate::WARNING!("[Panicked]"); } crate::Response::InternalServerError() } @@ -64,7 +84,7 @@ where } } }).await { - None => crate::WARNING!("\ + None => crate::WARNING!("[WARNING] \ Session timeouted. In Ohkami, Keep-Alive timeout \ is set to 42 seconds by default and is configurable \ by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\ @@ -74,22 +94,29 @@ where #[cfg(feature="ws")] Some(Upgrade::WebSocket(ws)) => { - crate::DEBUG!("WebSocket session started"); + match self.connection.into_websocket_stream() { + Ok(tcp_stream) => { + crate::DEBUG!("WebSocket session started"); - let aborted = ws.manage_with_timeout( - Duration::from_secs(crate::CONFIG.websocket_timeout()), - self.connection - ).await; - if aborted { - crate::WARNING!("\ - WebSocket session aborted by timeout. In Ohkami, \ - WebSocket timeout is set to 3600 seconds (1 hour) \ - by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ - environment variable.\ - "); - } + let aborted = ws.manage_with_timeout( + Duration::from_secs(crate::CONFIG.websocket_timeout()), + tcp_stream + ).await; + if aborted { + crate::WARNING!("[WARNING] \ + WebSocket session aborted by timeout. In Ohkami, \ + WebSocket timeout is set to 3600 seconds (1 hour) \ + by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ + environment variable.\ + "); + } - crate::DEBUG!("WebSocket session finished"); + crate::DEBUG!("WebSocket session finished"); + } + Err(msg) => { + crate::WARNING!("[WARNING] {}", msg); + } + } } } } diff --git a/ohkami/src/tls/mod.rs b/ohkami/src/tls/mod.rs index f7ee21bdd..2afc66de1 100644 --- a/ohkami/src/tls/mod.rs +++ b/ohkami/src/tls/mod.rs @@ -1,6 +1,13 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub struct TlsStream(pub tokio_rustls::server::TlsStream); +#[cfg(feature="ws")] +impl crate::session::WebSocketUpgradeable for TlsStream { + fn into_websocket_stream(self) -> Result { + Err("WebSocket connections are not supported over TLS yet") + } +} + impl AsyncRead for TlsStream { fn poll_read( mut self: std::pin::Pin<&mut Self>, From 727c7c2d95a91fece6389d68a6c94dd623ac8e90 Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 21:45:06 +0700 Subject: [PATCH 6/8] fix(session): fix trait bounds when ws feature is disabled --- ohkami/src/session/mod.rs | 70 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 4b967f9bf..5e09b0bd3 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -29,6 +29,7 @@ impl WebSocketUpgradeable for TcpStream { } } +#[cfg(feature="ws")] impl Session where S: AsyncRead + AsyncWrite + Unpin + Send + 'static + WebSocketUpgradeable, @@ -92,7 +93,6 @@ where Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"), - #[cfg(feature="ws")] Some(Upgrade::WebSocket(ws)) => { match self.connection.into_websocket_stream() { Ok(tcp_stream) => { @@ -120,4 +120,72 @@ where } } } +} + +// There has to be some cleaner implementation to apply the conditional trait bounds in this... +#[cfg(not(feature="ws"))] +impl Session +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub(crate) fn new( + router: Arc, + connection: S, + ip: std::net::IpAddr + ) -> Self { + Self { + router, + connection, + ip + } + } + + pub(crate) async fn manage(mut self) { + #[cold] #[inline(never)] + fn panicking(panic: Box) -> Response { + if let Some(msg) = panic.downcast_ref::() { + crate::WARNING!("[Panicked]: {msg}"); + } else if let Some(msg) = panic.downcast_ref::<&str>() { + crate::WARNING!("[Panicked]: {msg}"); + } else { + crate::WARNING!("[Panicked]"); + } + crate::Response::InternalServerError() + } + + match timeout_in(Duration::from_secs(crate::CONFIG.keepalive_timeout()), async { + let mut req = Request::init(self.ip); + let mut req = unsafe {Pin::new_unchecked(&mut req)}; + loop { + req.clear(); + match req.as_mut().read(&mut self.connection).await { + Ok(Some(())) => { + let close = matches!(req.headers.Connection(), Some("close" | "Close")); + + let res = match catch_unwind(AssertUnwindSafe({ + let req = req.as_mut(); + || self.router.handle(req.get_mut()) + })) { + Ok(future) => future.await, + Err(panic) => panicking(panic), + }; + let upgrade = res.send(&mut self.connection).await; + + if !upgrade.is_none() {break upgrade} + if close {break Upgrade::None} + } + Ok(None) => break Upgrade::None, + Err(res) => {res.send(&mut self.connection).await;}, + } + } + }).await { + None => crate::WARNING!("[WARNING] \ + Session timeouted. In Ohkami, Keep-Alive timeout \ + is set to 42 seconds by default and is configurable \ + by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\ + "), + + Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"), + } + } } \ No newline at end of file From 5a325cbfbe3fd45c144c5074bface5cc446afa1c Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 22:11:31 +0700 Subject: [PATCH 7/8] fix(session): omit Send trait to unbreak builds on other runtimes (hopefully) --- ohkami/src/session/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 5e09b0bd3..474fb2677 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -32,7 +32,7 @@ impl WebSocketUpgradeable for TcpStream { #[cfg(feature="ws")] impl Session where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static + WebSocketUpgradeable, + S: AsyncRead + AsyncWrite + Unpin + WebSocketUpgradeable, { pub(crate) fn new( router: Arc, @@ -126,7 +126,7 @@ where #[cfg(not(feature="ws"))] impl Session where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn new( router: Arc, From 739ca2315c3dcc6f8da730b33baa917b3776f683 Mon Sep 17 00:00:00 2001 From: voldtman Date: Thu, 20 Mar 2025 22:36:51 +0700 Subject: [PATCH 8/8] fix(etag): merge issues fix --- ohkami/src/header/etag.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ohkami/src/header/etag.rs b/ohkami/src/header/etag.rs index c7348568a..d0056efbe 100644 --- a/ohkami/src/header/etag.rs +++ b/ohkami/src/header/etag.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; #[derive(Clone, Debug, PartialEq)] - pub enum ETag<'header> { Any, Strong(Cow<'header, str>), @@ -84,6 +83,9 @@ impl<'header> ETag<'header> { /// r#""abc123", W/"def456", "ghi789""# /// ); /// + /// assert_eq!(etags.next(), Some(ETag::Strong("abc123".into()))); + /// assert_eq!(etags.next(), Some(ETag::Weak("def456".into()))); + /// assert_eq!(etags.next(), Some(ETag::Strong("ghi789".into()))); /// assert_eq!(etags.next(), None); /// /// let mut etags = ETag::iter_from("*"); @@ -121,4 +123,4 @@ impl ETag<'static> { .then_some(Self::Strong(value.into())) .ok_or(ETagError::InvalidCharactor) } -} +} \ No newline at end of file