Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions temporalio/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions temporalio/Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,15 @@ namespace :proto do
# Calls #{class_name}.#{method.name} API call.
#
# @param request [#{method.input_type.msgclass}] API request.
# @param rpc_retry [Boolean] Whether to implicitly retry known retryable errors.
# @param rpc_metadata [Hash<String, String>, nil] Headers to include on the RPC call.
# @param rpc_timeout [Float, nil] Number of seconds before timeout.
# @param rpc_options [RPCOptions, nil] Advanced RPC options.
# @return [#{method.output_type.msgclass}] API response.
def #{rpc}(request, rpc_retry: false, rpc_metadata: nil, rpc_timeout: nil)
def #{rpc}(request, rpc_options: nil)
invoke_rpc(
rpc: '#{rpc}',
request_class: #{method.input_type.msgclass},
response_class: #{method.output_type.msgclass},
request:,
rpc_retry:,
rpc_metadata:,
rpc_timeout:
rpc_options:
)
end
TEXT
Expand Down Expand Up @@ -236,7 +232,10 @@ namespace :proto do
# Camel case to snake case
rpc = method.name.gsub(/([A-Z])/, '_\1').downcase.delete_prefix('_')
file.puts <<-TEXT
def #{rpc}: (untyped request, ?rpc_retry: bool, ?rpc_metadata: Hash[String, String]?, ?rpc_timeout: Float?) -> untyped
def #{rpc}: (
untyped request,
?rpc_options: RPCOptions?
) -> untyped
TEXT
end

Expand Down
1 change: 1 addition & 0 deletions temporalio/ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ temporal-sdk-core = { version = "0.1.0", path = "./sdk-core/core", features = ["
temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
tokio = "1.26"
tokio-util = "0.7"
tonic = "0.12"
tracing = "0.1"
url = "2.2"
103 changes: 73 additions & 30 deletions temporalio/ext/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub fn init(ruby: &Ruby) -> Result<(), Error> {
inner_class.define_method("code", method!(RpcFailure::code, 0))?;
inner_class.define_method("message", method!(RpcFailure::message, 0))?;
inner_class.define_method("details", method!(RpcFailure::details, 0))?;

let inner_class = class.define_class("CancellationToken", class::object())?;
inner_class.define_singleton_method("new", function!(CancellationToken::new, 0))?;
inner_class.define_method("cancel", method!(CancellationToken::cancel, 0))?;
Ok(())
}

Expand All @@ -58,16 +62,17 @@ pub struct Client {
#[macro_export]
macro_rules! rpc_call {
($client:ident, $callback:ident, $call:ident, $trait:tt, $call_name:ident) => {{
let cancel_token = $call.cancel_token.clone();
if $call.retry {
let mut core_client = $client.core.clone();
let req = $call.into_request()?;
$crate::client::rpc_resp($client, $callback, async move {
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
$trait::$call_name(&mut core_client, req).await
})
} else {
let mut core_client = $client.core.clone().into_inner();
let req = $call.into_request()?;
$crate::client::rpc_resp($client, $callback, async move {
$crate::client::rpc_resp($client, $callback, cancel_token, async move {
$trait::$call_name(&mut core_client, req).await
})
}
Expand Down Expand Up @@ -176,39 +181,43 @@ impl Client {

pub fn async_invoke_rpc(&self, args: &[Value]) -> Result<(), Error> {
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
let (service, rpc, request, retry, metadata, timeout, queue) = scan_args::get_kwargs::<
_,
(
u8,
String,
RString,
bool,
Option<HashMap<String, String>>,
Option<f64>,
Value,
),
(),
(),
>(
args.keywords,
&[
id!("service"),
id!("rpc"),
id!("request"),
id!("rpc_retry"),
id!("rpc_metadata"),
id!("rpc_timeout"),
id!("queue"),
],
&[],
)?
.required;
let (service, rpc, request, retry, metadata, timeout, cancel_token, queue) =
scan_args::get_kwargs::<
_,
(
u8,
String,
RString,
bool,
Option<HashMap<String, String>>,
Option<f64>,
Option<&CancellationToken>,
Value,
),
(),
(),
>(
args.keywords,
&[
id!("service"),
id!("rpc"),
id!("request"),
id!("rpc_retry"),
id!("rpc_metadata"),
id!("rpc_timeout"),
id!("rpc_cancellation_token"),
id!("queue"),
],
&[],
)?
.required;
let call = RpcCall {
rpc,
request: unsafe { request.as_slice() },
retry,
metadata,
timeout,
cancel_token: cancel_token.map(|c| c.token.clone()),
_not_send_sync: PhantomData,
};
let callback = AsyncCallback::from_queue(queue);
Expand Down Expand Up @@ -249,6 +258,7 @@ pub(crate) struct RpcCall<'a> {
pub retry: bool,
pub metadata: Option<HashMap<String, String>>,
pub timeout: Option<f64>,
pub cancel_token: Option<tokio_util::sync::CancellationToken>,

// This RPC call contains an unsafe reference to Ruby bytes that does not
// outlive the call, so we prevent it from being sent to another thread.
Expand Down Expand Up @@ -280,14 +290,25 @@ impl RpcCall<'_> {
pub(crate) fn rpc_resp<P>(
client: &Client,
callback: AsyncCallback,
cancel_token: Option<tokio_util::sync::CancellationToken>,
fut: impl Future<Output = Result<tonic::Response<P>, tonic::Status>> + Send + 'static,
) -> Result<(), Error>
where
P: prost::Message,
P: Default,
{
client.runtime_handle.spawn(
async move { fut.await.map(|msg| msg.get_ref().encode_to_vec()) },
async move {
let res = if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => Err(tonic::Status::new(tonic::Code::Cancelled, "<__user_canceled__>")),
v = fut => v,
}
} else {
fut.await
};
res.map(|msg| msg.get_ref().encode_to_vec())
},
move |_, result| {
match result {
// TODO(cretz): Any reasonable way to prevent byte copy that is just going to get decoded into proto
Expand All @@ -299,3 +320,25 @@ where
);
Ok(())
}

#[derive(DataTypeFunctions, TypedData)]
#[magnus(
class = "Temporalio::Internal::Bridge::Client::CancellationToken",
free_immediately
)]
pub struct CancellationToken {
pub(crate) token: tokio_util::sync::CancellationToken,
}

impl CancellationToken {
pub fn new() -> Result<Self, Error> {
Ok(Self {
token: tokio_util::sync::CancellationToken::new(),
})
}

pub fn cancel(&self) -> Result<(), Error> {
self.token.cancel();
Ok(())
}
}
Loading
Loading