diff --git a/graphql/execution/middleware.py b/graphql/execution/middleware.py index fa7fc5ad..22bcfd32 100644 --- a/graphql/execution/middleware.py +++ b/graphql/execution/middleware.py @@ -9,14 +9,19 @@ class MiddlewareManager(object): - def __init__(self, *middlewares): + def __init__(self, *middlewares, **kwargs): self.middlewares = middlewares + self.wrap_in_promise = kwargs.get('wrap_in_promise', True) self._middleware_resolvers = list(get_middleware_resolvers(middlewares)) self._cached_resolvers = {} def get_field_resolver(self, field_resolver): if field_resolver not in self._cached_resolvers: - self._cached_resolvers[field_resolver] = middleware_chain(field_resolver, self._middleware_resolvers) + self._cached_resolvers[field_resolver] = middleware_chain( + field_resolver, + self._middleware_resolvers, + wrap_in_promise=self.wrap_in_promise, + ) return self._cached_resolvers[field_resolver] @@ -34,10 +39,13 @@ def get_middleware_resolvers(middlewares): yield getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION) -def middleware_chain(func, middlewares): +def middleware_chain(func, middlewares, wrap_in_promise): if not middlewares: return func - middlewares = chain((func, make_it_promise), middlewares) + if wrap_in_promise: + middlewares = chain((func, make_it_promise), middlewares) + else: + middlewares = chain((func,), middlewares) last_func = None for middleware in middlewares: last_func = partial(middleware, last_func) if last_func else middleware diff --git a/graphql/execution/tests/test_executor.py b/graphql/execution/tests/test_executor.py index 2bbaf2c2..7d9b5e7f 100644 --- a/graphql/execution/tests/test_executor.py +++ b/graphql/execution/tests/test_executor.py @@ -8,6 +8,7 @@ from graphql.type import (GraphQLArgument, GraphQLBoolean, GraphQLField, GraphQLInt, GraphQLList, GraphQLObjectType, GraphQLSchema, GraphQLString) +from promise import Promise def test_executes_arbitary_code(): @@ -607,3 +608,40 @@ def resolve(self, next, *args, **kwargs): middlewares = MiddlewareManager(MyMiddleware()) result = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares) assert result.data == {'ok': 'ko', 'not_ok': 'ko_ton'} + + +def test_middleware_skip_promise_wrap(): + doc = '''{ + ok + not_ok + }''' + + class Data(object): + + def ok(self): + return 'ok' + + def not_ok(self): + return 'not_ok' + + doc_ast = parse(doc) + + Type = GraphQLObjectType('Type', { + 'ok': GraphQLField(GraphQLString), + 'not_ok': GraphQLField(GraphQLString), + }) + + class MyPromiseMiddleware(object): + def resolve(self, next, *args, **kwargs): + return Promise.resolve(next(*args, **kwargs)) + + class MyEmptyMiddleware(object): + def resolve(self, next, *args, **kwargs): + return next(*args, **kwargs) + + middlewares_with_promise = MiddlewareManager(MyPromiseMiddleware(), wrap_in_promise=False) + middlewares_without_promise = MiddlewareManager(MyEmptyMiddleware(), wrap_in_promise=False) + + result1 = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares_with_promise) + result2 = execute(GraphQLSchema(Type), doc_ast, Data(), middleware=middlewares_without_promise) + assert result1.data == result2.data and result1.data == {'ok': 'ok', 'not_ok': 'not_ok'}