1+ use futures_core:: ready;
2+ use std:: future:: Future ;
3+ use std:: pin:: Pin ;
14use std:: task:: { Context , Poll } ;
25
36use http:: { header, HeaderMap , HeaderValue , Method , Request , Response , StatusCode , Version } ;
47use hyper:: Body ;
8+ use pin_project:: pin_project;
59use tonic:: body:: { empty_body, BoxBody } ;
610use tonic:: transport:: NamedService ;
711use tower_service:: Service ;
812use tracing:: { debug, trace} ;
913
1014use crate :: call:: content_types:: is_grpc_web;
1115use crate :: call:: { Encoding , GrpcWebCall } ;
12- use crate :: { BoxError , BoxFuture } ;
16+ use crate :: BoxError ;
1317
1418const GRPC : & str = "application/grpc" ;
1519
@@ -47,13 +51,17 @@ impl<S> GrpcWebService<S>
4751where
4852 S : Service < Request < Body > , Response = Response < BoxBody > > + Send + ' static ,
4953{
50- fn response ( & self , status : StatusCode ) -> BoxFuture < S :: Response , S :: Error > {
51- Box :: pin ( async move {
52- Ok ( Response :: builder ( )
53- . status ( status)
54- . body ( empty_body ( ) )
55- . unwrap ( ) )
56- } )
54+ fn response ( & self , status : StatusCode ) -> ResponseFuture < S :: Future > {
55+ ResponseFuture {
56+ case : Case :: ImmediateResponse {
57+ res : Some (
58+ Response :: builder ( )
59+ . status ( status)
60+ . body ( empty_body ( ) )
61+ . unwrap ( ) ,
62+ ) ,
63+ } ,
64+ }
5765 }
5866}
5967
6573{
6674 type Response = S :: Response ;
6775 type Error = S :: Error ;
68- type Future = BoxFuture < Self :: Response , Self :: Error > ;
76+ type Future = ResponseFuture < S :: Future > ;
6977
7078 fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
7179 self . inner . poll_ready ( cx)
8997 } => {
9098 trace ! ( kind = "simple" , path = ?req. uri( ) . path( ) , ?encoding, ?accept) ;
9199
92- let fut = self . inner . call ( coerce_request ( req, encoding) ) ;
93- Box :: pin ( async move { Ok ( coerce_response ( fut. await ?, accept) ) } )
100+ ResponseFuture {
101+ case : Case :: GrpcWeb {
102+ future : self . inner . call ( coerce_request ( req, encoding) ) ,
103+ accept,
104+ } ,
105+ }
94106 }
95107
96108 // The request's content-type matches one of the 4 supported grpc-web
@@ -105,7 +117,11 @@ where
105117 // whatever they are.
106118 RequestKind :: Other ( Version :: HTTP_2 ) => {
107119 debug ! ( kind = "other h2" , content_type = ?req. headers( ) . get( header:: CONTENT_TYPE ) ) ;
108- Box :: pin ( self . inner . call ( req) )
120+ ResponseFuture {
121+ case : Case :: Other {
122+ future : self . inner . call ( req) ,
123+ } ,
124+ }
109125 }
110126
111127 // Return HTTP 400 for all other requests.
@@ -117,6 +133,56 @@ where
117133 }
118134}
119135
136+ /// Response future for the [`GrpcWebService`].
137+ #[ allow( missing_debug_implementations) ]
138+ #[ pin_project]
139+ #[ must_use = "futures do nothing unless polled" ]
140+ pub struct ResponseFuture < F > {
141+ #[ pin]
142+ case : Case < F > ,
143+ }
144+
145+ #[ pin_project( project = CaseProj ) ]
146+ enum Case < F > {
147+ GrpcWeb {
148+ #[ pin]
149+ future : F ,
150+ accept : Encoding ,
151+ } ,
152+ Other {
153+ #[ pin]
154+ future : F ,
155+ } ,
156+ ImmediateResponse {
157+ res : Option < Response < BoxBody > > ,
158+ } ,
159+ }
160+
161+ impl < F , E > Future for ResponseFuture < F >
162+ where
163+ F : Future < Output = Result < Response < BoxBody > , E > > + Send + ' static ,
164+ E : Into < BoxError > + Send ,
165+ {
166+ type Output = Result < Response < BoxBody > , E > ;
167+
168+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
169+ let mut this = self . project ( ) ;
170+
171+ match this. case . as_mut ( ) . project ( ) {
172+ CaseProj :: GrpcWeb { future, accept } => {
173+ let res = match ready ! ( future. poll( cx) ) {
174+ Ok ( b) => b,
175+ Err ( e) => return Poll :: Ready ( Err ( e) ) ,
176+ } ;
177+
178+ Poll :: Ready ( Ok ( coerce_response ( res, * accept) ) )
179+ }
180+ CaseProj :: Other { future } => future. poll ( cx) ,
181+ CaseProj :: ImmediateResponse { res } => Poll :: Ready ( Ok ( res. take ( ) . unwrap ( ) ) ) ,
182+ }
183+ }
184+ }
185+
120186impl < S : NamedService > NamedService for GrpcWebService < S > {
121187 const NAME : & ' static str = S :: NAME ;
122188}
@@ -177,6 +243,8 @@ mod tests {
177243 ACCESS_CONTROL_REQUEST_HEADERS , ACCESS_CONTROL_REQUEST_METHOD , CONTENT_TYPE , ORIGIN ,
178244 } ;
179245
246+ type BoxFuture < T , E > = Pin < Box < dyn Future < Output = Result < T , E > > + Send > > ;
247+
180248 #[ derive( Debug , Clone ) ]
181249 struct Svc ;
182250
0 commit comments