Skip to content

Commit d36c024

Browse files
committed
refactor: add aiohttp-graphql as optional feature
1 parent 6c13ef6 commit d36c024

File tree

9 files changed

+1375
-1
lines changed

9 files changed

+1375
-1
lines changed

graphql_server/aiohttp/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .graphqlview import GraphQLView
2+
3+
__all__ = ["GraphQLView"]

graphql_server/aiohttp/graphqlview.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import copy
2+
from collections.abc import MutableMapping
3+
from functools import partial
4+
5+
from aiohttp import web
6+
from graphql import GraphQLError
7+
from graphql.type.schema import GraphQLSchema
8+
9+
from graphql_server import (
10+
HttpQueryError,
11+
encode_execution_results,
12+
format_error_default,
13+
json_encode,
14+
load_json_body,
15+
run_http_query,
16+
)
17+
18+
from .render_graphiql import render_graphiql
19+
20+
21+
class GraphQLView:
22+
schema = None
23+
root_value = None
24+
context = None
25+
pretty = False
26+
graphiql = False
27+
graphiql_version = None
28+
graphiql_template = None
29+
middleware = None
30+
batch = False
31+
jinja_env = None
32+
max_age = 86400
33+
enable_async = False
34+
subscriptions = None
35+
36+
accepted_methods = ["GET", "POST", "PUT", "DELETE"]
37+
38+
format_error = staticmethod(format_error_default)
39+
encode = staticmethod(json_encode)
40+
41+
def __init__(self, **kwargs):
42+
super(GraphQLView, self).__init__()
43+
for key, value in kwargs.items():
44+
if hasattr(self, key):
45+
setattr(self, key, value)
46+
47+
assert isinstance(
48+
self.schema, GraphQLSchema
49+
), "A Schema is required to be provided to GraphQLView."
50+
51+
def get_root_value(self):
52+
return self.root_value
53+
54+
def get_context(self, request):
55+
context = (
56+
copy.copy(self.context)
57+
if self.context and isinstance(self.context, MutableMapping)
58+
else {}
59+
)
60+
if isinstance(context, MutableMapping) and "request" not in context:
61+
context.update({"request": request})
62+
return context
63+
64+
def get_middleware(self):
65+
return self.middleware
66+
67+
# This method can be static
68+
async def parse_body(self, request):
69+
content_type = request.content_type
70+
# request.text() is the aiohttp equivalent to
71+
# request.body.decode("utf8")
72+
if content_type == "application/graphql":
73+
r_text = await request.text()
74+
return {"query": r_text}
75+
76+
if content_type == "application/json":
77+
text = await request.text()
78+
return load_json_body(text)
79+
80+
if content_type in (
81+
"application/x-www-form-urlencoded",
82+
"multipart/form-data",
83+
):
84+
# TODO: seems like a multidict would be more appropriate
85+
# than casting it and de-duping variables. Alas, it's what
86+
# graphql-python wants.
87+
return dict(await request.post())
88+
89+
return {}
90+
91+
def render_graphiql(self, params, result):
92+
return render_graphiql(
93+
jinja_env=self.jinja_env,
94+
params=params,
95+
result=result,
96+
graphiql_version=self.graphiql_version,
97+
graphiql_template=self.graphiql_template,
98+
subscriptions=self.subscriptions,
99+
)
100+
101+
# TODO:
102+
# use this method to replace flask and sanic
103+
# checks as this is equivalent to `should_display_graphiql` and
104+
# `request_wants_html` methods.
105+
def is_graphiql(self, request):
106+
return all(
107+
[
108+
self.graphiql,
109+
request.method.lower() == "get",
110+
"raw" not in request.query,
111+
any(
112+
[
113+
"text/html" in request.headers.get("accept", {}),
114+
"*/*" in request.headers.get("accept", {}),
115+
]
116+
),
117+
]
118+
)
119+
120+
# TODO: Same stuff as above method.
121+
def is_pretty(self, request):
122+
return any(
123+
[self.pretty, self.is_graphiql(request), request.query.get("pretty")]
124+
)
125+
126+
async def __call__(self, request):
127+
try:
128+
data = await self.parse_body(request)
129+
request_method = request.method.lower()
130+
is_graphiql = self.is_graphiql(request)
131+
is_pretty = self.is_pretty(request)
132+
133+
# TODO: way better than if-else so better
134+
# implement this too on flask and sanic
135+
if request_method == "options":
136+
return self.process_preflight(request)
137+
138+
execution_results, all_params = run_http_query(
139+
self.schema,
140+
request_method,
141+
data,
142+
query_data=request.query,
143+
batch_enabled=self.batch,
144+
catch=is_graphiql,
145+
# Execute options
146+
run_sync=not self.enable_async,
147+
root_value=self.get_root_value(),
148+
context_value=self.get_context(request),
149+
middleware=self.get_middleware(),
150+
)
151+
152+
exec_res = (
153+
[await ex for ex in execution_results]
154+
if self.enable_async
155+
else execution_results
156+
)
157+
result, status_code = encode_execution_results(
158+
exec_res,
159+
is_batch=isinstance(data, list),
160+
format_error=self.format_error,
161+
encode=partial(self.encode, pretty=is_pretty), # noqa: ignore
162+
)
163+
164+
if is_graphiql:
165+
return await self.render_graphiql(params=all_params[0], result=result)
166+
167+
return web.Response(
168+
text=result, status=status_code, content_type="application/json",
169+
)
170+
171+
except HttpQueryError as err:
172+
parsed_error = GraphQLError(err.message)
173+
return web.Response(
174+
body=self.encode(dict(errors=[self.format_error(parsed_error)])),
175+
status=err.status_code,
176+
headers=err.headers,
177+
content_type="application/json",
178+
)
179+
180+
def process_preflight(self, request):
181+
"""
182+
Preflight request support for apollo-client
183+
https://www.w3.org/TR/cors/#resource-preflight-requests
184+
"""
185+
headers = request.headers
186+
origin = headers.get("Origin", "")
187+
method = headers.get("Access-Control-Request-Method", "").upper()
188+
189+
if method and method in self.accepted_methods:
190+
return web.Response(
191+
status=200,
192+
headers={
193+
"Access-Control-Allow-Origin": origin,
194+
"Access-Control-Allow-Methods": ", ".join(self.accepted_methods),
195+
"Access-Control-Max-Age": str(self.max_age),
196+
},
197+
)
198+
return web.Response(status=400)
199+
200+
@classmethod
201+
def attach(cls, app, *, route_path="/graphql", route_name="graphql", **kwargs):
202+
view = cls(**kwargs)
203+
app.router.add_route("*", route_path, _asyncify(view), name=route_name)
204+
205+
206+
def _asyncify(handler):
207+
"""Return an async version of the given handler.
208+
209+
This is mainly here because ``aiohttp`` can't infer the async definition of
210+
:py:meth:`.GraphQLView.__call__` and raises a :py:class:`DeprecationWarning`
211+
in tests. Wrapping it into an async function avoids the noisy warning.
212+
"""
213+
214+
async def _dispatch(request):
215+
return await handler(request)
216+
217+
return _dispatch

0 commit comments

Comments
 (0)