Skip to content

Interrupt heartbeating activity on pause #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
overload,
)

import temporalio.bridge
import temporalio.bridge.proto
import temporalio.bridge.proto.activity_task
import temporalio.common
import temporalio.converter

Expand Down Expand Up @@ -135,6 +138,34 @@ def _logger_details(self) -> Mapping[str, Any]:
_current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity")


@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be frozen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We mutate the fields in this object to reflect changes across running activity & _context

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to frozen class + holder

class _ActivityCancellationDetailsHolder:
details: Optional[ActivityCancellationDetails] = None


@dataclass(frozen=True)
class ActivityCancellationDetails:
"""Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set."""

not_found: bool = False
cancel_requested: bool = False
paused: bool = False
timed_out: bool = False
worker_shutdown: bool = False

@staticmethod
def _from_proto(
proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails,
) -> ActivityCancellationDetails:
return ActivityCancellationDetails(
not_found=proto.is_not_found,
cancel_requested=proto.is_cancelled,
paused=proto.is_paused,
timed_out=proto.is_timed_out,
worker_shutdown=proto.is_worker_shutdown,
)


@dataclass
class _Context:
info: Callable[[], Info]
Expand All @@ -148,6 +179,7 @@ class _Context:
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
cancellation_details: _ActivityCancellationDetailsHolder
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
_metric_meter: Optional[temporalio.common.MetricMeter] = None
Expand Down Expand Up @@ -260,6 +292,11 @@ def info() -> Info:
return _Context.current().info()


def cancellation_details() -> Optional[ActivityCancellationDetails]:
"""Cancellation details of the current activity, if any. Once set, cancellation details do not change."""
return _Context.current().cancellation_details.details


def heartbeat(*details: Any) -> None:
"""Send a heartbeat for the current activity.

Expand Down
9 changes: 8 additions & 1 deletion temporalio/bridge/proto/activity_task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .activity_task_pb2 import ActivityCancelReason, ActivityTask, Cancel, Start
from .activity_task_pb2 import (
ActivityCancellationDetails,
ActivityCancelReason,
ActivityTask,
Cancel,
Start,
)

__all__ = [
"ActivityCancelReason",
"ActivityCancellationDetails",
"ActivityTask",
"Cancel",
"Start",
Expand Down
27 changes: 22 additions & 5 deletions temporalio/bridge/proto/activity_task/activity_task_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 57 additions & 1 deletion temporalio/bridge/proto/activity_task/activity_task_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class _ActivityCancelReasonEnumTypeWrapper(
"""Activity timed out"""
WORKER_SHUTDOWN: _ActivityCancelReason.ValueType # 3
"""Core is shutting down and the graceful timeout has elapsed"""
PAUSED: _ActivityCancelReason.ValueType # 4
"""Activity was paused"""

class ActivityCancelReason(
_ActivityCancelReason, metaclass=_ActivityCancelReasonEnumTypeWrapper
Expand All @@ -58,6 +60,8 @@ TIMED_OUT: ActivityCancelReason.ValueType # 2
"""Activity timed out"""
WORKER_SHUTDOWN: ActivityCancelReason.ValueType # 3
"""Core is shutting down and the graceful timeout has elapsed"""
PAUSED: ActivityCancelReason.ValueType # 4
"""Activity was paused"""
global___ActivityCancelReason = ActivityCancelReason

class ActivityTask(google.protobuf.message.Message):
Expand Down Expand Up @@ -320,14 +324,66 @@ class Cancel(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

REASON_FIELD_NUMBER: builtins.int
DETAILS_FIELD_NUMBER: builtins.int
reason: global___ActivityCancelReason.ValueType
"""Primary cancellation reason"""
@property
def details(self) -> global___ActivityCancellationDetails:
"""Activity cancellation details, surfaces all cancellation reasons."""
def __init__(
self,
*,
reason: global___ActivityCancelReason.ValueType = ...,
details: global___ActivityCancellationDetails | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["details", b"details"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["reason", b"reason"]
self,
field_name: typing_extensions.Literal[
"details", b"details", "reason", b"reason"
],
) -> None: ...

global___Cancel = Cancel

class ActivityCancellationDetails(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

IS_NOT_FOUND_FIELD_NUMBER: builtins.int
IS_CANCELLED_FIELD_NUMBER: builtins.int
IS_PAUSED_FIELD_NUMBER: builtins.int
IS_TIMED_OUT_FIELD_NUMBER: builtins.int
IS_WORKER_SHUTDOWN_FIELD_NUMBER: builtins.int
is_not_found: builtins.bool
is_cancelled: builtins.bool
is_paused: builtins.bool
is_timed_out: builtins.bool
is_worker_shutdown: builtins.bool
def __init__(
self,
*,
is_not_found: builtins.bool = ...,
is_cancelled: builtins.bool = ...,
is_paused: builtins.bool = ...,
is_timed_out: builtins.bool = ...,
is_worker_shutdown: builtins.bool = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"is_cancelled",
b"is_cancelled",
"is_not_found",
b"is_not_found",
"is_paused",
b"is_paused",
"is_timed_out",
b"is_timed_out",
"is_worker_shutdown",
b"is_worker_shutdown",
],
) -> None: ...

global___ActivityCancellationDetails = ActivityCancellationDetails
6 changes: 6 additions & 0 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ impl ClientRef {
"patch_schedule" => {
rpc_call!(retry_client, call, patch_schedule)
}
"pause_activity" => {
rpc_call!(retry_client, call, pause_activity)
}
"poll_activity_task_queue" => {
rpc_call!(retry_client, call, poll_activity_task_queue)
}
Expand Down Expand Up @@ -325,6 +328,9 @@ impl ClientRef {
"trigger_workflow_rule" => {
rpc_call!(retry_client, call, trigger_workflow_rule)
}
"unpause_activity" => {
rpc_call!(retry_client, call, unpause_activity)
}
"update_namespace" => {
rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace)
}
Expand Down
5 changes: 3 additions & 2 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6264,8 +6264,9 @@ async def heartbeat_async_activity(
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if resp_by_id.cancel_requested:
if resp_by_id.cancel_requested or resp_by_id.activity_paused:
raise AsyncActivityCancelledError()

else:
resp = await self._client.workflow_service.record_activity_task_heartbeat(
temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest(
Expand All @@ -6278,7 +6279,7 @@ async def heartbeat_async_activity(
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if resp.cancel_requested:
if resp.cancel_requested or resp.activity_paused:
raise AsyncActivityCancelledError()

async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
Expand Down
17 changes: 16 additions & 1 deletion temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,29 @@ def __init__(self) -> None:
self._cancelled = False
self._worker_shutdown = False
self._activities: Set[_Activity] = set()
self._cancellation_details = (
temporalio.activity._ActivityCancellationDetailsHolder()
)

def cancel(self) -> None:
def cancel(
self,
cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails(
cancel_requested=True
),
) -> None:
"""Cancel the activity.

Args:
cancellation_details: details about the cancellation. These will
be accessible through temporalio.activity.cancellation_details()
in the activity after cancellation.

This only has an effect on the first call.
"""
if self._cancelled:
return
self._cancelled = True
self._cancellation_details.details = cancellation_details
for act in self._activities:
act.cancel()

Expand Down Expand Up @@ -154,6 +168,7 @@ def __init__(
else self.cancel_thread_raiser.shielded,
payload_converter_class_or_instance=env.payload_converter,
runtime_metric_meter=env.metric_meter,
cancellation_details=env._cancellation_details,
)
self.task: Optional[asyncio.Task] = None

Expand Down
Loading
Loading