Skip to content

Commit eb7d75a

Browse files
Bonus: removed unnecessary boxing
Signed-off-by: slinkydeveloper <[email protected]>
1 parent 5cefddf commit eb7d75a

File tree

2 files changed

+81
-16
lines changed

2 files changed

+81
-16
lines changed

tonic-web/src/lib.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,13 @@
8888
#![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")]
8989

9090
pub use layer::GrpcWebLayer;
91-
pub use service::GrpcWebService;
91+
pub use service::{GrpcWebService, ResponseFuture};
9292

9393
mod call;
9494
mod layer;
9595
mod service;
9696

9797
use http::header::HeaderName;
98-
use std::future::Future;
99-
use std::pin::Pin;
10098
use std::time::Duration;
10199
use tonic::body::BoxBody;
102100
use tower_http::cors::{AllowOrigin, Cors, CorsLayer};
@@ -110,7 +108,6 @@ const DEFAULT_ALLOW_HEADERS: [&str; 4] =
110108
["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"];
111109

112110
type BoxError = Box<dyn std::error::Error + Send + Sync>;
113-
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
114111

115112
/// Enable a tonic service to handle grpc-web requests with the default configuration.
116113
///

tonic-web/src/service.rs

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
use futures_core::ready;
2+
use std::future::Future;
3+
use std::pin::Pin;
14
use std::task::{Context, Poll};
25

36
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
47
use hyper::Body;
8+
use pin_project::pin_project;
59
use tonic::body::{empty_body, BoxBody};
610
use tonic::transport::NamedService;
711
use tower_service::Service;
812
use tracing::{debug, trace};
913

1014
use crate::call::content_types::is_grpc_web;
1115
use crate::call::{Encoding, GrpcWebCall};
12-
use crate::{BoxError, BoxFuture};
16+
use crate::BoxError;
1317

1418
const GRPC: &str = "application/grpc";
1519

@@ -47,13 +51,17 @@ impl<S> GrpcWebService<S>
4751
where
4852
S: Service<Request<Body>, Response = Response<BoxBody>> + Send + 'static,
4953
{
50-
fn response(&self, status: StatusCode) -> BoxFuture<S::Response, S::Error> {
51-
Box::pin(async move {
52-
Ok(Response::builder()
53-
.status(status)
54-
.body(empty_body())
55-
.unwrap())
56-
})
54+
fn response(&self, status: StatusCode) -> ResponseFuture<S::Future> {
55+
ResponseFuture {
56+
case: Case::ImmediateResponse {
57+
res: Some(
58+
Response::builder()
59+
.status(status)
60+
.body(empty_body())
61+
.unwrap(),
62+
),
63+
},
64+
}
5765
}
5866
}
5967

@@ -65,7 +73,7 @@ where
6573
{
6674
type Response = S::Response;
6775
type Error = S::Error;
68-
type Future = BoxFuture<Self::Response, Self::Error>;
76+
type Future = ResponseFuture<S::Future>;
6977

7078
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
7179
self.inner.poll_ready(cx)
@@ -89,8 +97,12 @@ where
8997
} => {
9098
trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
9199

92-
let fut = self.inner.call(coerce_request(req, encoding));
93-
Box::pin(async move { Ok(coerce_response(fut.await?, accept)) })
100+
ResponseFuture {
101+
case: Case::GrpcWeb {
102+
future: self.inner.call(coerce_request(req, encoding)),
103+
accept,
104+
},
105+
}
94106
}
95107

96108
// The request's content-type matches one of the 4 supported grpc-web
@@ -105,7 +117,11 @@ where
105117
// whatever they are.
106118
RequestKind::Other(Version::HTTP_2) => {
107119
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
108-
Box::pin(self.inner.call(req))
120+
ResponseFuture {
121+
case: Case::Other {
122+
future: self.inner.call(req),
123+
},
124+
}
109125
}
110126

111127
// Return HTTP 400 for all other requests.
@@ -117,6 +133,56 @@ where
117133
}
118134
}
119135

136+
/// Response future for the [`GrpcWebService`].
137+
#[allow(missing_debug_implementations)]
138+
#[pin_project]
139+
#[must_use = "futures do nothing unless polled"]
140+
pub struct ResponseFuture<F> {
141+
#[pin]
142+
case: Case<F>,
143+
}
144+
145+
#[pin_project(project = CaseProj)]
146+
enum Case<F> {
147+
GrpcWeb {
148+
#[pin]
149+
future: F,
150+
accept: Encoding,
151+
},
152+
Other {
153+
#[pin]
154+
future: F,
155+
},
156+
ImmediateResponse {
157+
res: Option<Response<BoxBody>>,
158+
},
159+
}
160+
161+
impl<F, E> Future for ResponseFuture<F>
162+
where
163+
F: Future<Output = Result<Response<BoxBody>, E>> + Send + 'static,
164+
E: Into<BoxError> + Send,
165+
{
166+
type Output = Result<Response<BoxBody>, E>;
167+
168+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
169+
let mut this = self.project();
170+
171+
match this.case.as_mut().project() {
172+
CaseProj::GrpcWeb { future, accept } => {
173+
let res = match ready!(future.poll(cx)) {
174+
Ok(b) => b,
175+
Err(e) => return Poll::Ready(Err(e)),
176+
};
177+
178+
Poll::Ready(Ok(coerce_response(res, *accept)))
179+
}
180+
CaseProj::Other { future } => future.poll(cx),
181+
CaseProj::ImmediateResponse { res } => Poll::Ready(Ok(res.take().unwrap())),
182+
}
183+
}
184+
}
185+
120186
impl<S: NamedService> NamedService for GrpcWebService<S> {
121187
const NAME: &'static str = S::NAME;
122188
}
@@ -177,6 +243,8 @@ mod tests {
177243
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
178244
};
179245

246+
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
247+
180248
#[derive(Debug, Clone)]
181249
struct Svc;
182250

0 commit comments

Comments
 (0)