@@ -43,6 +43,10 @@ pub fn init(ruby: &Ruby) -> Result<(), Error> {
4343 inner_class. define_method ( "code" , method ! ( RpcFailure :: code, 0 ) ) ?;
4444 inner_class. define_method ( "message" , method ! ( RpcFailure :: message, 0 ) ) ?;
4545 inner_class. define_method ( "details" , method ! ( RpcFailure :: details, 0 ) ) ?;
46+
47+ let inner_class = class. define_class ( "CancellationToken" , class:: object ( ) ) ?;
48+ inner_class. define_singleton_method ( "new" , function ! ( CancellationToken :: new, 0 ) ) ?;
49+ inner_class. define_method ( "cancel" , method ! ( CancellationToken :: cancel, 0 ) ) ?;
4650 Ok ( ( ) )
4751}
4852
@@ -58,16 +62,17 @@ pub struct Client {
5862#[ macro_export]
5963macro_rules! rpc_call {
6064 ( $client: ident, $callback: ident, $call: ident, $trait: tt, $call_name: ident) => { {
65+ let cancel_token = $call. cancel_token. clone( ) ;
6166 if $call. retry {
6267 let mut core_client = $client. core. clone( ) ;
6368 let req = $call. into_request( ) ?;
64- $crate:: client:: rpc_resp( $client, $callback, async move {
69+ $crate:: client:: rpc_resp( $client, $callback, cancel_token , async move {
6570 $trait:: $call_name( & mut core_client, req) . await
6671 } )
6772 } else {
6873 let mut core_client = $client. core. clone( ) . into_inner( ) ;
6974 let req = $call. into_request( ) ?;
70- $crate:: client:: rpc_resp( $client, $callback, async move {
75+ $crate:: client:: rpc_resp( $client, $callback, cancel_token , async move {
7176 $trait:: $call_name( & mut core_client, req) . await
7277 } )
7378 }
@@ -176,39 +181,43 @@ impl Client {
176181
177182 pub fn async_invoke_rpc ( & self , args : & [ Value ] ) -> Result < ( ) , Error > {
178183 let args = scan_args:: scan_args :: < ( ) , ( ) , ( ) , ( ) , _ , ( ) > ( args) ?;
179- let ( service, rpc, request, retry, metadata, timeout, queue) = scan_args:: get_kwargs :: <
180- _ ,
181- (
182- u8 ,
183- String ,
184- RString ,
185- bool ,
186- Option < HashMap < String , String > > ,
187- Option < f64 > ,
188- Value ,
189- ) ,
190- ( ) ,
191- ( ) ,
192- > (
193- args. keywords ,
194- & [
195- id ! ( "service" ) ,
196- id ! ( "rpc" ) ,
197- id ! ( "request" ) ,
198- id ! ( "rpc_retry" ) ,
199- id ! ( "rpc_metadata" ) ,
200- id ! ( "rpc_timeout" ) ,
201- id ! ( "queue" ) ,
202- ] ,
203- & [ ] ,
204- ) ?
205- . required ;
184+ let ( service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
185+ scan_args:: get_kwargs :: <
186+ _ ,
187+ (
188+ u8 ,
189+ String ,
190+ RString ,
191+ bool ,
192+ Option < HashMap < String , String > > ,
193+ Option < f64 > ,
194+ Option < & CancellationToken > ,
195+ Value ,
196+ ) ,
197+ ( ) ,
198+ ( ) ,
199+ > (
200+ args. keywords ,
201+ & [
202+ id ! ( "service" ) ,
203+ id ! ( "rpc" ) ,
204+ id ! ( "request" ) ,
205+ id ! ( "rpc_retry" ) ,
206+ id ! ( "rpc_metadata" ) ,
207+ id ! ( "rpc_timeout" ) ,
208+ id ! ( "rpc_cancellation_token" ) ,
209+ id ! ( "queue" ) ,
210+ ] ,
211+ & [ ] ,
212+ ) ?
213+ . required ;
206214 let call = RpcCall {
207215 rpc,
208216 request : unsafe { request. as_slice ( ) } ,
209217 retry,
210218 metadata,
211219 timeout,
220+ cancel_token : cancel_token. map ( |c| c. token . clone ( ) ) ,
212221 _not_send_sync : PhantomData ,
213222 } ;
214223 let callback = AsyncCallback :: from_queue ( queue) ;
@@ -249,6 +258,7 @@ pub(crate) struct RpcCall<'a> {
249258 pub retry : bool ,
250259 pub metadata : Option < HashMap < String , String > > ,
251260 pub timeout : Option < f64 > ,
261+ pub cancel_token : Option < tokio_util:: sync:: CancellationToken > ,
252262
253263 // This RPC call contains an unsafe reference to Ruby bytes that does not
254264 // outlive the call, so we prevent it from being sent to another thread.
@@ -280,14 +290,25 @@ impl RpcCall<'_> {
280290pub ( crate ) fn rpc_resp < P > (
281291 client : & Client ,
282292 callback : AsyncCallback ,
293+ cancel_token : Option < tokio_util:: sync:: CancellationToken > ,
283294 fut : impl Future < Output = Result < tonic:: Response < P > , tonic:: Status > > + Send + ' static ,
284295) -> Result < ( ) , Error >
285296where
286297 P : prost:: Message ,
287298 P : Default ,
288299{
289300 client. runtime_handle . spawn (
290- async move { fut. await . map ( |msg| msg. get_ref ( ) . encode_to_vec ( ) ) } ,
301+ async move {
302+ let res = if let Some ( cancel_token) = cancel_token {
303+ tokio:: select! {
304+ _ = cancel_token. cancelled( ) => Err ( tonic:: Status :: new( tonic:: Code :: Cancelled , "<__user_canceled__>" ) ) ,
305+ v = fut => v,
306+ }
307+ } else {
308+ fut. await
309+ } ;
310+ res. map ( |msg| msg. get_ref ( ) . encode_to_vec ( ) )
311+ } ,
291312 move |_, result| {
292313 match result {
293314 // TODO(cretz): Any reasonable way to prevent byte copy that is just going to get decoded into proto
@@ -299,3 +320,25 @@ where
299320 ) ;
300321 Ok ( ( ) )
301322}
323+
324+ #[ derive( DataTypeFunctions , TypedData ) ]
325+ #[ magnus(
326+ class = "Temporalio::Internal::Bridge::Client::CancellationToken" ,
327+ free_immediately
328+ ) ]
329+ pub struct CancellationToken {
330+ pub ( crate ) token : tokio_util:: sync:: CancellationToken ,
331+ }
332+
333+ impl CancellationToken {
334+ pub fn new ( ) -> Result < Self , Error > {
335+ Ok ( Self {
336+ token : tokio_util:: sync:: CancellationToken :: new ( ) ,
337+ } )
338+ }
339+
340+ pub fn cancel ( & self ) -> Result < ( ) , Error > {
341+ self . token . cancel ( ) ;
342+ Ok ( ( ) )
343+ }
344+ }
0 commit comments