Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions ohkami/src/fang/builtin/csrf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{Request, Response, IntoResponse, Fang, FangProc};
use std::sync::Arc;

/// # Built-in CSRF protection fang.
/// # Built-in CSRF protection fang
///
/// The implementation is based on the way of Go 1.25 net/http's `CrossOriginProtection`:
///
Expand Down Expand Up @@ -33,6 +33,8 @@ use std::sync::Arc;
/// If you have multiple servers, you can use `Csrf::with_trusted_origins`
/// to specify trusted origins.
///
/// **NOTE**: wildcards (like `https://*.a.domain`) are not supported in trusted origins.
///
/// ```no_run
/// use ohkami::{Ohkami, Route, fang::Csrf};
///
Expand Down Expand Up @@ -63,24 +65,25 @@ impl Csrf {
let trusted_origins = trusted_origins.into_iter().collect::<Vec<_>>();

for origin in &trusted_origins {
let Some((scheme, rest)) = origin.split_once("://") else {
panic!("invalid origin `{origin}`: scheme is required")
let Some(("http" | "https", rest)) = origin.split_once("://") else {
panic!("[Csrf::with_trusted_origins] invalid origin `{origin}`: 'http' or 'https' scheme is required")
};
if !matches!(scheme, "http" | "https") {
panic!("invalid origin `{origin}`: scheme must be 'http' or 'https'");
}
if rest.contains(['/', '?', '#']) {
panic!("invalid origin `{origin}`: path, query and fragment are not allowed");
}
let (host, port) = rest.split_once(':').map_or((rest, None), |(h, p)| (h, Some(p)));
if port.is_some_and(|p| !p.chars().all(|c| c.is_ascii_digit())) {
panic!("invalid origin `{origin}`: port must be a number");
}
if !host.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '.' | '_')) {
panic!("invalid origin `{origin}`: invalid host");
panic!("[Csrf::with_trusted_origins] invalid origin `{origin}`: port must be a number");
}
if !host.starts_with(|c: char| c.is_ascii_alphabetic()) {
panic!("invalid origin `{origin}`: host must start with an alphabetic character");
panic!("[Csrf::with_trusted_origins] invalid origin `{origin}`: host must start with an alphabetic character");
}
if !host.split('.').all(|part|
!part.is_empty() && part.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_'))
) {
if host.contains(['/', '?', '#']) {
// helpful error message for common mistake
panic!("[Csrf::with_trusted_origins] invalid origin `{origin}`: path, query and fragment are not allowed");
} else {
panic!("[Csrf::with_trusted_origins] invalid origin `{origin}`: invalid host");
}
}
}

Expand Down