Skip to content

Commit 6c39940

Browse files
desertaxleclaude
andauthored
Add with_context() for logging from subprocesses (#21304)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b1dea50 commit 6c39940

File tree

4 files changed

+265
-1
lines changed

4 files changed

+265
-1
lines changed

.github/workflows/python-tests.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ jobs:
131131
tests/test_schedules.py
132132
tests/test_serializers.py
133133
tests/test_states.py
134+
tests/test_subprocess_logging.py
134135
tests/test_task_engine.py
135136
tests/test_task_runs.py
136137
tests/test_task_worker.py

docs/v3/how-to-guides/workflows/add-logging.mdx

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
title: How to add logging to a workflow
33
sidebarTitle: Add logging
4-
keywords: ["logging", "log", "print", "debug", "log_prints", "get_run_logger"]
4+
keywords: ["logging", "log", "print", "debug", "log_prints", "get_run_logger", "subprocess", "multiprocessing", "with_context"]
55
---
66

77
### Emit custom logs
@@ -88,6 +88,43 @@ You can configure the default `log_prints` setting for all Prefect flow and task
8888
prefect config set PREFECT_LOGGING_LOG_PRINTS=True
8989
```
9090

91+
## Log from subprocesses
92+
93+
When you spawn subprocesses inside a flow or task — for example, with `multiprocessing.Pool`
94+
or `concurrent.futures.ProcessPoolExecutor` — the Prefect run context is not automatically
95+
available in the child process. This means `get_run_logger()` raises a `MissingContextError`.
96+
97+
Use `with_context` from `prefect.context` to propagate the current run context into
98+
subprocess workers. Logs emitted with `get_run_logger()` in the child process are
99+
associated with the parent flow run and task run and appear in the Prefect UI.
100+
101+
{/* pmd-metadata: notest */}
102+
```python
103+
import multiprocessing
104+
from prefect import flow, task
105+
from prefect.context import with_context
106+
from prefect.logging import get_run_logger
107+
108+
109+
def process_item(item):
110+
logger = get_run_logger()
111+
logger.info(f"Processing {item}")
112+
return item * 2
113+
114+
115+
@task
116+
def parallel_task(items):
117+
with multiprocessing.Pool() as pool:
118+
return pool.map(with_context(process_item), items)
119+
120+
121+
@flow
122+
def my_flow():
123+
results = parallel_task([1, 2, 3, 4])
124+
```
125+
126+
`with_context` also works with `concurrent.futures.ProcessPoolExecutor` and `multiprocessing.Process`.
127+
91128
## Access logs from the command line
92129

93130
You can retrieve logs for a specific flow run ID using Prefect's CLI:

src/prefect/context.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,53 @@
6060
from prefect.tasks import Task
6161

6262

63+
class _ContextWrappedCallable:
64+
"""Picklable callable that hydrates Prefect context before calling the
65+
wrapped function. The serialized context is stored as cloudpickle
66+
bytes so that standard pickle (used by `multiprocessing`) can handle it."""
67+
68+
def __init__(
69+
self, fn: Callable[..., Any], serialized_context: dict[str, Any]
70+
) -> None:
71+
import cloudpickle
72+
73+
self.fn = fn
74+
self._ctx_bytes = cloudpickle.dumps(serialized_context)
75+
76+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
77+
import cloudpickle
78+
79+
ctx = cloudpickle.loads(self._ctx_bytes)
80+
with hydrated_context(ctx):
81+
return self.fn(*args, **kwargs)
82+
83+
84+
def with_context(fn: Callable[..., Any]) -> _ContextWrappedCallable:
85+
"""Wrap a function so it runs with the current Prefect context when
86+
called in a subprocess.
87+
88+
Use this to enable `get_run_logger()` and other context-dependent
89+
APIs in functions executed via `multiprocessing.Pool`,
90+
`ProcessPoolExecutor`, or `multiprocessing.Process`.
91+
92+
Example:
93+
```python
94+
from prefect.context import with_context
95+
96+
def worker(item):
97+
logger = get_run_logger()
98+
logger.info(f"Processing {item}")
99+
100+
@task
101+
def my_task():
102+
with multiprocessing.Pool() as pool:
103+
pool.map(with_context(worker), items)
104+
```
105+
"""
106+
ctx = serialize_context()
107+
return _ContextWrappedCallable(fn, ctx)
108+
109+
63110
def serialize_context(
64111
asset_ctx_kwargs: Union[dict[str, Any], None] = None,
65112
) -> dict[str, Any]:

tests/test_subprocess_logging.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Tests for `with_context` — propagating Prefect context to subprocesses."""
2+
3+
import multiprocessing
4+
from concurrent.futures import ProcessPoolExecutor
5+
6+
import pytest
7+
8+
from prefect import flow, task
9+
from prefect.context import with_context
10+
from prefect.exceptions import MissingContextError
11+
from prefect.logging import get_run_logger
12+
13+
# ---------------------------------------------------------------------------
14+
# Helper functions executed in child processes
15+
# ---------------------------------------------------------------------------
16+
17+
18+
def _log_and_return_ids(item: int) -> dict:
19+
"""Worker that uses get_run_logger() and returns context IDs."""
20+
from prefect.context import FlowRunContext, TaskRunContext
21+
22+
logger = get_run_logger()
23+
logger.info(f"Processing {item}")
24+
25+
flow_ctx = FlowRunContext.get()
26+
task_ctx = TaskRunContext.get()
27+
return {
28+
"item": item,
29+
"flow_run_id": str(flow_ctx.flow_run.id)
30+
if flow_ctx and flow_ctx.flow_run
31+
else None,
32+
"task_run_id": str(task_ctx.task_run.id) if task_ctx else None,
33+
}
34+
35+
36+
def _just_get_logger(_: object = None) -> str:
37+
"""Worker that calls get_run_logger() and returns its name."""
38+
logger = get_run_logger()
39+
return logger.name
40+
41+
42+
# ---------------------------------------------------------------------------
43+
# Tests
44+
# ---------------------------------------------------------------------------
45+
46+
47+
class TestWithContextPoolMap:
48+
"""multiprocessing.Pool.map with with_context."""
49+
50+
def test_get_run_logger_works_in_pool_worker(self):
51+
@task
52+
def my_task():
53+
wrapped = with_context(_just_get_logger)
54+
with multiprocessing.get_context("spawn").Pool(1) as pool:
55+
results = pool.map(wrapped, [1])
56+
return results
57+
58+
@flow
59+
def my_flow():
60+
return my_task()
61+
62+
results = my_flow()
63+
assert len(results) == 1
64+
assert isinstance(results[0], str)
65+
66+
def test_correct_run_ids_in_pool_worker(self):
67+
@task
68+
def my_task():
69+
from prefect.context import FlowRunContext, TaskRunContext
70+
71+
flow_ctx = FlowRunContext.get()
72+
task_ctx = TaskRunContext.get()
73+
parent_flow_run_id = (
74+
str(flow_ctx.flow_run.id) if flow_ctx and flow_ctx.flow_run else None
75+
)
76+
parent_task_run_id = str(task_ctx.task_run.id) if task_ctx else None
77+
78+
wrapped = with_context(_log_and_return_ids)
79+
with multiprocessing.get_context("spawn").Pool(1) as pool:
80+
results = pool.map(wrapped, [1, 2])
81+
return results, parent_flow_run_id, parent_task_run_id
82+
83+
@flow
84+
def my_flow():
85+
return my_task()
86+
87+
results, parent_flow_id, parent_task_id = my_flow()
88+
assert len(results) == 2
89+
for r in results:
90+
assert r["flow_run_id"] == parent_flow_id
91+
assert r["task_run_id"] == parent_task_id
92+
93+
94+
class TestWithContextProcess:
95+
"""multiprocessing.Process with with_context."""
96+
97+
def test_get_run_logger_works_in_process(self):
98+
@task
99+
def my_task():
100+
ctx = multiprocessing.get_context("spawn")
101+
q: multiprocessing.Queue = ctx.Queue() # type: ignore[type-arg]
102+
103+
wrapped = with_context(_target_with_queue)
104+
p = ctx.Process(target=wrapped, args=(q,))
105+
p.start()
106+
p.join(timeout=30)
107+
assert q.get_nowait() == "ok"
108+
109+
@flow
110+
def my_flow():
111+
return my_task()
112+
113+
my_flow()
114+
115+
116+
def _target_with_queue(queue: multiprocessing.Queue) -> None: # type: ignore[type-arg]
117+
"""Process target that puts a value on a queue after using get_run_logger."""
118+
logger = get_run_logger()
119+
logger.info("hello from subprocess")
120+
queue.put("ok")
121+
122+
123+
class TestWithContextProcessPoolExecutor:
124+
"""concurrent.futures.ProcessPoolExecutor with with_context."""
125+
126+
def test_get_run_logger_works_in_executor(self):
127+
@task
128+
def my_task():
129+
wrapped = with_context(_just_get_logger)
130+
ctx = multiprocessing.get_context("spawn")
131+
with ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
132+
future = executor.submit(wrapped)
133+
return future.result(timeout=30)
134+
135+
@flow
136+
def my_flow():
137+
return my_task()
138+
139+
result = my_flow()
140+
assert isinstance(result, str)
141+
142+
143+
class TestWithContextOutsideRun:
144+
"""with_context called outside a flow/task run should raise."""
145+
146+
def test_raises_missing_context_error(self):
147+
"""with_context still works outside a run — it serializes whatever
148+
context exists (which may be empty). The error only happens when
149+
the subprocess tries to call get_run_logger() without a run context."""
150+
wrapped = with_context(_just_get_logger)
151+
# The wrapper itself is created fine, but calling it in a subprocess
152+
# without run context will fail when get_run_logger is called.
153+
with pytest.raises(MissingContextError):
154+
wrapped()
155+
156+
157+
class TestWithContextFlowOnly:
158+
"""with_context inside a flow (no task) propagates flow context."""
159+
160+
def test_flow_context_propagated(self):
161+
@flow
162+
def my_flow():
163+
from prefect.context import FlowRunContext
164+
165+
flow_ctx = FlowRunContext.get()
166+
parent_flow_run_id = (
167+
str(flow_ctx.flow_run.id) if flow_ctx and flow_ctx.flow_run else None
168+
)
169+
170+
wrapped = with_context(_log_and_return_ids)
171+
ctx = multiprocessing.get_context("spawn")
172+
with ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
173+
future = executor.submit(wrapped, 42)
174+
result = future.result(timeout=30)
175+
return result, parent_flow_run_id
176+
177+
result, parent_flow_id = my_flow()
178+
assert result["flow_run_id"] == parent_flow_id
179+
assert result["task_run_id"] is None

0 commit comments

Comments
 (0)