Skip to content

Commit 016ce91

Browse files
committed
Bug Fix: Defensively copy context entities
Before this change, concurrent async tasks would all share the same instance of the entities list. This change makes it so they each get their own copy of the list. This matters because the recorder modifies the list in place, which makes it so concurrent subtasks have the wrong parent subsegment.
1 parent 816ab26 commit 016ce91

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

aws_xray_sdk/core/async_context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import sys
3+
import copy
34

45
from .context import Context as _Context
56

@@ -108,6 +109,13 @@ def task_factory(loop, coro):
108109
else:
109110
current_task = asyncio.Task.current_task(loop=loop)
110111
if current_task is not None and hasattr(current_task, 'context'):
111-
setattr(task, 'context', current_task.context)
112+
if current_task.context.get('entities'):
113+
# Defensively copying because recorder modifies the list in place.
114+
new_context = copy.copy(current_task.context)
115+
new_context['entities'] = [item for item in current_task.context['entities']]
116+
else:
117+
# no reason to copy if there's no entities list.
118+
new_context = current_task.context
119+
setattr(task, 'context', new_context)
112120

113121
return task

tests/test_async_recorder.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .util import get_new_stubbed_recorder
44
from aws_xray_sdk.version import VERSION
55
from aws_xray_sdk.core.async_context import AsyncContext
6+
import asyncio
67

78

89
xray_recorder = get_new_stubbed_recorder()
@@ -43,6 +44,28 @@ async def test_capture(loop):
4344
assert platform.python_implementation() == service.get('runtime')
4445
assert platform.python_version() == service.get('runtime_version')
4546

47+
async def test_concurrent_calls(loop):
48+
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
49+
async with xray_recorder.in_segment_async('segment') as segment:
50+
global counter
51+
counter = 0
52+
total_tasks = 10
53+
event = asyncio.Event()
54+
async def assert_task():
55+
async with xray_recorder.in_subsegment_async('segment') as subsegment:
56+
global counter
57+
counter += 1
58+
# Ensure that the task subsegments overlap
59+
if counter < total_tasks:
60+
await event.wait()
61+
else:
62+
event.set()
63+
return subsegment.parent_id
64+
tasks = [assert_task() for task in range(total_tasks)]
65+
results = await asyncio.gather(*tasks)
66+
for result in results:
67+
assert result == segment.id
68+
4669

4770
async def test_async_context_managers(loop):
4871
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))

0 commit comments

Comments
 (0)