Skip to content
Merged
3 changes: 3 additions & 0 deletions ohkami/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ ctrlc = { version = "3.4", optional = true }
num_cpus = { version = "1.16", optional = true }
futures-util = { version = "0.3", optional = true, default-features = false }
mews = { version = "0.2", optional = true }
rustls = { version = "0.23.23", optional = true }
tokio-rustls = { version = "0.26.2", optional = true }


[features]
Expand Down Expand Up @@ -89,6 +91,7 @@ nightly = []
openapi = ["dep:ohkami_openapi", "ohkami_macros/openapi"]
sse = ["ohkami_lib/stream"]
ws = ["ohkami_lib/stream", "dep:mews"]
tls = ["__rt_native__", "rt_tokio", "dep:rustls", "dep:tokio-rustls"]

##### internal #####
__rt__ = []
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/header/etag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ impl ETag<'static> {
.then_some(Self::Strong(value.into()))
.ok_or(ETagError::InvalidCharactor)
}
}
}
3 changes: 3 additions & 0 deletions ohkami/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ pub use ohkami::{Ohkami, Route};
pub mod fang;
pub use fang::{handler, Fang, FangProc};

#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
mod tls;

pub mod format;

pub mod header;
Expand Down
100 changes: 100 additions & 0 deletions ohkami/src/ohkami/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ use std::sync::Arc;
#[cfg(feature="__rt_native__")]
use crate::{__rt__, Session};

#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
use tokio_rustls::TlsAcceptor;

#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
use crate::tls::TlsStream;

/// # Ohkami - a smart wolf who serves your web app
///
/// ## Definition
Expand Down Expand Up @@ -589,6 +595,100 @@ impl Ohkami {
wg.await;
}

#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))]
/// Bind this `Ohkami` to an address and start serving with TLS/HTTPS support!
///
/// This method works like `howl` but upgrades connections to HTTPS using the provided
/// rustls configuration. This functionality is only available with the `rt_tokio` feature.
///
/// ### Parameters
///
/// - `bind`: Same as `howl`, can be a socket address or TcpListener
/// - `tls_config`: A rustls server configuration containing your certificates and keys
///
/// ### Example
///
/// ```no_run
/// use ohkami::prelude::*;
/// use rustls::{ServerConfig, Certificate, PrivateKey};
/// use std::fs::File;
/// use std::io::BufReader;
///
/// async fn hello() -> &'static str {
/// "Hello, secure ohkami!"
/// }
///
/// #[tokio::main]
/// async fn main() -> std::io::Result<()> {
/// // Initialize rustls crypto provider
/// match rustls::crypto::ring::default_provider().install_default() {
// Ok(_) => println!("Successfully installed rustls crypto provider"),
// Err(e) => {
// eprintln!("Failed to install rustls crypto provider: {:?}", e);
// std::process::exit(1);
// }
// }
/// // Load certificates and private key
/// let cert_file = File::open("path/to/cert.pem")?;
/// let key_file = File::open("path/to/key.pem")?;
///
/// let cert_chain = rustls_pemfile::certs(&mut BufReader::new(cert_file))
/// .map(|certs| certs.into_iter().map(Certificate).collect())
/// .unwrap_or_default();
///
/// let key = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(key_file))
/// .next()
/// .map(|key| PrivateKey(key))
/// .expect("Failed to load private key");
///
/// // Build TLS configuration
/// let tls_config = ServerConfig::builder()
/// .with_safe_defaults()
/// .with_no_client_auth()
/// .with_single_cert(cert_chain, key)
/// .expect("Failed to build TLS configuration");
///
/// // Create and run Ohkami with HTTPS
/// Ohkami::new((
/// "/".GET(hello),
/// )).howl_tls("0.0.0.0:8443", tls_config).await;
///
/// Ok(())
/// }
/// ```
pub async fn howl_tls<T>(self, bind: impl __rt__::IntoTcpListener<T>, tls_config: rustls::ServerConfig) {
let (router, _) = self.into_router().finalize();
let router = Arc::new(router);
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));

let listener = bind.ino_tcp_listener().await;
let (wg, ctrl_c) = (sync::WaitGroup::new(), sync::CtrlC::new());

while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await {
let Ok((tcp_stream, addr)) = accept else { continue };

let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else { continue };

let session = Session::new(
router.clone(),
TlsStream(tls_stream),
addr.ip()
);

let wg = wg.add();
tokio::spawn(async move {
session.manage().await;
wg.done();
});
}

crate::DEBUG!("interrupted, trying graceful shutdown...");
drop(listener);

crate::DEBUG!("waiting {} session(s) to finish...", wg.count());
wg.await;
}

#[cfg(feature="rt_worker")]
#[doc(hidden)]
pub async fn __worker__(self,
Expand Down
151 changes: 122 additions & 29 deletions ohkami/src/session/mod.rs
Copy link
Copy Markdown
Member

@kanarus kanarus Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently Upgrade::WebSocket holds WebSocket<TcpStream> ( TcpStream is the default type param of mews::WebSocket ), so ws.manage_with_timeout() can't accept generic S.

Could you work around this to upgrade only when Session<TcpStream> for now ?

Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,42 @@

use std::{any::Any, pin::Pin, sync::Arc, time::Duration};
use std::panic::{AssertUnwindSafe, catch_unwind};
use crate::__rt__::TcpStream;
use crate::__rt__::{AsyncRead, AsyncWrite};
use crate::response::Upgrade;
use crate::util::timeout_in;
use crate::router::r#final::Router;
use crate::{Request, Response};

pub(crate) struct Session {
router: Arc<Router>,
connection: TcpStream,
ip: std::net::IpAddr,
#[cfg(feature="ws")]
use crate::__rt__::TcpStream;

pub(crate) struct Session<S> {
router: Arc<Router>,
connection: S,
ip: std::net::IpAddr,
}

#[cfg(feature="ws")]
pub(crate) trait WebSocketUpgradeable {
fn into_websocket_stream(self) -> Result<TcpStream, &'static str>;
}

impl Session {
#[cfg(feature="ws")]
impl WebSocketUpgradeable for TcpStream {
fn into_websocket_stream(self) -> Result<TcpStream, &'static str> {
Ok(self)
}
}

#[cfg(feature="ws")]
impl<S> Session<S>
where
S: AsyncRead + AsyncWrite + Unpin + WebSocketUpgradeable,
{
pub(crate) fn new(
router: Arc<Router>,
connection: TcpStream,
ip: std::net::IpAddr
router: Arc<Router>,
connection: S,
ip: std::net::IpAddr
) -> Self {
Self {
router,
Expand All @@ -31,11 +50,11 @@ impl Session {
#[cold] #[inline(never)]
fn panicking(panic: Box<dyn Any + Send>) -> Response {
if let Some(msg) = panic.downcast_ref::<String>() {
crate::WARNING!("panic: {msg}");
crate::WARNING!("[Panicked]: {msg}");
} else if let Some(msg) = panic.downcast_ref::<&str>() {
crate::WARNING!("panic: {msg}");
crate::WARNING!("[Panicked]: {msg}");
} else {
crate::WARNING!("panic");
crate::WARNING!("[Panicked]");
}
crate::Response::InternalServerError()
}
Expand Down Expand Up @@ -66,33 +85,107 @@ impl Session {
}
}
}).await {
None => crate::WARNING!("\
None => crate::WARNING!("[WARNING] \
Session timeouted. In Ohkami, Keep-Alive timeout \
is set to 42 seconds by default and is configurable \
by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\
"),

Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"),

#[cfg(feature="ws")]
Some(Upgrade::WebSocket(ws)) => {
crate::DEBUG!("WebSocket session started");

let aborted = ws.manage_with_timeout(
Duration::from_secs(crate::CONFIG.websocket_timeout()),
self.connection
).await;
if aborted {
crate::WARNING!("\
WebSocket session aborted by timeout. In Ohkami, \
WebSocket timeout is set to 3600 seconds (1 hour) \
by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \
environment variable.\
");
}
match self.connection.into_websocket_stream() {
Ok(tcp_stream) => {
crate::DEBUG!("WebSocket session started");

crate::DEBUG!("WebSocket session finished");
let aborted = ws.manage_with_timeout(
Duration::from_secs(crate::CONFIG.websocket_timeout()),
tcp_stream
).await;
if aborted {
crate::WARNING!("[WARNING] \
WebSocket session aborted by timeout. In Ohkami, \
WebSocket timeout is set to 3600 seconds (1 hour) \
by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \
environment variable.\
");
}

crate::DEBUG!("WebSocket session finished");
}
Err(msg) => {
crate::WARNING!("[WARNING] {}", msg);
}
}
}
}
}
}

// There has to be some cleaner implementation to apply the conditional trait bounds in this...
#[cfg(not(feature="ws"))]
impl<S> Session<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
router: Arc<Router>,
connection: S,
ip: std::net::IpAddr
) -> Self {
Self {
router,
connection,
ip
}
}

pub(crate) async fn manage(mut self) {
#[cold] #[inline(never)]
fn panicking(panic: Box<dyn Any + Send>) -> Response {
if let Some(msg) = panic.downcast_ref::<String>() {
crate::WARNING!("[Panicked]: {msg}");
} else if let Some(msg) = panic.downcast_ref::<&str>() {
crate::WARNING!("[Panicked]: {msg}");
} else {
crate::WARNING!("[Panicked]");
}
crate::Response::InternalServerError()
}

match timeout_in(Duration::from_secs(crate::CONFIG.keepalive_timeout()), async {
let mut req = Request::init(self.ip);
let mut req = unsafe {Pin::new_unchecked(&mut req)};
loop {
req.clear();
match req.as_mut().read(&mut self.connection).await {
Ok(Some(())) => {
let close = matches!(req.headers.Connection(), Some("close" | "Close"));

let res = match catch_unwind(AssertUnwindSafe({
let req = req.as_mut();
|| self.router.handle(req.get_mut())
})) {
Ok(future) => future.await,
Err(panic) => panicking(panic),
};
let upgrade = res.send(&mut self.connection).await;

if !upgrade.is_none() {break upgrade}
if close {break Upgrade::None}
}
Ok(None) => break Upgrade::None,
Err(res) => {res.send(&mut self.connection).await;},
}
}
}).await {
None => crate::WARNING!("[WARNING] \
Session timeouted. In Ohkami, Keep-Alive timeout \
is set to 42 seconds by default and is configurable \
by `OHKAMI_KEEPALIVE_TIMEOUT` environment variable.\
"),

Some(Upgrade::None) => crate::DEBUG!("about to shutdown connection"),
}
}
}
52 changes: 52 additions & 0 deletions ohkami/src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use tokio::io::{AsyncRead, AsyncWrite};
pub struct TlsStream(pub tokio_rustls::server::TlsStream<tokio::net::TcpStream>);

#[cfg(feature="ws")]
impl crate::session::WebSocketUpgradeable for TlsStream {
fn into_websocket_stream(self) -> Result<crate::__rt__::TcpStream, &'static str> {
Err("WebSocket connections are not supported over TLS yet")
}
}

impl AsyncRead for TlsStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>
) -> std::task::Poll<std::io::Result<()>> {
match std::pin::Pin::new(&mut self.0).poll_read(cx, buf) {
std::task::Poll::Ready(Err(e)) => {
if e.to_string().contains("close_notify") {
std::task::Poll::Ready(Ok(()))
} else {
std::task::Poll::Ready(Err(e))
}
},
other => other,
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8]
) -> std::task::Poll<std::io::Result<usize>> {
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
}
}