diff --git a/README.md b/README.md index ce2d368d9..ae6541cfe 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ pub trait FangProc { built-in: -- `BasicAuth`, `Cors`, `Jwt` (authentication/security) +- `BasicAuth`, `Cors`, `Csrf`, `Jwt` (authentication/security) - `Context` (reuqest context) - `Enamel` (security headers; experimantal) - `Timeout` (handling timeout; native runtime only) diff --git a/ohkami/src/fang/builtin.rs b/ohkami/src/fang/builtin.rs index 300ac6093..081c85090 100644 --- a/ohkami/src/fang/builtin.rs +++ b/ohkami/src/fang/builtin.rs @@ -4,6 +4,9 @@ pub use basicauth::BasicAuth; mod cors; pub use cors::Cors; +mod csrf; +pub use csrf::Csrf; + mod jwt; pub use jwt::{Jwt, JwtToken}; diff --git a/ohkami/src/fang/builtin/csrf.rs b/ohkami/src/fang/builtin/csrf.rs new file mode 100644 index 000000000..243cfa2ad --- /dev/null +++ b/ohkami/src/fang/builtin/csrf.rs @@ -0,0 +1,243 @@ +use crate::{Request, Response, IntoResponse, Fang, FangProc}; +use std::sync::Arc; + +/// # Built-in CSRF protection fang. +/// +/// The implementation is based on the way of Go 1.25 net/http's `CrossOriginProtection`: +/// +/// - doc: https://go.dev/doc/go1.25#nethttppkgnethttp +/// - code: https://cs.opensource.google/go/go/+/refs/tags/go1.25.0:src/net/http/csrf.go +/// +/// providing a token-less CSRF protection mechanism, with support for byppassing trusted origins. +/// +/// ## Usage +/// +/// ### Single Server Service +/// +/// Just `Csrf::new()` and add it to your Ohkami app. +/// +/// ```no_run +/// use ohkami::{Ohkami, Route, fang::Csrf}; +/// +/// #[tokio::main] +/// async fn main() { +/// Ohkami::new(( +/// Csrf::new(), +/// "/".GET(|| async {"Hello, CSRF!"}), +/// )).howl("0.0.0.0:3000").await +/// } +/// ``` +/// +/// ### Multi Server Service +/// +/// If you have multiple servers, you can use `Csrf::with_trusted_origins` +/// to specify trusted origins. +/// +/// ```no_run +/// use ohkami::{Ohkami, Route, fang::Csrf}; +/// +/// #[tokio::main] +/// async fn main() { +/// Ohkami::new(( +/// Csrf::with_trusted_origins([ +/// "https://example.com", +/// "https://example.org", +/// ]), +/// "/".GET(|| async {"Hello, CSRF!"}), +/// )).howl("0.0.0.0:5000").await +/// } +/// ``` +#[derive(Clone)] +pub struct Csrf { + trusted_origins: Arc>, +} + +impl Csrf { + pub fn new() -> Self { + Csrf { + trusted_origins: Arc::new(vec![]), + } + } + + pub fn with_trusted_origins(trusted_origins: impl IntoIterator) -> Self { + let trusted_origins = trusted_origins.into_iter().collect::>(); + + for origin in &trusted_origins { + let Some((scheme, rest)) = origin.split_once("://") else { + panic!("invalid origin `{origin}`: scheme is required") + }; + if !matches!(scheme, "http" | "https") { + panic!("invalid origin `{origin}`: scheme must be 'http' or 'https'"); + } + if rest.contains(['/', '?', '#']) { + panic!("invalid origin `{origin}`: path, query and fragment are not allowed"); + } + let (host, port) = rest.split_once(':').map_or((rest, None), |(h, p)| (h, Some(p))); + if port.is_some_and(|p| !p.chars().all(|c| c.is_ascii_digit())) { + panic!("invalid origin `{origin}`: port must be a number"); + } + if !host.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '.' | '_')) { + panic!("invalid origin `{origin}`: invalid host"); + } + if !host.starts_with(|c: char| c.is_ascii_alphabetic()) { + panic!("invalid origin `{origin}`: host must start with an alphabetic character"); + } + } + + Csrf { trusted_origins: Arc::new(trusted_origins) } + } +} + +pub enum CsrfError { + InvalidSecFetchSite, + OriginNotMatchHost, + NoHostHeader, +} +impl IntoResponse for CsrfError { + fn into_response(self) -> Response { + match self { + CsrfError::InvalidSecFetchSite => Response::Forbidden() + .with_text("cross-origin request detected from Sec-Fetch-Site header"), + CsrfError::OriginNotMatchHost => Response::Forbidden() + .with_text("cross-origin request detected, and/or browser is out of date: Sec-Fetch-Site is missing, and Origin does not match Host"), + CsrfError::NoHostHeader => Response::BadRequest(), + } + } +} + +impl Csrf { + pub fn verify(&self, req: &Request) -> Result<(), CsrfError> { + let is_trusted = || req.headers.origin().is_some_and(|it| self.trusted_origins.contains(&it)); + + if req.method.is_safe() { + Ok(()) + } else if let Some(sec_fetch_site) = req.headers.sec_fetch_site() { + match sec_fetch_site { + "same-origin" | "none" => Ok(()), + _ => is_trusted().then_some(()).ok_or(CsrfError::InvalidSecFetchSite), + } + } else { + match (req.headers.origin(), req.headers.host()) { + (None, _) => Ok(()), // No Origin header, so we assume it's same-origin or not a browser request. + (_, None) => Err(CsrfError::NoHostHeader), + (Some(origin), Some(host)) if matches!( + origin.strip_suffix(host), + Some("http://" | "https://") + ) => Ok(()), + _ => is_trusted().then_some(()).ok_or(CsrfError::OriginNotMatchHost), + } + } + } +} + +const _: () = { + pub struct CsrfProc { + csrf: Csrf, + inner: I, + } + + impl Fang for Csrf { + type Proc = CsrfProc; + + fn chain(&self, inner: I) -> Self::Proc { + CsrfProc { csrf: self.clone(), inner } + } + } + + impl FangProc for CsrfProc { + async fn bite<'b>(&'b self, req: &'b mut Request) -> Response { + match self.csrf.verify(req) { + Ok(()) => self.inner.bite(req).await, + Err(e) => e.into_response(), + } + } + } +}; + +#[cfg(test)] +#[cfg(feature="__rt_native__")] +mod tests { + //! based on https://cs.opensource.google/go/go/+/refs/tags/go1.25.0:src/net/http/csrf_test.go + + use super::*; + use crate::testing::*; + use crate::{Ohkami, Route}; + + macro_rules! x {($method:ident) => { + TestRequest::$method("/").header("host", "example.com") + }} + + #[test] + fn test_sec_fetch_site() { + let t = Ohkami::new(( + Csrf::new(), + "/".GET(async || ()).PUT(async || ()).POST(async || ()), + )).test(); + + crate::__rt__::testing::block_on(async { + for (req, expected) in [ + (x!(POST).header("sec-fetch-site", "same-origin"), Status::OK), + (x!(POST).header("sec-fetch-site", "none"), Status::OK), + (x!(POST).header("sec-fetch-site", "cross-site"), Status::Forbidden), + (x!(POST).header("sec-fetch-site", "same-site"), Status::Forbidden), + + (x!(POST), Status::OK), + (x!(POST).header("origin", "https://example.com"), Status::OK), + (x!(POST).header("origin", "https://attacker.example"), Status::Forbidden), + (x!(POST).header("origin", "null"), Status::Forbidden), + + (x!(GET).header("sec-fetch-site", "cross-site"), Status::OK), + (x!(HEAD).header("sec-fetch-site", "cross-site"), Status::OK), + (x!(OPTIONS).header("sec-fetch-site", "cross-site"), Status::NotFound), // see `fang::handler::Handler::default_options_with` + (x!(PUT).header("sec-fetch-site", "cross-site"), Status::Forbidden), + ] { + let res = t.oneshot(req).await; + assert_eq!(res.status(), expected); + } + }); + } + + #[test] + fn test_trusted_origins() { + let t = Ohkami::new(( + Csrf::with_trusted_origins(["https://trusted.example"]), + "/".POST(async || ()), + )).test(); + + crate::__rt__::testing::block_on(async { + for (req, expected) in [ + (x!(POST).header("origin", "https://trusted.example"), Status::OK), + (x!(POST).header("origin", "https://trusted.example").header("sec-fetch-site", "cross-site"), Status::OK), + (x!(POST).header("origin", "https://attacker.example"), Status::Forbidden), + (x!(POST).header("origin", "https://attacker.example").header("sec-fetch-site", "cross-site"), Status::Forbidden), + ] { + let res = t.oneshot(req).await; + assert_eq!(res.status(), expected); + } + }); + } + + #[test] + fn test_invalid_trusted_origins() { + for (trusted_origin, should_judged_as_invalid) in [ + ("https://example.com", false), + ("https://example.com:8080", false), + ("http://example.com", false), + ("example.com", true), // missing scheme + ("https://", true), // missing host + ("https://example.com/", true), // path is not allowed + ("https://example.com/path", true), // path is not allowed + ("https://example.com?query=1", true), // query is not allowed + ("https://example.com#fragment", true), // fragment is not allowed + ("https://ex ample.com", true), // invalid host + ("", true), // empty string + ("null", true), // missing scheme + ("https://example.com:port", true), // invalid port + ] { + let is_judged_as_invalid = std::panic::catch_unwind(|| { + let _ = Csrf::with_trusted_origins([trusted_origin]); + }).is_err(); + assert_eq!(is_judged_as_invalid, should_judged_as_invalid, "unexpected result for trusted origin `{trusted_origin}`"); + } + } +} diff --git a/ohkami/src/request/method.rs b/ohkami/src/request/method.rs index d4310da60..a8eed5ad6 100644 --- a/ohkami/src/request/method.rs +++ b/ohkami/src/request/method.rs @@ -23,7 +23,8 @@ impl Method { } } #[cfg(feature="rt_worker")] - #[inline(always)] pub(crate) const fn from_worker(w: ::worker::Method) -> Option { + #[inline(always)] + pub(crate) const fn from_worker(w: ::worker::Method) -> Option { match w { ::worker::Method::Get => Some(Self::GET), ::worker::Method::Put => Some(Self::PUT), @@ -36,7 +37,8 @@ impl Method { } } - #[inline] pub const fn as_str(&self) -> &'static str { + #[inline] + pub const fn as_str(&self) -> &'static str { match self { Self::GET => "GET", Self::PUT => "PUT", @@ -47,8 +49,15 @@ impl Method { Self::OPTIONS => "OPTIONS", } } + + #[inline] + pub const fn is_safe(&self) -> bool { + matches!(self, Self::GET | Self::HEAD | Self::OPTIONS) + } } -#[allow(non_snake_case)] impl Method { + +#[allow(non_snake_case)] +impl Method { pub const fn isGET(&self) -> bool { matches!(self, Method::GET) } diff --git a/ohkami/src/testing/mod.rs b/ohkami/src/testing/mod.rs index c48472b8a..2450faa6e 100644 --- a/ohkami/src/testing/mod.rs +++ b/ohkami/src/testing/mod.rs @@ -1,4 +1,5 @@ #![cfg(debug_assertions)] +#![cfg(feature="__rt__")] //! Ohkami testing tools //! @@ -29,15 +30,13 @@ //! } //! ``` -use crate::{Response, Request, Ohkami, Status, Method}; +pub use crate::{Response, Request, Ohkami, Status, Method}; use crate::router::r#final::Router; - use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; use std::{pin::Pin, future::Future, format as f}; - pub trait Testing { fn test(self) -> TestingOhkami; }