Skip to content

Commit 74dae01

Browse files
committed
Cancel remaining fields on exceptions
gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236
1 parent 0107e30 commit 74dae01

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

src/graphql/execution/execute.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
from __future__ import annotations
44

5-
from asyncio import ensure_future, gather, shield, wait_for
5+
from asyncio import (
6+
FIRST_EXCEPTION,
7+
create_task,
8+
ensure_future,
9+
gather,
10+
shield,
11+
wait,
12+
wait_for,
13+
)
614
from contextlib import suppress
715
from copy import copy
816
from typing import (
@@ -459,12 +467,16 @@ async def get_results() -> dict[str, Any]:
459467
field = awaitable_fields[0]
460468
results[field] = await results[field]
461469
else:
462-
results.update(
463-
zip(
464-
awaitable_fields,
465-
await gather(*(results[field] for field in awaitable_fields)),
466-
)
467-
)
470+
tasks = {}
471+
for field in awaitable_fields:
472+
tasks[create_task(results[field])] = field # type: ignore[arg-type]
473+
474+
done, pending = await wait(tasks, return_when=FIRST_EXCEPTION)
475+
for task in pending:
476+
task.cancel()
477+
478+
results.update((tasks[task], task.result()) for task in done)
479+
468480
return results
469481

470482
return get_results()

tests/execution/test_parallel.py

+37
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GraphQLInt,
1212
GraphQLInterfaceType,
1313
GraphQLList,
14+
GraphQLNonNull,
1415
GraphQLObjectType,
1516
GraphQLSchema,
1617
GraphQLString,
@@ -193,3 +194,39 @@ async def is_type_of_baz(obj, *_args):
193194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
194195
None,
195196
)
197+
198+
@pytest.mark.asyncio
199+
async def cancel_on_exception():
200+
barrier = Barrier(2)
201+
completed = asyncio.Event()
202+
203+
async def succeed(*_args):
204+
await barrier.wait()
205+
completed.set()
206+
207+
async def fail(*_args):
208+
raise Exception
209+
210+
schema = GraphQLSchema(
211+
GraphQLObjectType(
212+
"Query",
213+
{
214+
"foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail),
215+
"bar": GraphQLField(GraphQLBoolean, resolve=succeed),
216+
},
217+
)
218+
)
219+
220+
ast = parse("{foo, bar}")
221+
222+
awaitable_result = execute(schema, ast)
223+
assert isinstance(awaitable_result, Awaitable)
224+
result = await asyncio.wait_for(awaitable_result, 1.0)
225+
226+
assert result.errors
227+
assert not result.data
228+
229+
# Unblock succeed() and check that it does not complete
230+
await barrier.wait()
231+
await asyncio.sleep(0)
232+
assert not completed.is_set()

0 commit comments

Comments
 (0)