1
1
from functools import partial
2
+ from typing import List
2
3
3
- from flask import Response , request
4
+ from flask import Response , render_template_string , request
4
5
from flask .views import View
5
6
from graphql .error import GraphQLError
6
7
from graphql .type .schema import GraphQLSchema
7
8
8
9
from graphql_server import (
10
+ GraphQLParams ,
9
11
HttpQueryError ,
10
12
encode_execution_results ,
11
13
format_error_default ,
12
14
json_encode ,
13
15
load_json_body ,
14
16
run_http_query ,
15
17
)
16
-
17
- from .render_graphiql import render_graphiql
18
+ from graphql_server .render_graphiql import (
19
+ GraphiQLConfig ,
20
+ GraphiQLData ,
21
+ render_graphiql_sync ,
22
+ )
18
23
19
24
20
25
class GraphQLView (View ):
@@ -27,6 +32,8 @@ class GraphQLView(View):
27
32
graphiql_html_title = None
28
33
middleware = None
29
34
batch = False
35
+ subscriptions = None
36
+ headers = None
30
37
31
38
methods = ["GET" , "POST" , "PUT" , "DELETE" ]
32
39
@@ -50,15 +57,6 @@ def get_context_value(self):
50
57
def get_middleware (self ):
51
58
return self .middleware
52
59
53
- def render_graphiql (self , params , result ):
54
- return render_graphiql (
55
- params = params ,
56
- result = result ,
57
- graphiql_version = self .graphiql_version ,
58
- graphiql_template = self .graphiql_template ,
59
- graphiql_html_title = self .graphiql_html_title ,
60
- )
61
-
62
60
format_error = staticmethod (format_error_default )
63
61
encode = staticmethod (json_encode )
64
62
@@ -72,6 +70,7 @@ def dispatch_request(self):
72
70
73
71
pretty = self .pretty or show_graphiql or request .args .get ("pretty" )
74
72
73
+ all_params : List [GraphQLParams ]
75
74
execution_results , all_params = run_http_query (
76
75
self .schema ,
77
76
request_method ,
@@ -88,11 +87,28 @@ def dispatch_request(self):
88
87
execution_results ,
89
88
is_batch = isinstance (data , list ),
90
89
format_error = self .format_error ,
91
- encode = partial (self .encode , pretty = pretty ),
90
+ encode = partial (self .encode , pretty = pretty ), # noqa
92
91
)
93
92
94
93
if show_graphiql :
95
- return self .render_graphiql (params = all_params [0 ], result = result )
94
+ graphiql_data = GraphiQLData (
95
+ result = result ,
96
+ query = getattr (all_params [0 ], "query" ),
97
+ variables = getattr (all_params [0 ], "variables" ),
98
+ operation_name = getattr (all_params [0 ], "operation_name" ),
99
+ subscription_url = self .subscriptions ,
100
+ headers = self .headers ,
101
+ )
102
+ graphiql_config = GraphiQLConfig (
103
+ graphiql_version = self .graphiql_version ,
104
+ graphiql_template = self .graphiql_template ,
105
+ graphiql_html_title = self .graphiql_html_title ,
106
+ jinja_env = None ,
107
+ )
108
+ source = render_graphiql_sync (
109
+ data = graphiql_data , config = graphiql_config
110
+ )
111
+ return render_template_string (source )
96
112
97
113
return Response (result , status = status_code , content_type = "application/json" )
98
114
0 commit comments