Skip to content

Commit 316a528

Browse files
authored
Handle case where in_workflow is called in a synchronous activity (#884)
* Handle case where in_workflow is called in a synchronous activity * Formatting * Move the catch down into * Add value validation * Format after merge
1 parent 785aca6 commit 316a528

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

temporalio/workflow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,12 @@ def current() -> _Runtime:
642642

643643
@staticmethod
644644
def maybe_current() -> Optional[_Runtime]:
645-
return getattr(asyncio.get_running_loop(), "__temporal_workflow_runtime", None)
645+
try:
646+
return getattr(
647+
asyncio.get_running_loop(), "__temporal_workflow_runtime", None
648+
)
649+
except RuntimeError:
650+
return None
646651

647652
@staticmethod
648653
def set_on_loop(

tests/worker/test_workflow.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7975,6 +7975,37 @@ async def test_quick_activity_swallows_cancellation(client: Client):
79757975
temporalio.worker._workflow_instance._raise_on_cancelling_completed_activity_override = False
79767976

79777977

7978+
@activity.defn
7979+
def use_in_workflow() -> bool:
7980+
return workflow.in_workflow()
7981+
7982+
7983+
@workflow.defn
7984+
class UseInWorkflow:
7985+
@workflow.run
7986+
async def run(self):
7987+
res = await workflow.execute_activity(
7988+
use_in_workflow, schedule_to_close_timeout=timedelta(seconds=10)
7989+
)
7990+
return res
7991+
7992+
7993+
async def test_in_workflow_sync(client: Client):
7994+
async with new_worker(
7995+
client,
7996+
UseInWorkflow,
7997+
activities=[use_in_workflow],
7998+
activity_executor=concurrent.futures.ThreadPoolExecutor(max_workers=1),
7999+
) as worker:
8000+
res = await client.execute_workflow(
8001+
UseInWorkflow.run,
8002+
id=f"test_in_workflow_sync",
8003+
task_queue=worker.task_queue,
8004+
execution_timeout=timedelta(minutes=1),
8005+
)
8006+
assert not res
8007+
8008+
79788009
class SignalInterceptor(temporalio.worker.Interceptor):
79798010
def workflow_interceptor_class(
79808011
self, input: temporalio.worker.WorkflowInterceptorClassInput

0 commit comments

Comments
 (0)