Skip to content

Commit 6b17c34

Browse files
committed
Cancel remaining fields on exceptions
gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236
1 parent 0107e30 commit 6b17c34

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

src/graphql/execution/execute.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from __future__ import annotations
44

5-
from asyncio import ensure_future, gather, shield, wait_for
5+
from asyncio import (
6+
FIRST_EXCEPTION,
7+
CancelledError,
8+
create_task,
9+
ensure_future,
10+
gather,
11+
shield,
12+
wait,
13+
wait_for,
14+
)
615
from contextlib import suppress
716
from copy import copy
817
from typing import (
@@ -459,12 +468,18 @@ async def get_results() -> dict[str, Any]:
459468
field = awaitable_fields[0]
460469
results[field] = await results[field]
461470
else:
462-
results.update(
463-
zip(
464-
awaitable_fields,
465-
await gather(*(results[field] for field in awaitable_fields)),
466-
)
467-
)
471+
tasks = {}
472+
for field in awaitable_fields:
473+
tasks[create_task(results[field])] = field # type: ignore[arg-type]
474+
475+
done, pending = await wait(tasks, return_when=FIRST_EXCEPTION)
476+
if pending:
477+
for task in pending:
478+
task.cancel()
479+
await wait(pending)
480+
481+
results.update((tasks[task], task.result()) for task in done)
482+
468483
return results
469484

470485
return get_results()
@@ -538,6 +553,10 @@ async def await_completed() -> Any:
538553
try:
539554
return await completed
540555
except Exception as raw_error:
556+
# Before Python 3.8 CancelledError inherits Exception and
557+
# so gets caught here.
558+
if isinstance(raw_error, CancelledError):
559+
raise
541560
self.handle_field_error(
542561
raw_error,
543562
return_type,

tests/execution/test_parallel.py

+38
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GraphQLInt,
1212
GraphQLInterfaceType,
1313
GraphQLList,
14+
GraphQLNonNull,
1415
GraphQLObjectType,
1516
GraphQLSchema,
1617
GraphQLString,
@@ -193,3 +194,40 @@ async def is_type_of_baz(obj, *_args):
193194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
194195
None,
195196
)
197+
198+
@pytest.mark.asyncio
199+
async def cancel_on_exception():
200+
barrier = Barrier(2)
201+
completed = False
202+
203+
async def succeed(*_args):
204+
nonlocal completed
205+
await barrier.wait()
206+
completed = True
207+
208+
async def fail(*_args):
209+
raise Exception
210+
211+
schema = GraphQLSchema(
212+
GraphQLObjectType(
213+
"Query",
214+
{
215+
"foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail),
216+
"bar": GraphQLField(GraphQLBoolean, resolve=succeed),
217+
},
218+
)
219+
)
220+
221+
ast = parse("{foo, bar}")
222+
223+
awaitable_result = execute(schema, ast)
224+
assert isinstance(awaitable_result, Awaitable)
225+
result = await asyncio.wait_for(awaitable_result, 1.0)
226+
227+
assert result.errors
228+
assert not result.data
229+
230+
# Unblock succeed() and check that it does not complete
231+
await barrier.wait()
232+
await asyncio.sleep(0)
233+
assert not completed

0 commit comments

Comments
 (0)