Skip to content

conduit-axum: Use ConduitAxumHandler struct to save conduit::Handler #5798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions conduit-axum/src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::file_stream::FileStream;
use crate::{spawn_blocking, AxumResponse, ConduitResponse};

use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use axum::body::{Body, HttpBody};
Expand All @@ -29,41 +31,60 @@ pub trait ConduitFallback {

impl ConduitFallback for axum::Router {
fn conduit_fallback(self, handler: impl Handler) -> Self {
let handler: Arc<dyn Handler> = Arc::new(handler);
self.fallback(fallback_to_conduit.layer(Extension(handler)))
self.fallback(ConduitAxumHandler(Arc::new(handler)))
}
}

async fn fallback_to_conduit(
handler: Extension<Arc<dyn Handler>>,
request: Request<Body>,
) -> Result<AxumResponse, ServiceError> {
if let Err(response) = check_content_length(&request) {
return Ok(response);
}

let (parts, body) = request.into_parts();
let now = StartInstant::now();
#[derive(Debug)]
pub struct ConduitAxumHandler<H>(pub Arc<H>);

let full_body = hyper::body::to_bytes(body).await?;
let request = Request::from_parts(parts, full_body);

let handler = handler.clone();
spawn_blocking(move || {
let mut request = ConduitRequest::new(request, now);
handler
.call(&mut request)
.map(|mut response| {
if let Some(pattern) = request.mut_extensions().remove::<RoutePattern>() {
response.extensions_mut().insert(pattern);
}
impl<H> Clone for ConduitAxumHandler<H> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

conduit_into_axum(response)
impl<S, H> AxumHandler<((),), S> for ConduitAxumHandler<H>
where
S: Send + Sync + 'static,
H: Handler,
{
type Future = Pin<Box<dyn Future<Output = AxumResponse> + Send>>;

fn call(self, request: Request<Body>, _state: S) -> Self::Future {
Box::pin(async move {
if let Err(response) = check_content_length(&request) {
return response.into_response();
}

let (parts, body) = request.into_parts();
let now = StartInstant::now();

let full_body = match hyper::body::to_bytes(body).await {
Ok(body) => body,
Err(err) => return server_error_response(&err),
};
let request = Request::from_parts(parts, full_body);

let Self(handler) = self;
spawn_blocking(move || {
let mut request = ConduitRequest::new(request, now);
handler
.call(&mut request)
.map(|mut response| {
if let Some(pattern) = request.mut_extensions().remove::<RoutePattern>() {
response.extensions_mut().insert(pattern);
}

conduit_into_axum(response)
})
.unwrap_or_else(|e| server_error_response(&*e))
})
.unwrap_or_else(|e| server_error_response(&*e))
})
.await
.map_err(Into::into)
.await
.map_err(ServiceError::from)
.into_response()
})
}
}

#[derive(Clone, Debug)]
Expand Down