File tree 2 files changed +32
-1
lines changed
2 files changed +32
-1
lines changed Original file line number Diff line number Diff line change 1
1
import asyncio
2
2
import sys
3
+ import copy
3
4
4
5
from .context import Context as _Context
5
6
@@ -108,6 +109,13 @@ def task_factory(loop, coro):
108
109
else :
109
110
current_task = asyncio .Task .current_task (loop = loop )
110
111
if current_task is not None and hasattr (current_task , 'context' ):
111
- setattr (task , 'context' , current_task .context )
112
+ if current_task .context .get ('entities' ):
113
+ # Defensively copying because recorder modifies the list in place.
114
+ new_context = copy .copy (current_task .context )
115
+ new_context ['entities' ] = [item for item in current_task .context ['entities' ]]
116
+ else :
117
+ # no reason to copy if there's no entities list.
118
+ new_context = current_task .context
119
+ setattr (task , 'context' , new_context )
112
120
113
121
return task
Original file line number Diff line number Diff line change 3
3
from .util import get_new_stubbed_recorder
4
4
from aws_xray_sdk .version import VERSION
5
5
from aws_xray_sdk .core .async_context import AsyncContext
6
+ import asyncio
6
7
7
8
8
9
xray_recorder = get_new_stubbed_recorder ()
@@ -43,6 +44,28 @@ async def test_capture(loop):
43
44
assert platform .python_implementation () == service .get ('runtime' )
44
45
assert platform .python_version () == service .get ('runtime_version' )
45
46
47
+ async def test_concurrent_calls (loop ):
48
+ xray_recorder .configure (service = 'test' , sampling = False , context = AsyncContext (loop = loop ))
49
+ async with xray_recorder .in_segment_async ('segment' ) as segment :
50
+ global counter
51
+ counter = 0
52
+ total_tasks = 10
53
+ event = asyncio .Event ()
54
+ async def assert_task ():
55
+ async with xray_recorder .in_subsegment_async ('segment' ) as subsegment :
56
+ global counter
57
+ counter += 1
58
+ # Ensure that the task subsegments overlap
59
+ if counter < total_tasks :
60
+ await event .wait ()
61
+ else :
62
+ event .set ()
63
+ return subsegment .parent_id
64
+ tasks = [assert_task () for task in range (total_tasks )]
65
+ results = await asyncio .gather (* tasks )
66
+ for result in results :
67
+ assert result == segment .id
68
+
46
69
47
70
async def test_async_context_managers (loop ):
48
71
xray_recorder .configure (service = 'test' , sampling = False , context = AsyncContext (loop = loop ))
You can’t perform that action at this time.
0 commit comments