Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pub trait FangProc {

built-in:

- `BasicAuth`, `Cors`, `Jwt` (authentication/security)
- `BasicAuth`, `Cors`, `Csrf`, `Jwt` (authentication/security)
- `Context` (reuqest context)
- `Enamel` (security headers; experimantal)
- `Timeout` (handling timeout; native runtime only)
Expand Down
3 changes: 3 additions & 0 deletions ohkami/src/fang/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ pub use basicauth::BasicAuth;
mod cors;
pub use cors::Cors;

mod csrf;
pub use csrf::Csrf;

mod jwt;
pub use jwt::{Jwt, JwtToken};

Expand Down
243 changes: 243 additions & 0 deletions ohkami/src/fang/builtin/csrf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
use crate::{Request, Response, IntoResponse, Fang, FangProc};
use std::sync::Arc;

/// # Built-in CSRF protection fang.
///
/// The implementation is based on the way of Go 1.25 net/http's `CrossOriginProtection`:
///
/// - doc: https://go.dev/doc/go1.25#nethttppkgnethttp
/// - code: https://cs.opensource.google/go/go/+/refs/tags/go1.25.0:src/net/http/csrf.go
///
/// providing a token-less CSRF protection mechanism, with support for byppassing trusted origins.
///
/// ## Usage
///
/// ### Single Server Service
///
/// Just `Csrf::new()` and add it to your Ohkami app.
///
/// ```no_run
/// use ohkami::{Ohkami, Route, fang::Csrf};
///
/// #[tokio::main]
/// async fn main() {
/// Ohkami::new((
/// Csrf::new(),
/// "/".GET(|| async {"Hello, CSRF!"}),
/// )).howl("0.0.0.0:3000").await
/// }
/// ```
///
/// ### Multi Server Service
///
/// If you have multiple servers, you can use `Csrf::with_trusted_origins`
/// to specify trusted origins.
///
/// ```no_run
/// use ohkami::{Ohkami, Route, fang::Csrf};
///
/// #[tokio::main]
/// async fn main() {
/// Ohkami::new((
/// Csrf::with_trusted_origins([
/// "https://example.com",
/// "https://example.org",
/// ]),
/// "/".GET(|| async {"Hello, CSRF!"}),
/// )).howl("0.0.0.0:5000").await
/// }
/// ```
#[derive(Clone)]
pub struct Csrf {
trusted_origins: Arc<Vec<&'static str>>,
}

impl Csrf {
pub fn new() -> Self {
Csrf {
trusted_origins: Arc::new(vec![]),
}
}

pub fn with_trusted_origins(trusted_origins: impl IntoIterator<Item = &'static str>) -> Self {
let trusted_origins = trusted_origins.into_iter().collect::<Vec<_>>();

for origin in &trusted_origins {
let Some((scheme, rest)) = origin.split_once("://") else {
panic!("invalid origin `{origin}`: scheme is required")
};
if !matches!(scheme, "http" | "https") {
panic!("invalid origin `{origin}`: scheme must be 'http' or 'https'");
}
if rest.contains(['/', '?', '#']) {
panic!("invalid origin `{origin}`: path, query and fragment are not allowed");
}
let (host, port) = rest.split_once(':').map_or((rest, None), |(h, p)| (h, Some(p)));
if port.is_some_and(|p| !p.chars().all(|c| c.is_ascii_digit())) {
panic!("invalid origin `{origin}`: port must be a number");
}
if !host.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '.' | '_')) {
panic!("invalid origin `{origin}`: invalid host");
}
if !host.starts_with(|c: char| c.is_ascii_alphabetic()) {
panic!("invalid origin `{origin}`: host must start with an alphabetic character");
}
}

Csrf { trusted_origins: Arc::new(trusted_origins) }
}
}

pub enum CsrfError {
InvalidSecFetchSite,
OriginNotMatchHost,
NoHostHeader,
}
impl IntoResponse for CsrfError {
fn into_response(self) -> Response {
match self {
CsrfError::InvalidSecFetchSite => Response::Forbidden()
.with_text("cross-origin request detected from Sec-Fetch-Site header"),
CsrfError::OriginNotMatchHost => Response::Forbidden()
.with_text("cross-origin request detected, and/or browser is out of date: Sec-Fetch-Site is missing, and Origin does not match Host"),
CsrfError::NoHostHeader => Response::BadRequest(),
}
}
}

impl Csrf {
pub fn verify(&self, req: &Request) -> Result<(), CsrfError> {
let is_trusted = || req.headers.origin().is_some_and(|it| self.trusted_origins.contains(&it));

if req.method.is_safe() {
Ok(())
} else if let Some(sec_fetch_site) = req.headers.sec_fetch_site() {
match sec_fetch_site {
"same-origin" | "none" => Ok(()),
_ => is_trusted().then_some(()).ok_or(CsrfError::InvalidSecFetchSite),
}
} else {
match (req.headers.origin(), req.headers.host()) {
(None, _) => Ok(()), // No Origin header, so we assume it's same-origin or not a browser request.
(_, None) => Err(CsrfError::NoHostHeader),
(Some(origin), Some(host)) if matches!(
origin.strip_suffix(host),
Some("http://" | "https://")
) => Ok(()),
_ => is_trusted().then_some(()).ok_or(CsrfError::OriginNotMatchHost),
}
}
}
}

const _: () = {
pub struct CsrfProc<I: FangProc> {
csrf: Csrf,
inner: I,
}

impl<I: FangProc> Fang<I> for Csrf {
type Proc = CsrfProc<I>;

fn chain(&self, inner: I) -> Self::Proc {
CsrfProc { csrf: self.clone(), inner }
}
}

impl<I: FangProc> FangProc for CsrfProc<I> {
async fn bite<'b>(&'b self, req: &'b mut Request) -> Response {
match self.csrf.verify(req) {
Ok(()) => self.inner.bite(req).await,
Err(e) => e.into_response(),
}
}
}
};

#[cfg(test)]
#[cfg(feature="__rt_native__")]
mod tests {
//! based on https://cs.opensource.google/go/go/+/refs/tags/go1.25.0:src/net/http/csrf_test.go

use super::*;
use crate::testing::*;
use crate::{Ohkami, Route};

macro_rules! x {($method:ident) => {
TestRequest::$method("/").header("host", "example.com")
}}

#[test]
fn test_sec_fetch_site() {
let t = Ohkami::new((
Csrf::new(),
"/".GET(async || ()).PUT(async || ()).POST(async || ()),
)).test();

crate::__rt__::testing::block_on(async {
for (req, expected) in [
(x!(POST).header("sec-fetch-site", "same-origin"), Status::OK),
(x!(POST).header("sec-fetch-site", "none"), Status::OK),
(x!(POST).header("sec-fetch-site", "cross-site"), Status::Forbidden),
(x!(POST).header("sec-fetch-site", "same-site"), Status::Forbidden),

(x!(POST), Status::OK),
(x!(POST).header("origin", "https://example.com"), Status::OK),
(x!(POST).header("origin", "https://attacker.example"), Status::Forbidden),
(x!(POST).header("origin", "null"), Status::Forbidden),

(x!(GET).header("sec-fetch-site", "cross-site"), Status::OK),
(x!(HEAD).header("sec-fetch-site", "cross-site"), Status::OK),
(x!(OPTIONS).header("sec-fetch-site", "cross-site"), Status::NotFound), // see `fang::handler::Handler::default_options_with`
(x!(PUT).header("sec-fetch-site", "cross-site"), Status::Forbidden),
] {
let res = t.oneshot(req).await;
assert_eq!(res.status(), expected);
}
});
}

#[test]
fn test_trusted_origins() {
let t = Ohkami::new((
Csrf::with_trusted_origins(["https://trusted.example"]),
"/".POST(async || ()),
)).test();

crate::__rt__::testing::block_on(async {
for (req, expected) in [
(x!(POST).header("origin", "https://trusted.example"), Status::OK),
(x!(POST).header("origin", "https://trusted.example").header("sec-fetch-site", "cross-site"), Status::OK),
(x!(POST).header("origin", "https://attacker.example"), Status::Forbidden),
(x!(POST).header("origin", "https://attacker.example").header("sec-fetch-site", "cross-site"), Status::Forbidden),
] {
let res = t.oneshot(req).await;
assert_eq!(res.status(), expected);
}
});
}

#[test]
fn test_invalid_trusted_origins() {
for (trusted_origin, should_judged_as_invalid) in [
("https://example.com", false),
("https://example.com:8080", false),
("http://example.com", false),
("example.com", true), // missing scheme
("https://", true), // missing host
("https://example.com/", true), // path is not allowed
("https://example.com/path", true), // path is not allowed
("https://example.com?query=1", true), // query is not allowed
("https://example.com#fragment", true), // fragment is not allowed
("https://ex ample.com", true), // invalid host
("", true), // empty string
("null", true), // missing scheme
("https://example.com:port", true), // invalid port
] {
let is_judged_as_invalid = std::panic::catch_unwind(|| {
let _ = Csrf::with_trusted_origins([trusted_origin]);
}).is_err();
assert_eq!(is_judged_as_invalid, should_judged_as_invalid, "unexpected result for trusted origin `{trusted_origin}`");
}
}
}
15 changes: 12 additions & 3 deletions ohkami/src/request/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ impl Method {
}
}
#[cfg(feature="rt_worker")]
#[inline(always)] pub(crate) const fn from_worker(w: ::worker::Method) -> Option<Self> {
#[inline(always)]
pub(crate) const fn from_worker(w: ::worker::Method) -> Option<Self> {
match w {
::worker::Method::Get => Some(Self::GET),
::worker::Method::Put => Some(Self::PUT),
Expand All @@ -36,7 +37,8 @@ impl Method {
}
}

#[inline] pub const fn as_str(&self) -> &'static str {
#[inline]
pub const fn as_str(&self) -> &'static str {
match self {
Self::GET => "GET",
Self::PUT => "PUT",
Expand All @@ -47,8 +49,15 @@ impl Method {
Self::OPTIONS => "OPTIONS",
}
}

#[inline]
pub const fn is_safe(&self) -> bool {
matches!(self, Self::GET | Self::HEAD | Self::OPTIONS)
}
}
#[allow(non_snake_case)] impl Method {

#[allow(non_snake_case)]
impl Method {
pub const fn isGET(&self) -> bool {
matches!(self, Method::GET)
}
Expand Down
5 changes: 2 additions & 3 deletions ohkami/src/testing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg(debug_assertions)]
#![cfg(feature="__rt__")]

//! Ohkami testing tools
//!
Expand Down Expand Up @@ -29,15 +30,13 @@
//! }
//! ```

use crate::{Response, Request, Ohkami, Status, Method};
pub use crate::{Response, Request, Ohkami, Status, Method};
use crate::router::r#final::Router;

use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use std::{pin::Pin, future::Future, format as f};


pub trait Testing {
fn test(self) -> TestingOhkami;
}
Expand Down