Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions ohkami/src/ohkami/mod.rs
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,182 @@ mod test {
is_send_sync_static(o);
}
}

#[cfg(test)]
#[cfg(feature = "__rt_native__")]
mod nested_fang_regression_test {
use crate::claw::status;
use crate::fang::{Context, FangAction};
use crate::testing::{Status, TestRequest, Tester};
use crate::{Ohkami, Request, Response, Route};

#[derive(Clone, Debug, PartialEq, Eq)]
struct Principal(&'static str);

#[derive(Clone)]
struct ParentAuthFang;

impl FangAction for ParentAuthFang {
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
match req.headers.authorization() {
Some("Bearer ops-token") => {
req.context.set(Principal("ops-user"));
Ok(())
}
_ => Err(Response::Unauthorized()),
}
}
}

#[derive(Clone)]
struct OpsAuthorizationFang;

impl FangAction for OpsAuthorizationFang {
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
match req.context.get::<Principal>() {
Some(Principal("ops-user")) => Ok(()),
_ => Err(Response::Unauthorized()),
}
}
}

async fn routing_health_handler() -> &'static str {
"health"
}

async fn routing_override_set_handler() -> &'static str {
"set"
}

async fn routing_override_clear_handler() -> status::NoContent {
status::NoContent
}

async fn metrics_handler() -> &'static str {
"metrics"
}

async fn accounting_reconciliation_handler() -> &'static str {
"reconciliation"
}

#[test]
fn parent_context_auth_is_visible_to_nested_top_level_fang_in_realistic_order() {
crate::__rt__::testing::block_on(async {
let ops_routes = Ohkami::new((
OpsAuthorizationFang,
"/routing/health".GET(routing_health_handler),
"/routing/override".POST(routing_override_set_handler),
"/routing/override".DELETE(routing_override_clear_handler),
"/metrics".GET(metrics_handler),
"/accounting/reconciliation".GET(accounting_reconciliation_handler),
));

let protected_routes = Ohkami::new((ParentAuthFang, "/ops".By(ops_routes)));

let app = Ohkami::new((Context::new(()), "/api".By(protected_routes)));

let tester = app.test();

let health_res = tester
.oneshot(
TestRequest::GET("/api/ops/routing/health")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(health_res.status(), Status::OK);

let set_override_res = tester
.oneshot(
TestRequest::POST("/api/ops/routing/override")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(set_override_res.status(), Status::OK);

let clear_override_res = tester
.oneshot(
TestRequest::DELETE("/api/ops/routing/override")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(clear_override_res.status(), Status::NoContent);

let metrics_res = tester
.oneshot(
TestRequest::GET("/api/ops/metrics")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(metrics_res.status(), Status::OK);

let accounting_res = tester
.oneshot(
TestRequest::GET("/api/ops/accounting/reconciliation")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(accounting_res.status(), Status::OK);
});
}

#[test]
fn parent_context_auth_is_visible_to_nested_local_route_fangs_in_realistic_order() {
crate::__rt__::testing::block_on(async {
let ops_routes = Ohkami::new((
"/routing/health".GET((OpsAuthorizationFang, routing_health_handler)),
"/routing/override".POST((OpsAuthorizationFang, routing_override_set_handler)),
"/routing/override".DELETE((OpsAuthorizationFang, routing_override_clear_handler)),
"/metrics".GET((OpsAuthorizationFang, metrics_handler)),
"/accounting/reconciliation"
.GET((OpsAuthorizationFang, accounting_reconciliation_handler)),
));

let protected_routes = Ohkami::new((ParentAuthFang, "/ops".By(ops_routes)));

let app = Ohkami::new((Context::new(()), "/api".By(protected_routes)));

let tester = app.test();

let health_res = tester
.oneshot(
TestRequest::GET("/api/ops/routing/health")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(health_res.status(), Status::OK);

let set_override_res = tester
.oneshot(
TestRequest::POST("/api/ops/routing/override")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(set_override_res.status(), Status::OK);

let clear_override_res = tester
.oneshot(
TestRequest::DELETE("/api/ops/routing/override")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(clear_override_res.status(), Status::NoContent);

let metrics_res = tester
.oneshot(
TestRequest::GET("/api/ops/metrics")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(metrics_res.status(), Status::OK);

let accounting_res = tester
.oneshot(
TestRequest::GET("/api/ops/accounting/reconciliation")
.header("Authorization", "Bearer ops-token"),
)
.await;
assert_eq!(accounting_res.status(), Status::OK);
});
}
}
58 changes: 26 additions & 32 deletions ohkami/src/router/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,42 @@ impl FangsList {
Self(Vec::new())
}

fn add(&mut self, id: ID, fangs: Arc<dyn Fangs>) {
fn add_inner(&mut self, id: ID, fangs: Arc<dyn Fangs>) {
if self.0.iter().all(|(_id, _)| *_id != id) {
self.0.push((id, fangs));
}
}
pub(super) fn append(&mut self, another: Self) {
for (id, fangs) in another.0.into_iter() {
self.add(id, fangs)
fn add_outer(&mut self, id: ID, fangs: Arc<dyn Fangs>) {
if self.0.iter().all(|(_id, _)| *_id != id) {
self.0.insert(0, (id, fangs));
}
}

/// yield from most inner fangs
fn into_iter(self) -> impl Iterator<Item = Arc<dyn Fangs>> {
self.0.into_iter().map(|(_, fangs)| fangs)
pub(super) fn append_inner(&mut self, another: Self) {
for (id, fangs) in another.0.into_iter() {
self.add_inner(id, fangs);
}
}

pub(super) fn into_proc_with(self, h: Handler) -> IntoProcWith {
let mut iter = self.into_iter();

#[cfg(not(feature = "openapi"))]
match iter.next() {
None => h.proc,
Some(most_inner) => {
iter.fold(most_inner.build(h.proc), |proc, fangs| fangs.build(proc))
}
{
self.0
.into_iter()
.rfold(h.proc, |proc, (_, most_inner_fangs)| {
most_inner_fangs.build(proc)
})
}
#[cfg(feature = "openapi")]
match iter.next() {
None => (h.proc, h.openapi_operation),
Some(most_inner) => iter.fold(
(
most_inner.build(h.proc),
most_inner.openapi_map_operation(h.openapi_operation),
),
|(proc, operation), fangs| {
(fangs.build(proc), fangs.openapi_map_operation(operation))
{
self.0.into_iter().rfold(
(h.proc, h.openapi_operation),
|(proc, op), (_, most_inner_fangs)| {
(
most_inner_fangs.build(proc),
most_inner_fangs.openapi_map_operation(op),
)
},
),
)
}
}
}
Expand Down Expand Up @@ -362,10 +360,6 @@ impl Node {
}
}

fn append_fangs(&mut self, fangs: FangsList) {
self.fangses.append(fangs);
}

fn set_handler(&mut self, new_handler: Handler, allow_override: bool) -> Result<(), String> {
if self.handler.is_some() && !allow_override {
return Err(format!("Conflicting handler registering"));
Expand Down Expand Up @@ -414,7 +408,7 @@ impl Node {
panic!("Unexpectedly called `Node::merge_here` where `another_root` is not root node")
};

self.append_fangs(another_root_fangses);
self.fangses.append_inner(another_root_fangses);

if let Some(h) = another_root_handler {
self.set_handler(h, allow_override_handler)?;
Expand All @@ -432,10 +426,10 @@ impl Node {
for child in &mut self.children {
child.apply_fangs(id, fangs.clone())
}

// Add even when `self.handler.is_none()`. They are used later
// for applying to `Handler::default_notfound`s in `finalize`.
self.fangses.add(id, fangs);
// This `fangses` must be added by `_outer` to *wrap* existing fangs.
self.fangses.add_outer(id, fangs);
}
}

Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/router/final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ const _: (/* conversions */) = {
let child = base.children.pop().unwrap(/* base.children.len() == 1 */);
base.children = child.children;
base.handler = child.handler;
base.fangses.append(child.fangses);
base.fangses.append_inner(child.fangses);
base.pattern = Some(match base.pattern {
None => child.pattern.unwrap(/* not root */),
Some(p) => p.merge_statics(child.pattern.unwrap(/* not root */)).unwrap(/* both are Pattern::Static */)
Expand Down
Loading