From 6b17c3421e0592f53457eaaef87929469331e1da Mon Sep 17 00:00:00 2001 From: Michael Gorven Date: Mon, 12 May 2025 15:48:10 -0700 Subject: [PATCH] Cancel remaining fields on exceptions gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236 --- src/graphql/execution/execute.py | 33 +++++++++++++++++++++------ tests/execution/test_parallel.py | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 1097e80f..db5e11a6 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -2,7 +2,16 @@ from __future__ import annotations -from asyncio import ensure_future, gather, shield, wait_for +from asyncio import ( + FIRST_EXCEPTION, + CancelledError, + create_task, + ensure_future, + gather, + shield, + wait, + wait_for, +) from contextlib import suppress from copy import copy from typing import ( @@ -459,12 +468,18 @@ async def get_results() -> dict[str, Any]: field = awaitable_fields[0] results[field] = await results[field] else: - results.update( - zip( - awaitable_fields, - await gather(*(results[field] for field in awaitable_fields)), - ) - ) + tasks = {} + for field in awaitable_fields: + tasks[create_task(results[field])] = field # type: ignore[arg-type] + + done, pending = await wait(tasks, return_when=FIRST_EXCEPTION) + if pending: + for task in pending: + task.cancel() + await wait(pending) + + results.update((tasks[task], task.result()) for task in done) + return results return get_results() @@ -538,6 +553,10 @@ async def await_completed() -> Any: try: return await completed except Exception as raw_error: + # Before Python 3.8 CancelledError inherits Exception and + # so gets caught here. + if isinstance(raw_error, CancelledError): + raise self.handle_field_error( raw_error, return_type, diff --git a/tests/execution/test_parallel.py b/tests/execution/test_parallel.py index f4dc86b1..fdf379a7 100644 --- a/tests/execution/test_parallel.py +++ b/tests/execution/test_parallel.py @@ -11,6 +11,7 @@ GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString, @@ -193,3 +194,40 @@ async def is_type_of_baz(obj, *_args): {"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]}, None, ) + + @pytest.mark.asyncio + async def cancel_on_exception(): + barrier = Barrier(2) + completed = False + + async def succeed(*_args): + nonlocal completed + await barrier.wait() + completed = True + + async def fail(*_args): + raise Exception + + schema = GraphQLSchema( + GraphQLObjectType( + "Query", + { + "foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail), + "bar": GraphQLField(GraphQLBoolean, resolve=succeed), + }, + ) + ) + + ast = parse("{foo, bar}") + + awaitable_result = execute(schema, ast) + assert isinstance(awaitable_result, Awaitable) + result = await asyncio.wait_for(awaitable_result, 1.0) + + assert result.errors + assert not result.data + + # Unblock succeed() and check that it does not complete + await barrier.wait() + await asyncio.sleep(0) + assert not completed