Skip to content

Commit 90aaf2d

Browse files
committed
extmod/uasyncio: Fix gather cancelling and handling of exceptions.
The following fixes are made: - cancelling a gather now cancels all sub-tasks of the gather (previously it would only cancel the first) - if any sub-task of a gather raises an exception then the gather finishes (previously it would only finish if the first sub-task raised) Fixes issues #5798, #7807, #7901. Signed-off-by: Damien George <[email protected]>
1 parent 335002a commit 90aaf2d

File tree

5 files changed

+202
-25
lines changed

5 files changed

+202
-25
lines changed

extmod/uasyncio/funcs.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,68 @@ def wait_for_ms(aw, timeout):
5353
return wait_for(aw, timeout, core.sleep_ms)
5454

5555

56+
class _Remove:
57+
@staticmethod
58+
def remove(t):
59+
pass
60+
61+
5662
async def gather(*aws, return_exceptions=False):
63+
def done(t, er):
64+
nonlocal state
65+
if type(state) is not int:
66+
# A sub-task already raised an exception, so do nothing.
67+
return
68+
elif not return_exceptions and not isinstance(er, StopIteration):
69+
# A sub-task raised an exception, indicate that to the gather task.
70+
state = er
71+
else:
72+
state -= 1
73+
if state:
74+
# Still some sub-tasks running.
75+
return
76+
# Gather waiting is done, schedule the main gather task.
77+
core._task_queue.push_head(gather_task)
78+
5779
ts = [core._promote_to_task(aw) for aw in aws]
5880
for i in range(len(ts)):
59-
try:
60-
# TODO handle cancel of gather itself
61-
# if ts[i].coro:
62-
# iter(ts[i]).waiting.push_head(cur_task)
63-
# try:
64-
# yield
65-
# except CancelledError as er:
66-
# # cancel all waiting tasks
67-
# raise er
68-
ts[i] = await ts[i]
69-
except (core.CancelledError, Exception) as er:
70-
if return_exceptions:
71-
ts[i] = er
72-
else:
73-
raise er
81+
if ts[i].state is not True:
82+
# Task is not running, gather not currently supported for this case.
83+
raise RuntimeError("can't gather")
84+
# Register the callback to call when the task is done.
85+
ts[i].state = done
86+
87+
# Set the state for execution of the gather.
88+
gather_task = core.cur_task
89+
state = len(ts)
90+
cancel_all = False
91+
92+
# Wait for the a sub-task to need attention.
93+
gather_task.data = _Remove
94+
try:
95+
yield
96+
except core.CancelledError as er:
97+
cancel_all = True
98+
state = er
99+
100+
# Clean up tasks.
101+
for i in range(len(ts)):
102+
if ts[i].state is done:
103+
# Sub-task is still running, deregister the callback and cancel if needed.
104+
ts[i].state = True
105+
if cancel_all:
106+
ts[i].cancel()
107+
elif isinstance(ts[i].data, StopIteration):
108+
# Sub-task ran to completion, get its return value.
109+
ts[i] = ts[i].data.value
110+
else:
111+
# Sub-task had an exception with return_exceptions==True, so get its exception.
112+
ts[i] = ts[i].data
113+
114+
# Either this gather was cancelled, or one of the sub-tasks raised an exception with
115+
# return_exceptions==False, so reraise the exception here.
116+
if state is not 0:
117+
raise state
118+
119+
# Return the list of return values of each sub-task.
74120
return ts

tests/extmod/uasyncio_gather.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,72 @@ async def task(id):
2727
return id
2828

2929

30-
async def gather_task():
30+
async def task_loop(id):
31+
print("task_loop start", id)
32+
while True:
33+
await asyncio.sleep(0.02)
34+
print("task_loop loop", id)
35+
36+
37+
async def task_raise(id):
38+
print("task_raise start", id)
39+
await asyncio.sleep(0.02)
40+
raise ValueError(id)
41+
42+
43+
async def gather_task(t0, t1):
3144
print("gather_task")
32-
await asyncio.gather(task(1), task(2))
45+
await asyncio.gather(t0, t1)
3346
print("gather_task2")
3447

3548

3649
async def main():
3750
# Simple gather with return values
3851
print(await asyncio.gather(factorial("A", 2), factorial("B", 3), factorial("C", 4)))
3952

53+
print("====")
54+
4055
# Test return_exceptions, where one task is cancelled and the other finishes normally
4156
tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
4257
tasks[0].cancel()
4358
print(await asyncio.gather(*tasks, return_exceptions=True))
4459

45-
# Cancel a multi gather
46-
# TODO doesn't work, Task should not forward cancellation from gather to sub-task
47-
# but rather CancelledError should cancel the gather directly, which will then cancel
48-
# all sub-tasks explicitly
49-
# t = asyncio.create_task(gather_task())
50-
# await asyncio.sleep(0.01)
51-
# t.cancel()
52-
# await asyncio.sleep(0.01)
60+
print("====")
61+
62+
# Test return_exceptions, where one task raises an exception and the other finishes normally.
63+
tasks = [asyncio.create_task(task(1)), asyncio.create_task(task_raise(2))]
64+
print(await asyncio.gather(*tasks, return_exceptions=True))
65+
66+
print("====")
67+
68+
# Test case where one task raises an exception and other task keeps running.
69+
tasks = [asyncio.create_task(task_loop(1)), asyncio.create_task(task_raise(2))]
70+
try:
71+
await asyncio.gather(*tasks)
72+
except ValueError as er:
73+
print(repr(er))
74+
print(tasks[0].done(), tasks[1].done())
75+
for t in tasks:
76+
t.cancel()
77+
await asyncio.sleep(0.04)
78+
79+
print("====")
80+
81+
# Test case where both tasks raise an exception.
82+
tasks = [asyncio.create_task(task_raise(1)), asyncio.create_task(task_raise(2))]
83+
try:
84+
await asyncio.gather(*tasks)
85+
except ValueError as er:
86+
print(repr(er))
87+
print(tasks[0].done(), tasks[1].done())
88+
89+
print("====")
90+
91+
# Cancel a multi gather.
92+
t = asyncio.create_task(gather_task(task(1), task(2)))
93+
await asyncio.sleep(0.01)
94+
t.cancel()
95+
await asyncio.sleep(0.04)
5396

5497

5598
asyncio.run(main())

tests/extmod/uasyncio_gather.py.exp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,27 @@ Task B: factorial(3) = 6
88
Task C: Compute factorial(4)...
99
Task C: factorial(4) = 24
1010
[2, 6, 24]
11+
====
1112
start 2
1213
end 2
1314
[CancelledError(), 2]
15+
====
16+
start 1
17+
task_raise start 2
18+
end 1
19+
[1, ValueError(2,)]
20+
====
21+
task_loop start 1
22+
task_raise start 2
23+
task_loop loop 1
24+
ValueError(2,)
25+
False True
26+
====
27+
task_raise start 1
28+
task_raise start 2
29+
ValueError(1,)
30+
True True
31+
====
32+
gather_task
33+
start 1
34+
start 2
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Test uasyncio.gather() function, features that are not implemented.
2+
3+
try:
4+
import uasyncio as asyncio
5+
except ImportError:
6+
try:
7+
import asyncio
8+
except ImportError:
9+
print("SKIP")
10+
raise SystemExit
11+
12+
13+
def custom_handler(loop, context):
14+
print(repr(context["exception"]))
15+
16+
17+
async def task(id):
18+
print("task start", id)
19+
await asyncio.sleep(0.01)
20+
print("task end", id)
21+
return id
22+
23+
24+
async def gather_task(t0, t1):
25+
print("gather_task start")
26+
await asyncio.gather(t0, t1)
27+
print("gather_task end")
28+
29+
30+
async def main():
31+
loop = asyncio.get_event_loop()
32+
loop.set_exception_handler(custom_handler)
33+
34+
# Test case where can't wait on a task being gathered.
35+
tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
36+
gt = asyncio.create_task(gather_task(tasks[0], tasks[1]))
37+
await asyncio.sleep(0) # let the gather start
38+
try:
39+
await tasks[0] # can't await because this task is part of the gather
40+
except RuntimeError as er:
41+
print(repr(er))
42+
await gt
43+
44+
print("====")
45+
46+
# Test case where can't gather on a task being waited.
47+
tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))]
48+
asyncio.create_task(gather_task(tasks[0], tasks[1]))
49+
await tasks[0] # wait on this task before the gather starts
50+
await tasks[1]
51+
52+
53+
asyncio.run(main())
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
task start 1
2+
task start 2
3+
gather_task start
4+
RuntimeError("can't wait",)
5+
task end 1
6+
task end 2
7+
gather_task end
8+
====
9+
task start 1
10+
task start 2
11+
gather_task start
12+
RuntimeError("can't gather",)
13+
task end 1
14+
task end 2

0 commit comments

Comments
 (0)