Skip to content

Commit 124c2d9

Browse files
authored
Merge pull request #27 from graphql-python/features/graphql-server
Use graphql_server package for reusability
2 parents 724695a + 12c3e30 commit 124c2d9

File tree

5 files changed

+305
-205
lines changed

5 files changed

+305
-205
lines changed

flask_graphql/graphqlview.py

+60-180
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1-
import json
1+
from functools import partial
22

3-
import six
43
from flask import Response, request
54
from flask.views import View
6-
from werkzeug.exceptions import BadRequest, MethodNotAllowed
75

8-
from graphql import Source, execute, parse, validate
9-
from graphql.error import format_error as format_graphql_error
10-
from graphql.error import GraphQLError
11-
from graphql.execution import ExecutionResult
126
from graphql.type.schema import GraphQLSchema
13-
from graphql.utils.get_operation_ast import get_operation_ast
7+
from graphql_server import (HttpQueryError, default_format_error,
8+
encode_execution_results, json_encode,
9+
load_json_body, run_http_query)
1410

1511
from .render_graphiql import render_graphiql
1612

1713

18-
class HttpError(Exception):
19-
def __init__(self, response, message=None, *args, **kwargs):
20-
self.response = response
21-
self.message = message = message or response.description
22-
super(HttpError, self).__init__(message, *args, **kwargs)
23-
24-
2514
class GraphQLView(View):
2615
schema = None
2716
executor = None
@@ -42,52 +31,68 @@ def __init__(self, **kwargs):
4231
if hasattr(self, key):
4332
setattr(self, key, value)
4433

45-
assert not all((self.graphiql, self.batch)), 'Use either graphiql or batch processing'
4634
assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
4735

4836
# noinspection PyUnusedLocal
49-
def get_root_value(self, request):
37+
def get_root_value(self):
5038
return self.root_value
5139

52-
def get_context(self, request):
40+
def get_context(self):
5341
if self.context is not None:
5442
return self.context
5543
return request
5644

57-
def get_middleware(self, request):
45+
def get_middleware(self):
5846
return self.middleware
5947

60-
def get_executor(self, request):
48+
def get_executor(self):
6149
return self.executor
6250

63-
def render_graphiql(self, **kwargs):
51+
def render_graphiql(self, params, result):
6452
return render_graphiql(
53+
params=params,
54+
result=result,
6555
graphiql_version=self.graphiql_version,
6656
graphiql_template=self.graphiql_template,
67-
**kwargs
6857
)
6958

59+
format_error = staticmethod(default_format_error)
60+
encode = staticmethod(json_encode)
61+
7062
def dispatch_request(self):
7163
try:
72-
if request.method.lower() not in ('get', 'post'):
73-
raise HttpError(MethodNotAllowed(['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
74-
75-
data = self.parse_body(request)
76-
show_graphiql = self.graphiql and self.can_display_graphiql(data)
77-
78-
if self.batch:
79-
responses = [self.get_response(request, entry) for entry in data]
80-
result = '[{}]'.format(','.join([response[0] for response in responses]))
81-
status_code = max(responses, key=lambda response: response[1])[1]
82-
else:
83-
result, status_code = self.get_response(request, data, show_graphiql)
64+
request_method = request.method.lower()
65+
data = self.parse_body()
66+
67+
show_graphiql = request_method == 'get' and self.should_display_graphiql()
68+
catch = HttpQueryError if show_graphiql else None
69+
70+
pretty = self.pretty or show_graphiql or request.args.get('pretty')
71+
72+
execution_results, all_params = run_http_query(
73+
self.schema,
74+
request_method,
75+
data,
76+
query_data=request.args,
77+
batch_enabled=self.batch,
78+
catch=catch,
79+
80+
# Execute options
81+
root_value=self.get_root_value(),
82+
context_value=self.get_context(),
83+
middleware=self.get_middleware(),
84+
executor=self.get_executor(),
85+
)
86+
result, status_code = encode_execution_results(
87+
execution_results,
88+
is_batch=isinstance(data, list),
89+
format_error=self.format_error,
90+
encode=partial(self.encode, pretty=pretty)
91+
)
8492

8593
if show_graphiql:
86-
query, variables, operation_name, id = self.get_graphql_params(request, data)
8794
return self.render_graphiql(
88-
query=query,
89-
variables=variables,
90-
operation_name=operation_name,
95+
params=all_params[0],
9196
result=result
9297
)
9398

@@ -97,167 +102,42 @@ def dispatch_request(self):
97102
content_type='application/json'
98103
)
99104

100-
except HttpError as e:
105+
except HttpQueryError as e:
101106
return Response(
102-
self.json_encode(request, {
107+
self.encode({
103108
'errors': [self.format_error(e)]
104109
}),
105-
status=e.response.code,
106-
headers={'Allow': ['GET, POST']},
110+
status=e.status_code,
111+
headers=e.headers,
107112
content_type='application/json'
108113
)
109114

110-
def get_response(self, request, data, show_graphiql=False):
111-
query, variables, operation_name, id = self.get_graphql_params(request, data)
112-
113-
execution_result = self.execute_graphql_request(
114-
data,
115-
query,
116-
variables,
117-
operation_name,
118-
show_graphiql
119-
)
120-
121-
status_code = 200
122-
if execution_result:
123-
response = {}
124-
125-
if execution_result.errors:
126-
response['errors'] = [self.format_error(e) for e in execution_result.errors]
127-
128-
if execution_result.invalid:
129-
status_code = 400
130-
else:
131-
status_code = 200
132-
response['data'] = execution_result.data
133-
134-
if self.batch:
135-
response = {
136-
'id': id,
137-
'payload': response,
138-
'status': status_code,
139-
}
140-
141-
result = self.json_encode(request, response, show_graphiql)
142-
else:
143-
result = None
144-
145-
return result, status_code
146-
147-
def json_encode(self, request, d, show_graphiql=False):
148-
pretty = self.pretty or show_graphiql or request.args.get('pretty')
149-
if not pretty:
150-
return json.dumps(d, separators=(',', ':'))
151-
152-
return json.dumps(d, sort_keys=True,
153-
indent=2, separators=(',', ': '))
154-
115+
# Flask
155116
# noinspection PyBroadException
156-
def parse_body(self, request):
157-
content_type = self.get_content_type(request)
117+
def parse_body(self):
118+
# We use mimetype here since we don't need the other
119+
# information provided by content_type
120+
content_type = request.mimetype
158121
if content_type == 'application/graphql':
159-
return {'query': request.data.decode()}
122+
return {'query': request.data.decode('utf8')}
160123

161124
elif content_type == 'application/json':
162-
try:
163-
request_json = json.loads(request.data.decode('utf8'))
164-
if self.batch:
165-
assert isinstance(request_json, list)
166-
else:
167-
assert isinstance(request_json, dict)
168-
return request_json
169-
except:
170-
raise HttpError(BadRequest('POST body sent invalid JSON.'))
125+
return load_json_body(request.data.decode('utf8'))
171126

172-
elif content_type == 'application/x-www-form-urlencoded':
173-
return request.form
174-
175-
elif content_type == 'multipart/form-data':
127+
elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'):
176128
return request.form
177129

178130
return {}
179131

180-
def execute(self, *args, **kwargs):
181-
return execute(self.schema, *args, **kwargs)
182-
183-
def execute_graphql_request(self, data, query, variables, operation_name, show_graphiql=False):
184-
if not query:
185-
if show_graphiql:
186-
return None
187-
raise HttpError(BadRequest('Must provide query string.'))
188-
189-
try:
190-
source = Source(query, name='GraphQL request')
191-
ast = parse(source)
192-
validation_errors = validate(self.schema, ast)
193-
if validation_errors:
194-
return ExecutionResult(
195-
errors=validation_errors,
196-
invalid=True,
197-
)
198-
except Exception as e:
199-
return ExecutionResult(errors=[e], invalid=True)
200-
201-
if request.method.lower() == 'get':
202-
operation_ast = get_operation_ast(ast, operation_name)
203-
if operation_ast and operation_ast.operation != 'query':
204-
if show_graphiql:
205-
return None
206-
raise HttpError(MethodNotAllowed(
207-
['POST'], 'Can only perform a {} operation from a POST request.'.format(operation_ast.operation)
208-
))
209-
210-
try:
211-
return self.execute(
212-
ast,
213-
root_value=self.get_root_value(request),
214-
variable_values=variables or {},
215-
operation_name=operation_name,
216-
context_value=self.get_context(request),
217-
middleware=self.get_middleware(request),
218-
executor=self.get_executor(request)
219-
)
220-
except Exception as e:
221-
return ExecutionResult(errors=[e], invalid=True)
132+
def should_display_graphiql(self):
133+
if not self.graphiql or 'raw' in request.args:
134+
return False
222135

223-
@classmethod
224-
def can_display_graphiql(cls, data):
225-
raw = 'raw' in request.args or 'raw' in data
226-
return not raw and cls.request_wants_html(request)
136+
return self.request_wants_html()
227137

228-
@classmethod
229-
def request_wants_html(cls, request):
138+
def request_wants_html(self):
230139
best = request.accept_mimetypes \
231140
.best_match(['application/json', 'text/html'])
232141
return best == 'text/html' and \
233142
request.accept_mimetypes[best] > \
234143
request.accept_mimetypes['application/json']
235-
236-
@staticmethod
237-
def get_graphql_params(request, data):
238-
query = request.args.get('query') or data.get('query')
239-
variables = request.args.get('variables') or data.get('variables')
240-
id = request.args.get('id') or data.get('id')
241-
242-
if variables and isinstance(variables, six.text_type):
243-
try:
244-
variables = json.loads(variables)
245-
except:
246-
raise HttpError(BadRequest('Variables are invalid JSON.'))
247-
248-
operation_name = request.args.get('operationName') or data.get('operationName')
249-
250-
return query, variables, operation_name, id
251-
252-
@staticmethod
253-
def format_error(error):
254-
if isinstance(error, GraphQLError):
255-
return format_graphql_error(error)
256-
257-
return {'message': six.text_type(error)}
258-
259-
@staticmethod
260-
def get_content_type(request):
261-
# We use mimetype here since we don't need the other
262-
# information provided by content_type
263-
return request.mimetype

flask_graphql/render_graphiql.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from flask import render_template_string
22

3-
43
GRAPHIQL_VERSION = '0.7.1'
54

65
TEMPLATE = '''<!--
@@ -112,10 +111,10 @@
112111
onEditQuery: onEditQuery,
113112
onEditVariables: onEditVariables,
114113
onEditOperationName: onEditOperationName,
115-
query: {{ query|tojson }},
114+
query: {{ params.query|tojson }},
116115
response: {{ result|tojson }},
117-
variables: {{ variables|tojson }},
118-
operationName: {{ operation_name|tojson }},
116+
variables: {{ params.variables|tojson }},
117+
operationName: {{ params.operation_name|tojson }},
119118
}),
120119
document.body
121120
);
@@ -124,9 +123,13 @@
124123
</html>'''
125124

126125

127-
def render_graphiql(graphiql_version=None, graphiql_template=None, **kwargs):
126+
def render_graphiql(params, result, graphiql_version=None, graphiql_template=None):
128127
graphiql_version = graphiql_version or GRAPHIQL_VERSION
129128
template = graphiql_template or TEMPLATE
130129

131130
return render_template_string(
132-
template, graphiql_version=graphiql_version, **kwargs)
131+
template,
132+
graphiql_version=graphiql_version,
133+
result=result,
134+
params=params
135+
)

0 commit comments

Comments
 (0)