diff --git a/conduit-axum/src/fallback.rs b/conduit-axum/src/fallback.rs index f0faab0a585..91bb0a3c7e6 100644 --- a/conduit-axum/src/fallback.rs +++ b/conduit-axum/src/fallback.rs @@ -4,6 +4,8 @@ use crate::file_stream::FileStream; use crate::{spawn_blocking, AxumResponse, ConduitResponse}; use std::error::Error; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use axum::body::{Body, HttpBody}; @@ -29,41 +31,60 @@ pub trait ConduitFallback { impl ConduitFallback for axum::Router { fn conduit_fallback(self, handler: impl Handler) -> Self { - let handler: Arc = Arc::new(handler); - self.fallback(fallback_to_conduit.layer(Extension(handler))) + self.fallback(ConduitAxumHandler(Arc::new(handler))) } } -async fn fallback_to_conduit( - handler: Extension>, - request: Request, -) -> Result { - if let Err(response) = check_content_length(&request) { - return Ok(response); - } - - let (parts, body) = request.into_parts(); - let now = StartInstant::now(); +#[derive(Debug)] +pub struct ConduitAxumHandler(pub Arc); - let full_body = hyper::body::to_bytes(body).await?; - let request = Request::from_parts(parts, full_body); - - let handler = handler.clone(); - spawn_blocking(move || { - let mut request = ConduitRequest::new(request, now); - handler - .call(&mut request) - .map(|mut response| { - if let Some(pattern) = request.mut_extensions().remove::() { - response.extensions_mut().insert(pattern); - } +impl Clone for ConduitAxumHandler { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} - conduit_into_axum(response) +impl AxumHandler<((),), S> for ConduitAxumHandler +where + S: Send + Sync + 'static, + H: Handler, +{ + type Future = Pin + Send>>; + + fn call(self, request: Request, _state: S) -> Self::Future { + Box::pin(async move { + if let Err(response) = check_content_length(&request) { + return response.into_response(); + } + + let (parts, body) = request.into_parts(); + let now = StartInstant::now(); + + let full_body = match hyper::body::to_bytes(body).await { + Ok(body) => body, + Err(err) => return server_error_response(&err), + }; + let request = Request::from_parts(parts, full_body); + + let Self(handler) = self; + spawn_blocking(move || { + let mut request = ConduitRequest::new(request, now); + handler + .call(&mut request) + .map(|mut response| { + if let Some(pattern) = request.mut_extensions().remove::() { + response.extensions_mut().insert(pattern); + } + + conduit_into_axum(response) + }) + .unwrap_or_else(|e| server_error_response(&*e)) }) - .unwrap_or_else(|e| server_error_response(&*e)) - }) - .await - .map_err(Into::into) + .await + .map_err(ServiceError::from) + .into_response() + }) + } } #[derive(Clone, Debug)]