Skip to content

bpo-46752: Taskgroup tweaks #31559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,28 @@ async def __aexit__(self, et, exc, tb):
self._base_error is None):
self._base_error = exc

if et is exceptions.CancelledError:
if self._parent_cancel_requested:
# Only if we did request task to cancel ourselves
# we mark it as no longer cancelled.
self._parent_task.uncancel()
else:
propagate_cancellation_error = et

if et is not None and not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
if et is not None:
if et is exceptions.CancelledError:
propagate_cancellation_error = et

# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()
if self._parent_cancel_requested and not self._parent_task.uncancel():
# Do nothing, i.e. swallow the error.
pass
else:
propagate_cancellation_error = exc

if not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()

# We use while-loop here because "self._on_completed_fut"
# can be cancelled multiple times if our parent task
Expand Down Expand Up @@ -118,7 +115,6 @@ async def __aexit__(self, et, exc, tb):
self._on_completed_fut = None

assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed

if self._base_error is not None:
raise self._base_error
Expand Down Expand Up @@ -199,8 +195,7 @@ def _on_task_done(self, task):
})
return

self._abort()
if not self._parent_task.cancelling():
if not self._aborting and not self._parent_cancel_requested:
# If parent task *is not* being cancelled, it means that we want
# to manually cancel it to abort whatever is being run right now
# in the TaskGroup. But we want to mark parent task as
Expand All @@ -219,5 +214,6 @@ def _on_task_done(self, task):
# pass
# await something_else # this line has to be called
# # after TaskGroup is finished.
self._abort()
self._parent_cancel_requested = True
self._parent_task.cancel()
26 changes: 19 additions & 7 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ async def runner():
self.assertTrue(t2_cancel)
self.assertTrue(t2.cancelled())

async def test_taskgroup_05(self):
async def test_cancel_children_on_child_error(self):
"""
When a child task raises an error, the rest of the children
are cancelled and the errors are gathered into an EG.
"""

NUM = 0
t2_cancel = False
Expand Down Expand Up @@ -165,7 +169,7 @@ async def runner():
self.assertTrue(t2_cancel)
self.assertTrue(runner_cancel)

async def test_taskgroup_06(self):
async def test_cancellation(self):

NUM = 0

Expand All @@ -186,10 +190,12 @@ async def runner():
await asyncio.sleep(0.1)

self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
r.cancel("test")
with self.assertRaises(asyncio.CancelledError) as cm:
await r

self.assertEqual(cm.exception.args, ('test',))

self.assertEqual(NUM, 5)

async def test_taskgroup_07(self):
Expand Down Expand Up @@ -226,7 +232,7 @@ async def runner():

self.assertEqual(NUM, 15)

async def test_taskgroup_08(self):
async def test_cancellation_in_body(self):

async def foo():
await asyncio.sleep(0.1)
Expand All @@ -246,10 +252,12 @@ async def runner():
await asyncio.sleep(0.1)

self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
r.cancel("test")
with self.assertRaises(asyncio.CancelledError) as cm:
await r

self.assertEqual(cm.exception.args, ('test',))

async def test_taskgroup_09(self):

t1 = t2 = None
Expand Down Expand Up @@ -699,3 +707,7 @@ async def coro():
async with taskgroups.TaskGroup() as g:
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")


if __name__ == "__main__":
unittest.main()