diff --git a/Taskfile.yaml b/Taskfile.yaml index d4b8289e7..060ee2e5d 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -67,7 +67,8 @@ tasks: test:examples: dir: ./examples cmds: - - cargo test + - chmod +x ./test.sh + - ./test.sh test:samples: dir: ./samples diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 02776cb80..4b183ff9c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,6 +2,7 @@ resolver = "3" members = [ "sse", + "jwt", "form", "hello", "chatgpt", diff --git a/examples/jwt/.env.sample b/examples/jwt/.env.sample new file mode 100644 index 000000000..6b893cb0c --- /dev/null +++ b/examples/jwt/.env.sample @@ -0,0 +1 @@ +JWT_SECRET=your-jwt-secret-key diff --git a/examples/jwt/.gitignore b/examples/jwt/.gitignore new file mode 100644 index 000000000..2eea525d8 --- /dev/null +++ b/examples/jwt/.gitignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/examples/jwt/Cargo.toml b/examples/jwt/Cargo.toml new file mode 100644 index 000000000..7fa52589c --- /dev/null +++ b/examples/jwt/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "jwt" +version = "0.1.0" +edition = "2024" + +[dependencies] +ohkami = { workspace = true } +tokio = { workspace = true } +dotenvy = "0.15" diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs new file mode 100644 index 000000000..01a79d80b --- /dev/null +++ b/examples/jwt/src/main.rs @@ -0,0 +1,105 @@ +use ohkami::prelude::*; +use ohkami::fang::{JWT, JWTToken}; + +fn jwt() -> JWT { + JWT::default(std::env::var("JWT_SECRET").unwrap()) +} + +#[derive(Serialize, Deserialize)] +struct JwtPayload { + sub: String, + exp: u64, +} + +trait JwtSub: 'static { + fn sub() -> String; +} + +struct DefaultJwtSub; +impl JwtSub for DefaultJwtSub { + fn sub() -> String {"ohkami".to_string()} +} + +#[derive(Serialize)] +#[cfg_attr(test, derive(Deserialize))] +struct AuthResponse { + token: JWTToken, +} + +async fn auth() -> JSON { + let token = jwt().issue(JwtPayload { + sub: S::sub(), + exp: ohkami::util::unix_timestamp() + 86400, + }); + + JSON(AuthResponse { token }) +} + +async fn private( + Context(_): Context<'_, JwtPayload>, +) -> &'static str { + "Hello, private!" +} + +fn ohkami() -> Ohkami { + Ohkami::new(( + "/auth".GET(auth::), + "/private".GET((jwt(), private)), + )) +} + +#[tokio::main] +async fn main() { + dotenvy::dotenv().ok(); + ohkami::().howl("0.0.0.0:3000").await +} + +#[cfg(test)] +mod test { + use super::*; + use ohkami::testing::*; + + /// regression test for https://github.com/ohkami-rs/ohkami/issues/433 + /// + /// run with `OHKAMI_REQUEST_BUFSIZE=4096` or larger + #[tokio::test] + async fn test_large_jwt() { + struct LargeJwtSub; + impl JwtSub for LargeJwtSub { + fn sub() -> String { + const SENTENCE: &'static str = "\ + Lorem ipsum dolor sit amet, consectetur adipiscing elit. \ + Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. \ + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris \ + nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in \ + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla \ + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in \ + culpa qui officia deserunt mollit anim id est laborum.\ + "; + + let sub = SENTENCE.repeat((1 << 11) / SENTENCE.len() + 1); + + // `sub` itself is already larger than default `request_bufsize` + assert!(sub.len() > (1 << 11)); + + sub + } + } + + dotenvy::dotenv().ok(); + + let t = ohkami::().test(); + + let req = TestRequest::GET("/auth"); + let res = t.oneshot(req).await; + let AuthResponse { token } = res.json() + .expect("`/auth` response doesn't contain a token") + .expect("`/auth` response is not `AuthResponse`"); + + let req = TestRequest::GET("/private") + .header("Authorization", format!("Bearer {token}")); + let res = t.oneshot(req).await; + assert_eq!(res.status().code(), 200); + assert_eq!(res.text(), Some("Hello, private!")); + } +} diff --git a/examples/test.sh b/examples/test.sh new file mode 100755 index 000000000..4f330f933 --- /dev/null +++ b/examples/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -Cue + +EXAMPLES=$(pwd) + +cd $EXAMPLES/jwt && \ + cp .env.sample .env +cd $EXAMPLES/jwt && \ + cargo test 2>&1 | grep 'Unexpected end of headers' \ + && echo '---> expected error' \ + || exit 1 +cd $EXAMPLES/jwt && \ + OHKAMI_REQUEST_BUFSIZE=4096 cargo test diff --git a/ohkami/src/config.rs b/ohkami/src/config.rs index 2cbfe3c73..5ac8dcd4b 100644 --- a/ohkami/src/config.rs +++ b/ohkami/src/config.rs @@ -1,4 +1,7 @@ pub(crate) struct Config { + #[cfg(feature="__rt_native__")] + request_bufsize: std::sync::LazyLock, + #[cfg(feature="__rt_native__")] keepalive_timeout: std::sync::LazyLock, @@ -9,6 +12,13 @@ pub(crate) struct Config { impl Config { #[cfg(feature="__rt_native__")] + #[inline] + pub(crate) fn request_bufsize(&self) -> usize { + *(&*self.request_bufsize) + } + + #[cfg(feature="__rt_native__")] + #[inline] pub(crate) fn keepalive_timeout(&self) -> u64 { *(&*self.keepalive_timeout) } @@ -23,15 +33,28 @@ impl Config { impl Config { pub(super) const fn new() -> Self { Self { + #[cfg(feature="__rt_native__")] + request_bufsize: std::sync::LazyLock::new(|| std::env::var("OHKAMI_REQUEST_BUFSIZE") + .ok() + .map(|v| v.parse().ok()) + .flatten() + .unwrap_or(1 << 11) + ), + #[cfg(feature="__rt_native__")] keepalive_timeout: std::sync::LazyLock::new(|| std::env::var("OHKAMI_KEEPALIVE_TIMEOUT") - .ok().map(|v| v.parse().ok()).flatten() + .ok() + .map(|v| v.parse().ok()) + .flatten() .unwrap_or(42) ), + #[cfg(feature="__rt_native__")] #[cfg(feature="ws")] websocket_timeout: std::sync::LazyLock::new(|| std::env::var("OHKAMI_WEBSOCKET_TIMEOUT") - .ok().map(|v| v.parse().ok()).flatten() + .ok() + .map(|v| v.parse().ok()) + .flatten() .unwrap_or(42) ), } diff --git a/ohkami/src/fang/builtin/jwt.rs b/ohkami/src/fang/builtin/jwt.rs index 43fe40e88..ae562070b 100644 --- a/ohkami/src/fang/builtin/jwt.rs +++ b/ohkami/src/fang/builtin/jwt.rs @@ -446,7 +446,7 @@ impl Deserialize<'de>> JWT { let mut req = unsafe {Pin::new_unchecked(&mut req)}; crate::__rt__::testing::block_on({ req.as_mut().read(&mut req_bytes) - }); + }).unwrap(); assert_eq!( my_jwt.verified(&req.as_ref()).unwrap(), @@ -462,7 +462,7 @@ impl Deserialize<'de>> JWT { let mut req = unsafe {Pin::new_unchecked(&mut req)}; crate::__rt__::testing::block_on({ req.as_mut().read(&mut req_bytes) - }); + }).unwrap(); assert_eq!( my_jwt.verified(&req.as_ref()).unwrap_err().status, diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 35785aad3..56489c804 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -149,27 +149,29 @@ mod __rt__ { #[cfg(test)] pub(crate) mod testing { - pub(crate) fn block_on(future: impl std::future::Future) { + pub(crate) fn block_on(future: F) -> F::Output { #[cfg(feature="rt_tokio")] - tokio::runtime::Builder::new_current_thread() + return tokio::runtime::Builder::new_current_thread() .enable_all() - .build().unwrap() + .build() + .unwrap() .block_on(future); #[cfg(feature="rt_async-std")] - async_std::task::block_on(future); + return async_std::task::block_on(future); #[cfg(feature="rt_smol")] - smol::block_on(future); + return smol::block_on(future); #[cfg(feature="rt_nio")] - nio::runtime::Builder::new_multi_thread() + return nio::runtime::Builder::new_multi_thread() .enable_all() - .build().unwrap() + .build() + .unwrap() .block_on(future); #[cfg(feature="rt_glommio")] - glommio::LocalExecutor::default().run(future); + return glommio::LocalExecutor::default().run(future); } pub(crate) const PORT: u16 = { diff --git a/ohkami/src/request/_test_parse.rs b/ohkami/src/request/_test_parse.rs index 93c460ee2..5c4ebfe6a 100644 --- a/ohkami/src/request/_test_parse.rs +++ b/ohkami/src/request/_test_parse.rs @@ -1,7 +1,11 @@ #![cfg(all(test, feature="__rt_native__", feature="DEBUG"))] #[allow(unused)] -use super::{Request, Method, BUF_SIZE, Path, QueryParams, Context}; +use super::{Request, Method, Path, QueryParams, Context}; + +use super::{RequestHeader, RequestHeaders}; +use std::pin::Pin; +use ohkami_lib::{Slice, CowSlice}; #[test] fn parse_path() { @@ -18,48 +22,47 @@ fn parse_path() { assert_eq!(&*path, "/"); } -#[test] fn test_parse_request() { - use super::{RequestHeader, RequestHeaders}; - use std::pin::Pin; - use ohkami_lib::{Slice, CowSlice}; - - fn metadataize(input: &str) -> Box<[u8; BUF_SIZE]> { - let mut buf = [0; BUF_SIZE]; - buf[..input.len().min(BUF_SIZE)] - .copy_from_slice(&input.as_bytes()[..input.len().min(BUF_SIZE)]); - Box::new(buf) - } +macro_rules! assert_parse { + ($case:expr, $expected:expr) => { + let mut case = $case.as_bytes(); - macro_rules! assert_parse { - ($case:expr, $expected:expr) => { - let mut case = $case.as_bytes(); - - let mut actual = Request::init(crate::util::IP_0000); - let mut actual = unsafe {Pin::new_unchecked(&mut actual)}; - - crate::__rt__::testing::block_on({ - actual.as_mut().read(&mut case) - }); - - let expected = $expected; - - println!(""); - - let __panic_message = format!("\n\ - ===== actual =====\n\ - {actual:#?}\n\ - \n\ - ===== expected =====\n\ - {expected:#?}\n\ - \n\ - "); - - if actual.get_mut() != &expected { - panic!("{__panic_message}") - } - }; - } + let mut actual = Request::init(crate::util::IP_0000); + let mut actual = unsafe {Pin::new_unchecked(&mut actual)}; + + let result = crate::__rt__::testing::block_on({ + actual.as_mut().read(&mut case) + }); + + assert_eq!(result, Ok(Some(()))); + + let expected = $expected; + + println!(""); + let __panic_message = format!("\n\ + ===== actual =====\n\ + {actual:#?}\n\ + \n\ + ===== expected =====\n\ + {expected:#?}\n\ + \n\ + "); + + if actual.get_mut() != &expected { + panic!("{__panic_message}") + } + }; +} + +fn metadataize(input: &str) -> Box<[u8]> { + let buf_size = crate::CONFIG.request_bufsize(); + let mut buf = vec![0; buf_size]; + buf[..input.len().min(buf_size)] + .copy_from_slice(&input.as_bytes()[..input.len().min(buf_size)]); + buf.into_boxed_slice() +} + +#[test] fn test_parse_request() { const CASE_1: &str = "\ GET /hello.html HTTP/1.1\r\n\ User-Agent: Mozilla/4.0\r\n\ diff --git a/ohkami/src/request/mod.rs b/ohkami/src/request/mod.rs index 440294ec9..d47007f3f 100644 --- a/ohkami/src/request/mod.rs +++ b/ohkami/src/request/mod.rs @@ -34,10 +34,9 @@ use { std::borrow::Cow, }; - -#[cfg(feature="__rt_native__")] -pub(crate) const BUF_SIZE: usize = 1 << 10; #[cfg(feature="__rt_native__")] +/// reject requests having `Content-Length` larger than this limit +/// (as `413 Payload Too Large`) for resource security reason pub(crate) const PAYLOAD_LIMIT: usize = 1 << 32; /// # HTTP Request @@ -97,7 +96,7 @@ pub(crate) const PAYLOAD_LIMIT: usize = 1 << 32; /// ``` pub struct Request { #[cfg(feature="__rt_native__")] - pub(super/* for test */) __buf__: Box<[u8; BUF_SIZE]>, + pub(super/* for test */) __buf__: Box<[u8]>, #[cfg(feature="rt_worker")] pub(super/* for test */) __url__: std::mem::MaybeUninit<::worker::Url>, @@ -189,7 +188,7 @@ impl Request { ip: crate::util::IP_0000/* tetative */, #[cfg(feature="__rt_native__")] - __buf__: Box::new([0; BUF_SIZE]), + __buf__: vec![0u8; crate::CONFIG.request_bufsize()].into_boxed_slice(), #[cfg(feature="rt_worker")] __url__: std::mem::MaybeUninit::uninit(), #[cfg(feature="rt_lambda")] @@ -219,7 +218,6 @@ impl Request { } #[cfg(feature="__rt_native__")] - #[inline] pub(crate) async fn read( mut self: Pin<&mut Self>, stream: &mut (impl AsyncRead + Unpin), @@ -264,9 +262,26 @@ impl Request { while r.consume("\r\n").is_none() { let key_bytes = r.read_while(|b| b != &b':'); - r.consume(": ").ok_or_else(Response::BadRequest)?; + r.consume(": ").ok_or_else(|| { + crate::WARNING!("\ + [Request::read] Unexpected end of headers! \ + Maybe request buffer size is not enough. \ + Try to set `OHKAMI_REQUEST_BUFSIZE` to larger value \ + (default: 2048).\ + "); + Response::BadRequest() + })?; + let value = CowSlice::Ref(Slice::from_bytes(r.read_while(|b| b != &b'\r'))); - r.consume("\r\n").ok_or_else(Response::BadRequest)?; + r.consume("\r\n").ok_or_else(|| { + crate::WARNING!("\ + [Request::read] Unexpected end of headers! \ + Maybe request buffer size is not enough. \ + Try to set `OHKAMI_REQUEST_BUFSIZE` to larger value \ + (default: 2048).\ + "); + Response::BadRequest() + })?; if let Some(key) = RequestHeader::from_bytes(key_bytes) { self.headers.append(key, value); @@ -329,8 +344,8 @@ impl Request { } - #[cfg(debug_assertions/* for `ohkami::testing` */)] #[cfg(any(feature="rt_worker", feature="rt_lambda"))] + #[cfg(debug_assertions/* for `ohkami::testing` */)] /// Used in `testing` module pub(crate) async fn read(mut self: Pin<&mut Self>, raw_bytes: &mut &[u8]