@@ -14,10 +14,9 @@ use irpc::{
1414 MAX_MESSAGE_SIZE ,
1515 } ,
1616 util:: AsyncReadVarintExt ,
17- LocalSender , RequestError ,
17+ LocalSender , RequestError , Service ,
1818} ;
1919use n0_future:: { future:: Boxed as BoxFuture , TryFutureExt } ;
20- use serde:: de:: DeserializeOwned ;
2120use tracing:: { debug, error_span, trace, trace_span, warn, Instrument } ;
2221
2322/// Returns a client that connects to a irpc service using an [`iroh::Endpoint`].
@@ -102,8 +101,8 @@ async fn connect_and_open_bi(
102101/// A [`ProtocolHandler`] for an irpc protocol.
103102///
104103/// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string.
105- pub struct IrohProtocol < R > {
106- handler : Handler < R > ,
104+ pub struct IrohProtocol < S > {
105+ handler : Handler < S > ,
107106 request_id : AtomicU64 ,
108107}
109108
@@ -113,25 +112,25 @@ impl<T> fmt::Debug for IrohProtocol<T> {
113112 }
114113}
115114
116- impl < R : DeserializeOwned + Send + ' static > IrohProtocol < R > {
117- pub fn with_sender ( local_sender : impl Into < LocalSender < R > > ) -> Self
115+ impl < S : Service > IrohProtocol < S > {
116+ pub fn with_sender ( local_sender : impl Into < LocalSender < S > > ) -> Self
118117 where
119- R : RemoteService ,
118+ S : RemoteService ,
120119 {
121- let handler = R :: remote_handler ( local_sender. into ( ) ) ;
120+ let handler = S :: remote_handler ( local_sender. into ( ) ) ;
122121 Self :: new ( handler)
123122 }
124123
125124 /// Creates a new [`IrohProtocol`] for the `handler`.
126- pub fn new ( handler : Handler < R > ) -> Self {
125+ pub fn new ( handler : Handler < S > ) -> Self {
127126 Self {
128127 handler,
129128 request_id : Default :: default ( ) ,
130129 }
131130 }
132131}
133132
134- impl < R : DeserializeOwned + Send + ' static > ProtocolHandler for IrohProtocol < R > {
133+ impl < S : Service > ProtocolHandler for IrohProtocol < S > {
135134 fn accept (
136135 & self ,
137136 connection : Connection ,
@@ -140,7 +139,7 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for IrohProtocol<R> {
140139 let request_id = self
141140 . request_id
142141 . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: AcqRel ) ;
143- let fut = handle_connection ( connection, handler) . map_err ( AcceptError :: from_err) ;
142+ let fut = handle_connection :: < S > ( connection, handler) . map_err ( AcceptError :: from_err) ;
144143 let span = trace_span ! ( "rpc" , id = request_id) ;
145144 Box :: pin ( fut. instrument ( span) )
146145 }
@@ -151,8 +150,8 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for IrohProtocol<R> {
151150/// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string.
152151///
153152/// For details about when it is safe to use 0rtt, see https://www.iroh.computer/blog/0rtt-api
154- pub struct Iroh0RttProtocol < R > {
155- handler : Handler < R > ,
153+ pub struct Iroh0RttProtocol < S > {
154+ handler : Handler < S > ,
156155 request_id : AtomicU64 ,
157156}
158157
@@ -162,25 +161,25 @@ impl<T> fmt::Debug for Iroh0RttProtocol<T> {
162161 }
163162}
164163
165- impl < R : DeserializeOwned + Send + ' static > Iroh0RttProtocol < R > {
166- pub fn with_sender ( local_sender : impl Into < LocalSender < R > > ) -> Self
164+ impl < S : Service > Iroh0RttProtocol < S > {
165+ pub fn with_sender ( local_sender : impl Into < LocalSender < S > > ) -> Self
167166 where
168- R : RemoteService ,
167+ S : RemoteService ,
169168 {
170- let handler = R :: remote_handler ( local_sender. into ( ) ) ;
169+ let handler = S :: remote_handler ( local_sender. into ( ) ) ;
171170 Self :: new ( handler)
172171 }
173172
174173 /// Creates a new [`Iroh0RttProtocol`] for the `handler`.
175- pub fn new ( handler : Handler < R > ) -> Self {
174+ pub fn new ( handler : Handler < S > ) -> Self {
176175 Self {
177176 handler,
178177 request_id : Default :: default ( ) ,
179178 }
180179 }
181180}
182181
183- impl < R : DeserializeOwned + Send + ' static > ProtocolHandler for Iroh0RttProtocol < R > {
182+ impl < S : Service > ProtocolHandler for Iroh0RttProtocol < S > {
184183 async fn on_connecting ( & self , connecting : Connecting ) -> Result < Connection , AcceptError > {
185184 let ( conn, _zero_rtt_accepted) = connecting
186185 . into_0rtt ( )
@@ -196,29 +195,34 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for Iroh0RttProtocol<
196195 let request_id = self
197196 . request_id
198197 . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: AcqRel ) ;
199- let fut = handle_connection ( connection, handler) . map_err ( AcceptError :: from_err) ;
198+ let fut = handle_connection :: < S > ( connection, handler) . map_err ( AcceptError :: from_err) ;
200199 let span = trace_span ! ( "rpc" , id = request_id) ;
201200 Box :: pin ( fut. instrument ( span) )
202201 }
203202}
204203
205204/// Handles a single iroh connection with the provided `handler`.
206- pub async fn handle_connection < R : DeserializeOwned + ' static > (
205+ ///
206+ /// The wire format used depends on `S::SPAN_PROPAGATION` - if true, span context is expected.
207+ pub async fn handle_connection < S : Service > (
207208 connection : Connection ,
208- handler : Handler < R > ,
209+ handler : Handler < S > ,
209210) -> io:: Result < ( ) > {
210211 if let Ok ( remote) = connection. remote_id ( ) {
211212 tracing:: Span :: current ( ) . record ( "remote" , tracing:: field:: display ( remote. fmt_short ( ) ) ) ;
212213 }
213214 debug ! ( "connection accepted" ) ;
214215 loop {
215- let Some ( ( msg, rx, tx) ) = read_request_raw ( & connection) . await ? else {
216+ let Some ( ( msg, rx, tx) ) = read_request_raw :: < S > ( & connection) . await ? else {
216217 return Ok ( ( ) ) ;
217218 } ;
218219 handler ( msg, rx, tx) . await ?;
219220 }
220221}
221222
223+ /// Reads a request from a connection and converts it to a message enum.
224+ ///
225+ /// This combines `read_request_raw` with `RemoteService::with_remote_channels`.
222226pub async fn read_request < S : RemoteService > (
223227 connection : & Connection ,
224228) -> std:: io:: Result < Option < S :: Message > > {
@@ -231,12 +235,16 @@ pub async fn read_request<S: RemoteService>(
231235///
232236/// This accepts a bi-directional stream from the connection and reads and parses the request.
233237///
238+ /// The wire format used depends on `S::SPAN_PROPAGATION`:
239+ /// - When `true`: expects `(Option<SpanContextCarrier>, Message)` tuple format
240+ /// - When `false`: expects plain `Message` format
241+ ///
234242/// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
235243/// Returns None if the remote closed the connection with error code `0`.
236244/// Returns an error for all other failure cases.
237- pub async fn read_request_raw < R : DeserializeOwned + ' static > (
245+ pub async fn read_request_raw < S : Service > (
238246 connection : & Connection ,
239- ) -> std:: io:: Result < Option < ( R , RecvStream , SendStream ) > > {
247+ ) -> std:: io:: Result < Option < ( S , RecvStream , SendStream ) > > {
240248 let ( send, mut recv) = match connection. accept_bi ( ) . await {
241249 Ok ( ( s, r) ) => ( s, r) ,
242250 Err ( ConnectionError :: ApplicationClosed ( cause) ) if cause. error_code . into_inner ( ) == 0 => {
@@ -264,24 +272,31 @@ pub async fn read_request_raw<R: DeserializeOwned + 'static>(
264272 . await
265273 . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: UnexpectedEof , e) ) ?;
266274
267- // Deserialize the payload which includes optional span context
268- // irpc-iroh uses irpc with default features, which include spans and rpc,
269- // so span_propagation module always exists
270- let ( span_ctx , msg ) : ( Option < irpc :: span_propagation :: SpanContextCarrier > , R ) =
271- postcard :: from_bytes ( & buf ) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
275+ // Deserialize based on S::SPAN_PROPAGATION
276+ let msg : S = if S :: SPAN_PROPAGATION {
277+ let ( span_ctx , msg ) : ( Option < irpc :: span_propagation :: SpanContextCarrier > , S ) =
278+ postcard :: from_bytes ( & buf )
279+ . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
272280
273- // Store span context in thread-local for use by with_remote_channels
274- if let Some ( ctx) = span_ctx {
275- ctx. store_in_thread_local ( ) ;
276- }
281+ // Store span context in thread-local for use by with_remote_channels
282+ if let Some ( ctx) = span_ctx {
283+ ctx. store_in_thread_local ( ) ;
284+ }
285+
286+ msg
287+ } else {
288+ postcard:: from_bytes ( & buf) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?
289+ } ;
277290
278291 let rx = recv;
279292 let tx = send;
280293 Ok ( Some ( ( msg, rx, tx) ) )
281294}
282295
283- /// Utility function to listen for incoming connections and handle them with the provided handler
284- pub async fn listen < R : DeserializeOwned + ' static > ( endpoint : iroh:: Endpoint , handler : Handler < R > ) {
296+ /// Utility function to listen for incoming connections and handle them with the provided handler.
297+ ///
298+ /// The wire format used depends on `S::SPAN_PROPAGATION` - if true, span context is expected.
299+ pub async fn listen < S : Service > ( endpoint : iroh:: Endpoint , handler : Handler < S > ) {
285300 let mut request_id = 0u64 ;
286301 let mut tasks = n0_future:: task:: JoinSet :: new ( ) ;
287302 loop {
@@ -300,7 +315,7 @@ pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, han
300315 let handler = handler. clone ( ) ;
301316 let fut = async move {
302317 match incoming. await {
303- Ok ( connection) => match handle_connection ( connection, handler) . await {
318+ Ok ( connection) => match handle_connection :: < S > ( connection, handler) . await {
304319 Err ( err) => warn ! ( "connection closed with error: {err:?}" ) ,
305320 Ok ( ( ) ) => debug ! ( "connection closed" ) ,
306321 } ,
0 commit comments