Skip to content

Commit a311b86

Browse files
syrusakbaryCito
authored andcommitted
Added support for middleware (#7)
A useful addition taken over from GraphQL-core (not in GraphQL.js).
1 parent 62a2c97 commit a311b86

File tree

6 files changed

+266
-39
lines changed

6 files changed

+266
-39
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[flake8]
22
exclude = .git,.mypy_cache,.pytest_cache,.tox,.venv,__pycache__,build,dist,docs
3+
max-line-length = 88

graphql/execution/__init__.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@
55
"""
66

77
from .execute import (
8-
execute, default_field_resolver, response_path_as_list,
9-
ExecutionContext, ExecutionResult)
8+
execute,
9+
default_field_resolver,
10+
response_path_as_list,
11+
ExecutionContext,
12+
ExecutionResult,
13+
)
14+
from .middleware import MiddlewareManager
1015
from .values import get_directive_values
1116

1217
__all__ = [
13-
'execute', 'default_field_resolver', 'response_path_as_list',
14-
'ExecutionContext', 'ExecutionResult',
15-
'get_directive_values']
18+
"execute",
19+
"default_field_resolver",
20+
"response_path_as_list",
21+
"ExecutionContext",
22+
"ExecutionResult",
23+
"MiddlewareManager",
24+
"get_directive_values",
25+
]

graphql/execution/execute.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
is_non_null_type, is_object_type)
2121
from .values import (
2222
get_argument_values, get_directive_values, get_variable_values)
23+
from .middleware import MiddlewareManager
24+
2325

2426
__all__ = [
2527
'add_path', 'assert_valid_execution_arguments', 'default_field_resolver',
@@ -64,7 +66,8 @@ def execute(
6466
schema: GraphQLSchema, document: DocumentNode,
6567
root_value: Any=None, context_value: Any=None,
6668
variable_values: Dict[str, Any]=None,
67-
operation_name: str=None, field_resolver: GraphQLFieldResolver=None
69+
operation_name: str=None, field_resolver: GraphQLFieldResolver=None,
70+
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
6871
) -> MaybeAwaitable[ExecutionResult]:
6972
"""Execute a GraphQL operation.
7073
@@ -84,7 +87,7 @@ def execute(
8487
# arguments, a "Response" with only errors is returned.
8588
exe_context = ExecutionContext.build(
8689
schema, document, root_value, context_value,
87-
variable_values, operation_name, field_resolver)
90+
variable_values, operation_name, field_resolver, middleware)
8891

8992
# Return early errors if execution context failed.
9093
if isinstance(exe_context, list):
@@ -116,6 +119,7 @@ class ExecutionContext:
116119
operation: OperationDefinitionNode
117120
variable_values: Dict[str, Any]
118121
field_resolver: GraphQLFieldResolver
122+
middleware_manager: Optional[MiddlewareManager]
119123
errors: List[GraphQLError]
120124

121125
def __init__(
@@ -125,6 +129,7 @@ def __init__(
125129
operation: OperationDefinitionNode,
126130
variable_values: Dict[str, Any],
127131
field_resolver: GraphQLFieldResolver,
132+
middleware_manager: Optional[MiddlewareManager],
128133
errors: List[GraphQLError]) -> None:
129134
self.schema = schema
130135
self.fragments = fragments
@@ -133,6 +138,7 @@ def __init__(
133138
self.operation = operation
134139
self.variable_values = variable_values
135140
self.field_resolver = field_resolver # type: ignore
141+
self.middleware_manager = middleware_manager
136142
self.errors = errors
137143
self._subfields_cache: Dict[
138144
Tuple[GraphQLObjectType, Tuple[FieldNode, ...]],
@@ -144,7 +150,8 @@ def build(
144150
root_value: Any=None, context_value: Any=None,
145151
raw_variable_values: Dict[str, Any]=None,
146152
operation_name: str=None,
147-
field_resolver: GraphQLFieldResolver=None
153+
field_resolver: GraphQLFieldResolver=None,
154+
middleware: Optional[Union[Iterable[Any], MiddlewareManager]]=None
148155
) -> Union[List[GraphQLError], 'ExecutionContext']:
149156
"""Build an execution context
150157
@@ -157,6 +164,18 @@ def build(
157164
operation: Optional[OperationDefinitionNode] = None
158165
has_multiple_assumed_operations = False
159166
fragments: Dict[str, FragmentDefinitionNode] = {}
167+
middleware_manager: Optional[MiddlewareManager] = None
168+
if middleware:
169+
if isinstance(middleware, Iterable):
170+
middleware_manager = MiddlewareManager(*middleware)
171+
elif isinstance(middleware, MiddlewareManager):
172+
middleware_manager = middleware
173+
else:
174+
raise TypeError(
175+
f"middlewares have to be an instance"
176+
"of MiddlewareManager. Received \"{middleware}\""
177+
)
178+
160179
for definition in document.definitions:
161180
if isinstance(definition, OperationDefinitionNode):
162181
if not operation_name and operation:
@@ -201,7 +220,8 @@ def build(
201220

202221
return cls(
203222
schema, fragments, root_value, context_value, operation,
204-
variable_values, field_resolver or default_field_resolver, errors)
223+
variable_values, field_resolver or default_field_resolver,
224+
middleware_manager, errors)
205225

206226
def build_response(
207227
self, data: MaybeAwaitable[Optional[Dict[str, Any]]]
@@ -447,6 +467,9 @@ def resolve_field(
447467

448468
resolve_fn = field_def.resolve or self.field_resolver
449469

470+
if self.middleware_manager:
471+
resolve_fn = self.middleware_manager.get_field_resolver(resolve_fn)
472+
450473
info = self.build_resolve_info(
451474
field_def, field_nodes, parent_type, path)
452475

graphql/execution/middleware.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast
2+
3+
from inspect import isfunction
4+
from functools import partial
5+
from itertools import chain
6+
7+
8+
from ..type import GraphQLFieldResolver
9+
10+
11+
__all__ = ["MiddlewareManager", "middlewares"]
12+
13+
# If the provided middleware is a class, this is the attribute we will look at
14+
MIDDLEWARE_RESOLVER_FUNCTION = "resolve"
15+
16+
17+
class MiddlewareManager:
18+
"""MiddlewareManager helps to chain resolver functions with the provided
19+
middleware functions and classes
20+
"""
21+
22+
__slots__ = ("middlewares", "_middleware_resolvers", "_cached_resolvers")
23+
24+
_cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver]
25+
_middleware_resolvers: Optional[Tuple[Callable, ...]]
26+
27+
def __init__(self, *middlewares: Any) -> None:
28+
self.middlewares = middlewares
29+
if middlewares:
30+
self._middleware_resolvers = tuple(get_middleware_resolvers(middlewares))
31+
else:
32+
self.__middleware_resolvers = None
33+
self._cached_resolvers = {}
34+
35+
def get_field_resolver(
36+
self, field_resolver: GraphQLFieldResolver
37+
) -> GraphQLFieldResolver:
38+
"""Wraps the provided resolver returning a function that
39+
executes chains the middleware functions with the resolver function"""
40+
if self._middleware_resolvers is None:
41+
return field_resolver
42+
if field_resolver not in self._cached_resolvers:
43+
self._cached_resolvers[field_resolver] = middleware_chain(
44+
field_resolver, self._middleware_resolvers
45+
)
46+
47+
return self._cached_resolvers[field_resolver]
48+
49+
50+
middlewares = MiddlewareManager
51+
52+
53+
def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]:
54+
"""Returns the functions related to the middleware classes or functions"""
55+
for middleware in middlewares:
56+
# If the middleware is a function instead of a class
57+
if isfunction(middleware):
58+
yield middleware
59+
resolver_func = getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION, None)
60+
if resolver_func is not None:
61+
yield resolver_func
62+
63+
64+
def middleware_chain(
65+
func: GraphQLFieldResolver, middlewares: Iterable[Callable]
66+
) -> GraphQLFieldResolver:
67+
"""Reduces the current function with the provided middlewares,
68+
returning a new resolver function"""
69+
if not middlewares:
70+
return func
71+
middlewares = chain((func,), middlewares)
72+
last_func: Optional[GraphQLFieldResolver] = None
73+
for middleware in middlewares:
74+
last_func = partial(middleware, last_func) if last_func else middleware
75+
76+
return cast(GraphQLFieldResolver, last_func)

graphql/graphql.py

+43-30
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
from asyncio import ensure_future
22
from inspect import isawaitable
3-
from typing import Any, Awaitable, Callable, Dict, Union, cast
3+
from typing import Any, Awaitable, Callable, Dict, Union, Optional, Iterable, cast
44

55
from .error import GraphQLError
66
from .execution import execute
77
from .language import parse, Source
88
from .pyutils import MaybeAwaitable
99
from .type import GraphQLSchema, validate_schema
10-
from .execution.execute import ExecutionResult
10+
from .execution import ExecutionResult, MiddlewareManager
1111

12-
__all__ = ['graphql', 'graphql_sync']
12+
__all__ = ["graphql", "graphql_sync"]
1313

1414

1515
async def graphql(
16-
schema: GraphQLSchema,
17-
source: Union[str, Source],
18-
root_value: Any=None,
19-
context_value: Any=None,
20-
variable_values: Dict[str, Any]=None,
21-
operation_name: str=None,
22-
field_resolver: Callable=None) -> ExecutionResult:
16+
schema: GraphQLSchema,
17+
source: Union[str, Source],
18+
root_value: Any = None,
19+
context_value: Any = None,
20+
variable_values: Dict[str, Any] = None,
21+
operation_name: str = None,
22+
field_resolver: Callable = None,
23+
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
24+
) -> ExecutionResult:
2325
"""Execute a GraphQL operation asynchronously.
2426
2527
This is the primary entry point function for fulfilling GraphQL operations
@@ -56,6 +58,8 @@ async def graphql(
5658
A resolver function to use when one is not provided by the schema.
5759
If not provided, the default field resolver is used (which looks for
5860
a value or method on the source value with the field's name).
61+
:arg middleware:
62+
The middleware to wrap the resolvers with
5963
"""
6064
# Always return asynchronously for a consistent API.
6165
result = graphql_impl(
@@ -65,7 +69,9 @@ async def graphql(
6569
context_value,
6670
variable_values,
6771
operation_name,
68-
field_resolver)
72+
field_resolver,
73+
middleware,
74+
)
6975

7076
if isawaitable(result):
7177
return await cast(Awaitable[ExecutionResult], result)
@@ -74,13 +80,15 @@ async def graphql(
7480

7581

7682
def graphql_sync(
77-
schema: GraphQLSchema,
78-
source: Union[str, Source],
79-
root_value: Any = None,
80-
context_value: Any = None,
81-
variable_values: Dict[str, Any] = None,
82-
operation_name: str = None,
83-
field_resolver: Callable = None) -> ExecutionResult:
83+
schema: GraphQLSchema,
84+
source: Union[str, Source],
85+
root_value: Any = None,
86+
context_value: Any = None,
87+
variable_values: Dict[str, Any] = None,
88+
operation_name: str = None,
89+
field_resolver: Callable = None,
90+
middleware: Optional[Union[Iterable[Any], MiddlewareManager]] = None,
91+
) -> ExecutionResult:
8492
"""Execute a GraphQL operation synchronously.
8593
8694
The graphql_sync function also fulfills GraphQL operations by parsing,
@@ -95,26 +103,28 @@ def graphql_sync(
95103
context_value,
96104
variable_values,
97105
operation_name,
98-
field_resolver)
106+
field_resolver,
107+
middleware,
108+
)
99109

100110
# Assert that the execution was synchronous.
101111
if isawaitable(result):
102112
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
103-
raise RuntimeError(
104-
'GraphQL execution failed to complete synchronously.')
113+
raise RuntimeError("GraphQL execution failed to complete synchronously.")
105114

106115
return cast(ExecutionResult, result)
107116

108117

109118
def graphql_impl(
110-
schema,
111-
source,
112-
root_value,
113-
context_value,
114-
variable_values,
115-
operation_name,
116-
field_resolver
117-
) -> MaybeAwaitable[ExecutionResult]:
119+
schema,
120+
source,
121+
root_value,
122+
context_value,
123+
variable_values,
124+
operation_name,
125+
field_resolver,
126+
middleware,
127+
) -> MaybeAwaitable[ExecutionResult]:
118128
"""Execute a query, return asynchronously only if necessary."""
119129
# Validate Schema
120130
schema_validation_errors = validate_schema(schema)
@@ -132,6 +142,7 @@ def graphql_impl(
132142

133143
# Validate
134144
from .validation import validate
145+
135146
validation_errors = validate(schema, document)
136147
if validation_errors:
137148
return ExecutionResult(data=None, errors=validation_errors)
@@ -144,4 +155,6 @@ def graphql_impl(
144155
context_value,
145156
variable_values,
146157
operation_name,
147-
field_resolver)
158+
field_resolver,
159+
middleware,
160+
)

0 commit comments

Comments
 (0)