From fb889f60084f32dd088c868e5c36b8c8b6a876e7 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Tue, 3 Jan 2023 18:59:53 +0100 Subject: [PATCH 1/2] conduit-axum: Adjust `fallback_to_conduit()` fn to return `Response` --- conduit-axum/src/fallback.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/conduit-axum/src/fallback.rs b/conduit-axum/src/fallback.rs index f0faab0a585..9cb42b31cf3 100644 --- a/conduit-axum/src/fallback.rs +++ b/conduit-axum/src/fallback.rs @@ -37,15 +37,18 @@ impl ConduitFallback for axum::Router { async fn fallback_to_conduit( handler: Extension>, request: Request, -) -> Result { +) -> AxumResponse { if let Err(response) = check_content_length(&request) { - return Ok(response); + return response.into_response(); } let (parts, body) = request.into_parts(); let now = StartInstant::now(); - let full_body = hyper::body::to_bytes(body).await?; + let full_body = match hyper::body::to_bytes(body).await { + Ok(body) => body, + Err(err) => return server_error_response(&err), + }; let request = Request::from_parts(parts, full_body); let handler = handler.clone(); @@ -63,7 +66,8 @@ async fn fallback_to_conduit( .unwrap_or_else(|e| server_error_response(&*e)) }) .await - .map_err(Into::into) + .map_err(ServiceError::from) + .into_response() } #[derive(Clone, Debug)] From 907e4859b01ece97b4afc4aab5e8b655a470aa89 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Tue, 3 Jan 2023 18:59:54 +0100 Subject: [PATCH 2/2] conduit-axum: Use `ConduitAxumHandler` struct to save `conduit::Handler` ... instead of adding the handler to the extensions of each request. --- conduit-axum/src/fallback.rs | 83 ++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 33 deletions(-) diff --git a/conduit-axum/src/fallback.rs b/conduit-axum/src/fallback.rs index 9cb42b31cf3..91bb0a3c7e6 100644 --- a/conduit-axum/src/fallback.rs +++ b/conduit-axum/src/fallback.rs @@ -4,6 +4,8 @@ use crate::file_stream::FileStream; use crate::{spawn_blocking, AxumResponse, ConduitResponse}; use std::error::Error; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use axum::body::{Body, HttpBody}; @@ -29,45 +31,60 @@ pub trait ConduitFallback { impl ConduitFallback for axum::Router { fn conduit_fallback(self, handler: impl Handler) -> Self { - let handler: Arc = Arc::new(handler); - self.fallback(fallback_to_conduit.layer(Extension(handler))) + self.fallback(ConduitAxumHandler(Arc::new(handler))) } } -async fn fallback_to_conduit( - handler: Extension>, - request: Request, -) -> AxumResponse { - if let Err(response) = check_content_length(&request) { - return response.into_response(); +#[derive(Debug)] +pub struct ConduitAxumHandler(pub Arc); + +impl Clone for ConduitAxumHandler { + fn clone(&self) -> Self { + Self(self.0.clone()) } +} - let (parts, body) = request.into_parts(); - let now = StartInstant::now(); - - let full_body = match hyper::body::to_bytes(body).await { - Ok(body) => body, - Err(err) => return server_error_response(&err), - }; - let request = Request::from_parts(parts, full_body); - - let handler = handler.clone(); - spawn_blocking(move || { - let mut request = ConduitRequest::new(request, now); - handler - .call(&mut request) - .map(|mut response| { - if let Some(pattern) = request.mut_extensions().remove::() { - response.extensions_mut().insert(pattern); - } - - conduit_into_axum(response) +impl AxumHandler<((),), S> for ConduitAxumHandler +where + S: Send + Sync + 'static, + H: Handler, +{ + type Future = Pin + Send>>; + + fn call(self, request: Request, _state: S) -> Self::Future { + Box::pin(async move { + if let Err(response) = check_content_length(&request) { + return response.into_response(); + } + + let (parts, body) = request.into_parts(); + let now = StartInstant::now(); + + let full_body = match hyper::body::to_bytes(body).await { + Ok(body) => body, + Err(err) => return server_error_response(&err), + }; + let request = Request::from_parts(parts, full_body); + + let Self(handler) = self; + spawn_blocking(move || { + let mut request = ConduitRequest::new(request, now); + handler + .call(&mut request) + .map(|mut response| { + if let Some(pattern) = request.mut_extensions().remove::() { + response.extensions_mut().insert(pattern); + } + + conduit_into_axum(response) + }) + .unwrap_or_else(|e| server_error_response(&*e)) }) - .unwrap_or_else(|e| server_error_response(&*e)) - }) - .await - .map_err(ServiceError::from) - .into_response() + .await + .map_err(ServiceError::from) + .into_response() + }) + } } #[derive(Clone, Debug)]