Skip to content

Commit e325e1d

Browse files
committed
add propagate-span attribute to rpc_request
1 parent 1e73185 commit e325e1d

File tree

5 files changed

+212
-124
lines changed

5 files changed

+212
-124
lines changed

Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ n0-future = { workspace = true }
4646
futures-util = { workspace = true, optional = true }
4747
# for the derive reexport/feature
4848
irpc-derive = { version = "0.8.0", path = "./irpc-derive", optional = true }
49-
# for remote span propagation when both spans and rpc are enabled
49+
# for remote span propagation when use-tracing-opentelemetry feature is enabled
5050
opentelemetry = { version = "0.31", optional = true }
5151
tracing-opentelemetry = { version = "0.32", optional = true }
5252

@@ -67,14 +67,16 @@ testresult = "0.4.1"
6767

6868
[features]
6969
# enable the remote transport
70-
rpc = ["dep:quinn", "dep:postcard", "dep:anyhow", "dep:smallvec", "dep:tracing", "tokio/io-util", "dep:opentelemetry", "dep:tracing-opentelemetry", "irpc-derive?/rpc"]
70+
rpc = ["dep:quinn", "dep:postcard", "dep:anyhow", "dep:smallvec", "dep:tracing", "tokio/io-util"]
7171
# add test utilities
7272
quinn_endpoint_setup = ["rpc", "dep:rustls", "dep:rcgen", "dep:anyhow", "dep:futures-buffered", "quinn/rustls-ring"]
7373
# pick up parent span when creating channel messages
74-
spans = ["dep:tracing", "irpc-derive?/spans"]
74+
spans = ["dep:tracing"]
7575
stream = ["dep:futures-util"]
7676
derive = ["dep:irpc-derive"]
7777
varint-util = ["dep:postcard", "dep:smallvec", "tokio/io-util"]
78+
# enable OpenTelemetry span context propagation across remote connections
79+
use-tracing-opentelemetry = ["dep:opentelemetry", "dep:tracing-opentelemetry", "rpc"]
7880
default = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"]
7981

8082
[[example]]

irpc-derive/Cargo.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,4 @@ quote = "1"
1818
proc-macro2 = "1"
1919

2020
[features]
21-
# These features should match irpc's features for proper code generation
22-
spans = []
23-
rpc = []
24-
default = ["spans", "rpc"]
21+
default = []

irpc-derive/src/lib.rs

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,21 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
152152
let message_from_impls =
153153
generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name);
154154

155+
let span_propagation = args.span_propagation;
155156
let service_impl = quote! {
156157
impl ::irpc::Service for #enum_name {
157158
type Message = #message_enum_name;
159+
const SPAN_PROPAGATION: bool = #span_propagation;
158160
}
159161
};
160162

161163
let remote_service_impl = if !args.no_rpc {
162-
let block =
163-
generate_remote_service_impl(message_enum_name, enum_name, &variants_with_attr);
164+
let block = generate_remote_service_impl(
165+
message_enum_name,
166+
enum_name,
167+
&variants_with_attr,
168+
args.span_propagation,
169+
);
164170
quote! {
165171
#cfg_feature_rpc
166172
#block
@@ -280,39 +286,39 @@ fn generate_message_enum_from_impls(
280286
}
281287

282288
/// Generate `RemoteService` impl for message enums.
289+
///
290+
/// When `span_propagation` is true, the generated code will create spans for each
291+
/// request and set their parent from the propagated remote context.
283292
fn generate_remote_service_impl(
284293
message_enum_name: &Ident,
285294
proto_enum_name: &Ident,
286295
variants_with_attr: &[(Ident, Type)],
296+
span_propagation: bool,
287297
) -> TokenStream2 {
288-
// Generate match arms that set the span parent for each variant
289-
#[cfg(all(feature = "spans", feature = "rpc"))]
298+
// Generate match arms for each variant
290299
let variants = variants_with_attr
291300
.iter()
292301
.map(|(variant_name, _inner_type)| {
293302
let span_name = variant_name.to_string();
294-
quote! {
295-
#proto_enum_name::#variant_name(msg) => {
296-
// Create a span for this specific RPC operation
297-
let span = ::tracing::info_span!(#span_name);
298-
// Set its parent to the propagated remote context if available
299-
if let Some(ctx) = ::irpc::span_propagation::take_remote_span_context() {
300-
use ::irpc::span_propagation::OpenTelemetrySpanExt;
301-
let _ = span.set_parent(ctx);
303+
304+
if span_propagation {
305+
// When span_propagation is enabled, create spans and set parent from remote context
306+
quote! {
307+
#proto_enum_name::#variant_name(msg) => {
308+
// Create a span for this specific RPC operation
309+
let span = ::tracing::info_span!(#span_name);
310+
// Set its parent to the propagated remote context if available
311+
::irpc::span_propagation::set_span_parent_from_remote(&span);
312+
let _guard = span.enter();
313+
#message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
302314
}
303-
let _guard = span.enter();
304-
#message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
305315
}
306-
}
307-
});
308-
309-
#[cfg(not(all(feature = "spans", feature = "rpc")))]
310-
let variants = variants_with_attr
311-
.iter()
312-
.map(|(variant_name, _inner_type)| {
313-
quote! {
314-
#proto_enum_name::#variant_name(msg) => {
315-
#message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
316+
} else {
317+
// When span_propagation is disabled, just create the message
318+
quote! {
319+
#proto_enum_name::#variant_name(msg) => {
320+
#message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
321+
}
316322
}
317323
}
318324
});
@@ -361,6 +367,8 @@ struct MacroArgs {
361367
rpc_feature: Option<String>,
362368
no_rpc: bool,
363369
no_spans: bool,
370+
/// When true, includes span context in the wire format and enables span propagation.
371+
span_propagation: bool,
364372
}
365373

366374
impl Parse for MacroArgs {
@@ -396,6 +404,9 @@ impl Parse for MacroArgs {
396404
"no_spans" => {
397405
this.no_spans = true;
398406
}
407+
"span_propagation" => {
408+
this.span_propagation = true;
409+
}
399410
_ => {
400411
return syn_err(arg.span(), format!("Unknown parameter: {arg}"));
401412
}

irpc-iroh/src/lib.rs

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ use irpc::{
1414
MAX_MESSAGE_SIZE,
1515
},
1616
util::AsyncReadVarintExt,
17-
LocalSender, RequestError,
17+
LocalSender, RequestError, Service,
1818
};
1919
use n0_future::{future::Boxed as BoxFuture, TryFutureExt};
20-
use serde::de::DeserializeOwned;
2120
use tracing::{debug, error_span, trace, trace_span, warn, Instrument};
2221

2322
/// Returns a client that connects to a irpc service using an [`iroh::Endpoint`].
@@ -102,8 +101,8 @@ async fn connect_and_open_bi(
102101
/// A [`ProtocolHandler`] for an irpc protocol.
103102
///
104103
/// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string.
105-
pub struct IrohProtocol<R> {
106-
handler: Handler<R>,
104+
pub struct IrohProtocol<S> {
105+
handler: Handler<S>,
107106
request_id: AtomicU64,
108107
}
109108

@@ -113,25 +112,25 @@ impl<T> fmt::Debug for IrohProtocol<T> {
113112
}
114113
}
115114

116-
impl<R: DeserializeOwned + Send + 'static> IrohProtocol<R> {
117-
pub fn with_sender(local_sender: impl Into<LocalSender<R>>) -> Self
115+
impl<S: Service> IrohProtocol<S> {
116+
pub fn with_sender(local_sender: impl Into<LocalSender<S>>) -> Self
118117
where
119-
R: RemoteService,
118+
S: RemoteService,
120119
{
121-
let handler = R::remote_handler(local_sender.into());
120+
let handler = S::remote_handler(local_sender.into());
122121
Self::new(handler)
123122
}
124123

125124
/// Creates a new [`IrohProtocol`] for the `handler`.
126-
pub fn new(handler: Handler<R>) -> Self {
125+
pub fn new(handler: Handler<S>) -> Self {
127126
Self {
128127
handler,
129128
request_id: Default::default(),
130129
}
131130
}
132131
}
133132

134-
impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for IrohProtocol<R> {
133+
impl<S: Service> ProtocolHandler for IrohProtocol<S> {
135134
fn accept(
136135
&self,
137136
connection: Connection,
@@ -140,7 +139,7 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for IrohProtocol<R> {
140139
let request_id = self
141140
.request_id
142141
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
143-
let fut = handle_connection(connection, handler).map_err(AcceptError::from_err);
142+
let fut = handle_connection::<S>(connection, handler).map_err(AcceptError::from_err);
144143
let span = trace_span!("rpc", id = request_id);
145144
Box::pin(fut.instrument(span))
146145
}
@@ -151,8 +150,8 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for IrohProtocol<R> {
151150
/// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string.
152151
///
153152
/// For details about when it is safe to use 0rtt, see https://www.iroh.computer/blog/0rtt-api
154-
pub struct Iroh0RttProtocol<R> {
155-
handler: Handler<R>,
153+
pub struct Iroh0RttProtocol<S> {
154+
handler: Handler<S>,
156155
request_id: AtomicU64,
157156
}
158157

@@ -162,25 +161,25 @@ impl<T> fmt::Debug for Iroh0RttProtocol<T> {
162161
}
163162
}
164163

165-
impl<R: DeserializeOwned + Send + 'static> Iroh0RttProtocol<R> {
166-
pub fn with_sender(local_sender: impl Into<LocalSender<R>>) -> Self
164+
impl<S: Service> Iroh0RttProtocol<S> {
165+
pub fn with_sender(local_sender: impl Into<LocalSender<S>>) -> Self
167166
where
168-
R: RemoteService,
167+
S: RemoteService,
169168
{
170-
let handler = R::remote_handler(local_sender.into());
169+
let handler = S::remote_handler(local_sender.into());
171170
Self::new(handler)
172171
}
173172

174173
/// Creates a new [`Iroh0RttProtocol`] for the `handler`.
175-
pub fn new(handler: Handler<R>) -> Self {
174+
pub fn new(handler: Handler<S>) -> Self {
176175
Self {
177176
handler,
178177
request_id: Default::default(),
179178
}
180179
}
181180
}
182181

183-
impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for Iroh0RttProtocol<R> {
182+
impl<S: Service> ProtocolHandler for Iroh0RttProtocol<S> {
184183
async fn on_connecting(&self, connecting: Connecting) -> Result<Connection, AcceptError> {
185184
let (conn, _zero_rtt_accepted) = connecting
186185
.into_0rtt()
@@ -196,29 +195,34 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for Iroh0RttProtocol<
196195
let request_id = self
197196
.request_id
198197
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
199-
let fut = handle_connection(connection, handler).map_err(AcceptError::from_err);
198+
let fut = handle_connection::<S>(connection, handler).map_err(AcceptError::from_err);
200199
let span = trace_span!("rpc", id = request_id);
201200
Box::pin(fut.instrument(span))
202201
}
203202
}
204203

205204
/// Handles a single iroh connection with the provided `handler`.
206-
pub async fn handle_connection<R: DeserializeOwned + 'static>(
205+
///
206+
/// The wire format used depends on `S::SPAN_PROPAGATION` - if true, span context is expected.
207+
pub async fn handle_connection<S: Service>(
207208
connection: Connection,
208-
handler: Handler<R>,
209+
handler: Handler<S>,
209210
) -> io::Result<()> {
210211
if let Ok(remote) = connection.remote_id() {
211212
tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short()));
212213
}
213214
debug!("connection accepted");
214215
loop {
215-
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
216+
let Some((msg, rx, tx)) = read_request_raw::<S>(&connection).await? else {
216217
return Ok(());
217218
};
218219
handler(msg, rx, tx).await?;
219220
}
220221
}
221222

223+
/// Reads a request from a connection and converts it to a message enum.
224+
///
225+
/// This combines `read_request_raw` with `RemoteService::with_remote_channels`.
222226
pub async fn read_request<S: RemoteService>(
223227
connection: &Connection,
224228
) -> std::io::Result<Option<S::Message>> {
@@ -231,12 +235,16 @@ pub async fn read_request<S: RemoteService>(
231235
///
232236
/// This accepts a bi-directional stream from the connection and reads and parses the request.
233237
///
238+
/// The wire format used depends on `S::SPAN_PROPAGATION`:
239+
/// - When `true`: expects `(Option<SpanContextCarrier>, Message)` tuple format
240+
/// - When `false`: expects plain `Message` format
241+
///
234242
/// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
235243
/// Returns None if the remote closed the connection with error code `0`.
236244
/// Returns an error for all other failure cases.
237-
pub async fn read_request_raw<R: DeserializeOwned + 'static>(
245+
pub async fn read_request_raw<S: Service>(
238246
connection: &Connection,
239-
) -> std::io::Result<Option<(R, RecvStream, SendStream)>> {
247+
) -> std::io::Result<Option<(S, RecvStream, SendStream)>> {
240248
let (send, mut recv) = match connection.accept_bi().await {
241249
Ok((s, r)) => (s, r),
242250
Err(ConnectionError::ApplicationClosed(cause)) if cause.error_code.into_inner() == 0 => {
@@ -264,24 +272,31 @@ pub async fn read_request_raw<R: DeserializeOwned + 'static>(
264272
.await
265273
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
266274

267-
// Deserialize the payload which includes optional span context
268-
// irpc-iroh uses irpc with default features, which include spans and rpc,
269-
// so span_propagation module always exists
270-
let (span_ctx, msg): (Option<irpc::span_propagation::SpanContextCarrier>, R) =
271-
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
275+
// Deserialize based on S::SPAN_PROPAGATION
276+
let msg: S = if S::SPAN_PROPAGATION {
277+
let (span_ctx, msg): (Option<irpc::span_propagation::SpanContextCarrier>, S) =
278+
postcard::from_bytes(&buf)
279+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
272280

273-
// Store span context in thread-local for use by with_remote_channels
274-
if let Some(ctx) = span_ctx {
275-
ctx.store_in_thread_local();
276-
}
281+
// Store span context in thread-local for use by with_remote_channels
282+
if let Some(ctx) = span_ctx {
283+
ctx.store_in_thread_local();
284+
}
285+
286+
msg
287+
} else {
288+
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
289+
};
277290

278291
let rx = recv;
279292
let tx = send;
280293
Ok(Some((msg, rx, tx)))
281294
}
282295

283-
/// Utility function to listen for incoming connections and handle them with the provided handler
284-
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
296+
/// Utility function to listen for incoming connections and handle them with the provided handler.
297+
///
298+
/// The wire format used depends on `S::SPAN_PROPAGATION` - if true, span context is expected.
299+
pub async fn listen<S: Service>(endpoint: iroh::Endpoint, handler: Handler<S>) {
285300
let mut request_id = 0u64;
286301
let mut tasks = n0_future::task::JoinSet::new();
287302
loop {
@@ -300,7 +315,7 @@ pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, han
300315
let handler = handler.clone();
301316
let fut = async move {
302317
match incoming.await {
303-
Ok(connection) => match handle_connection(connection, handler).await {
318+
Ok(connection) => match handle_connection::<S>(connection, handler).await {
304319
Err(err) => warn!("connection closed with error: {err:?}"),
305320
Ok(()) => debug!("connection closed"),
306321
},

0 commit comments

Comments
 (0)