Skip to content

Commit 67741fc

Browse files
authored
RPC cancellation support (#174)
Fixes #168
1 parent cde31d9 commit 67741fc

26 files changed

+1109
-1276
lines changed

temporalio/Cargo.lock

Lines changed: 16 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

temporalio/Rakefile

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,15 @@ namespace :proto do
193193
# Calls #{class_name}.#{method.name} API call.
194194
#
195195
# @param request [#{method.input_type.msgclass}] API request.
196-
# @param rpc_retry [Boolean] Whether to implicitly retry known retryable errors.
197-
# @param rpc_metadata [Hash<String, String>, nil] Headers to include on the RPC call.
198-
# @param rpc_timeout [Float, nil] Number of seconds before timeout.
196+
# @param rpc_options [RPCOptions, nil] Advanced RPC options.
199197
# @return [#{method.output_type.msgclass}] API response.
200-
def #{rpc}(request, rpc_retry: false, rpc_metadata: nil, rpc_timeout: nil)
198+
def #{rpc}(request, rpc_options: nil)
201199
invoke_rpc(
202200
rpc: '#{rpc}',
203201
request_class: #{method.input_type.msgclass},
204202
response_class: #{method.output_type.msgclass},
205203
request:,
206-
rpc_retry:,
207-
rpc_metadata:,
208-
rpc_timeout:
204+
rpc_options:
209205
)
210206
end
211207
TEXT
@@ -236,7 +232,10 @@ namespace :proto do
236232
# Camel case to snake case
237233
rpc = method.name.gsub(/([A-Z])/, '_\1').downcase.delete_prefix('_')
238234
file.puts <<-TEXT
239-
def #{rpc}: (untyped request, ?rpc_retry: bool, ?rpc_metadata: Hash[String, String]?, ?rpc_timeout: Float?) -> untyped
235+
def #{rpc}: (
236+
untyped request,
237+
?rpc_options: RPCOptions?
238+
) -> untyped
240239
TEXT
241240
end
242241

temporalio/ext/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ temporal-sdk-core = { version = "0.1.0", path = "./sdk-core/core", features = ["
2020
temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
2121
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
2222
tokio = "1.26"
23+
tokio-util = "0.7"
2324
tonic = "0.12"
2425
tracing = "0.1"
2526
url = "2.2"

temporalio/ext/src/client.rs

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ pub fn init(ruby: &Ruby) -> Result<(), Error> {
4343
inner_class.define_method("code", method!(RpcFailure::code, 0))?;
4444
inner_class.define_method("message", method!(RpcFailure::message, 0))?;
4545
inner_class.define_method("details", method!(RpcFailure::details, 0))?;
46+
47+
let inner_class = class.define_class("CancellationToken", class::object())?;
48+
inner_class.define_singleton_method("new", function!(CancellationToken::new, 0))?;
49+
inner_class.define_method("cancel", method!(CancellationToken::cancel, 0))?;
4650
Ok(())
4751
}
4852

@@ -58,16 +62,17 @@ pub struct Client {
5862
#[macro_export]
5963
macro_rules! rpc_call {
6064
($client:ident, $callback:ident, $call:ident, $trait:tt, $call_name:ident) => {{
65+
let cancel_token = $call.cancel_token.clone();
6166
if $call.retry {
6267
let mut core_client = $client.core.clone();
6368
let req = $call.into_request()?;
64-
$crate::client::rpc_resp($client, $callback, async move {
69+
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
6570
$trait::$call_name(&mut core_client, req).await
6671
})
6772
} else {
6873
let mut core_client = $client.core.clone().into_inner();
6974
let req = $call.into_request()?;
70-
$crate::client::rpc_resp($client, $callback, async move {
75+
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
7176
$trait::$call_name(&mut core_client, req).await
7277
})
7378
}
@@ -176,39 +181,43 @@ impl Client {
176181

177182
pub fn async_invoke_rpc(&self, args: &[Value]) -> Result<(), Error> {
178183
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
179-
let (service, rpc, request, retry, metadata, timeout, queue) = scan_args::get_kwargs::<
180-
_,
181-
(
182-
u8,
183-
String,
184-
RString,
185-
bool,
186-
Option<HashMap<String, String>>,
187-
Option<f64>,
188-
Value,
189-
),
190-
(),
191-
(),
192-
>(
193-
args.keywords,
194-
&[
195-
id!("service"),
196-
id!("rpc"),
197-
id!("request"),
198-
id!("rpc_retry"),
199-
id!("rpc_metadata"),
200-
id!("rpc_timeout"),
201-
id!("queue"),
202-
],
203-
&[],
204-
)?
205-
.required;
184+
let (service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
185+
scan_args::get_kwargs::<
186+
_,
187+
(
188+
u8,
189+
String,
190+
RString,
191+
bool,
192+
Option<HashMap<String, String>>,
193+
Option<f64>,
194+
Option<&CancellationToken>,
195+
Value,
196+
),
197+
(),
198+
(),
199+
>(
200+
args.keywords,
201+
&[
202+
id!("service"),
203+
id!("rpc"),
204+
id!("request"),
205+
id!("rpc_retry"),
206+
id!("rpc_metadata"),
207+
id!("rpc_timeout"),
208+
id!("rpc_cancellation_token"),
209+
id!("queue"),
210+
],
211+
&[],
212+
)?
213+
.required;
206214
let call = RpcCall {
207215
rpc,
208216
request: unsafe { request.as_slice() },
209217
retry,
210218
metadata,
211219
timeout,
220+
cancel_token: cancel_token.map(|c| c.token.clone()),
212221
_not_send_sync: PhantomData,
213222
};
214223
let callback = AsyncCallback::from_queue(queue);
@@ -249,6 +258,7 @@ pub(crate) struct RpcCall<'a> {
249258
pub retry: bool,
250259
pub metadata: Option<HashMap<String, String>>,
251260
pub timeout: Option<f64>,
261+
pub cancel_token: Option<tokio_util::sync::CancellationToken>,
252262

253263
// This RPC call contains an unsafe reference to Ruby bytes that does not
254264
// outlive the call, so we prevent it from being sent to another thread.
@@ -280,14 +290,25 @@ impl RpcCall<'_> {
280290
pub(crate) fn rpc_resp<P>(
281291
client: &Client,
282292
callback: AsyncCallback,
293+
cancel_token: Option<tokio_util::sync::CancellationToken>,
283294
fut: impl Future<Output = Result<tonic::Response<P>, tonic::Status>> + Send + 'static,
284295
) -> Result<(), Error>
285296
where
286297
P: prost::Message,
287298
P: Default,
288299
{
289300
client.runtime_handle.spawn(
290-
async move { fut.await.map(|msg| msg.get_ref().encode_to_vec()) },
301+
async move {
302+
let res = if let Some(cancel_token) = cancel_token {
303+
tokio::select! {
304+
_ = cancel_token.cancelled() => Err(tonic::Status::new(tonic::Code::Cancelled, "<__user_canceled__>")),
305+
v = fut => v,
306+
}
307+
} else {
308+
fut.await
309+
};
310+
res.map(|msg| msg.get_ref().encode_to_vec())
311+
},
291312
move |_, result| {
292313
match result {
293314
// TODO(cretz): Any reasonable way to prevent byte copy that is just going to get decoded into proto
@@ -299,3 +320,25 @@ where
299320
);
300321
Ok(())
301322
}
323+
324+
#[derive(DataTypeFunctions, TypedData)]
325+
#[magnus(
326+
class = "Temporalio::Internal::Bridge::Client::CancellationToken",
327+
free_immediately
328+
)]
329+
pub struct CancellationToken {
330+
pub(crate) token: tokio_util::sync::CancellationToken,
331+
}
332+
333+
impl CancellationToken {
334+
pub fn new() -> Result<Self, Error> {
335+
Ok(Self {
336+
token: tokio_util::sync::CancellationToken::new(),
337+
})
338+
}
339+
340+
pub fn cancel(&self) -> Result<(), Error> {
341+
self.token.cancel();
342+
Ok(())
343+
}
344+
}

0 commit comments

Comments
 (0)