1+ import asyncio
12import grpc
23import json
34import logging
4- from typing import Callable , Optional , Tuple , Dict , AsyncIterable , Any , Set , Awaitable
5+ import threading
6+ from typing import (
7+ Callable ,
8+ Optional ,
9+ Tuple ,
10+ Dict ,
11+ AsyncIterable ,
12+ Any ,
13+ Set ,
14+ Awaitable ,
15+ )
516
617import brad .proto_gen .brad_pb2_grpc as brad_grpc
7- from brad .connection .schema import Schema
18+ from brad .config .file import ConfigFile
19+ from brad .connection .schema import Schema , DataType
820from brad .front_end .brad_interface import BradInterface
921from brad .front_end .grpc import BradGrpc
1022from brad .front_end .session import SessionManager , SessionId
1527logger = logging .getLogger (__name__ )
1628
1729
18- # (query_string, vdbe_id, session_id, debug_info) -> (rows, schema)
30+ # (query_string, vdbe_id, session_id, debug_info, retrieve_schema ) -> (rows, schema)
1931QueryHandler = Callable [
20- [str , int , SessionId , Dict [str , Any ]], Awaitable [Tuple [RowList , Optional [Schema ]]]
32+ [str , int , SessionId , Dict [str , Any ], bool ],
33+ Awaitable [Tuple [RowList , Optional [Schema ]]],
2134]
2235
2336
@@ -33,11 +46,30 @@ def __init__(
3346 vdbe_mgr : VdbeFrontEndManager ,
3447 session_mgr : SessionManager ,
3548 handler : QueryHandler ,
49+ config : ConfigFile ,
3650 ) -> None :
3751 self ._vdbe_mgr = vdbe_mgr
3852 self ._session_mgr = session_mgr
3953 self ._handler = handler
54+ self ._config = config
4055 self ._endpoints : Dict [int , Tuple [int , grpc .aio .Server , VdbeGrpcInterface ]] = {}
56+ self ._flight_sql_endpoints : Dict [int , Tuple [int , VdbeFlightSqlServer ]] = {}
57+
58+ try :
59+ # pylint: disable-next=import-error,no-name-in-module,unused-import
60+ import brad .native .pybind_brad_server as brad_server
61+
62+ self ._use_flight_sql = self ._config .flight_sql_mode () == "vdbe"
63+ except ImportError :
64+ self ._use_flight_sql = False
65+
66+ if self ._use_flight_sql :
67+ logger .info ("Will start Flight SQL endpoints for VDBEs." )
68+ else :
69+ logger .info (
70+ "Flight SQL endpoints for VDBEs are not available. "
71+ "Using gRPC endpoints only."
72+ )
4173
4274 async def initialize (self ) -> None :
4375 for engine in self ._vdbe_mgr .engines ():
@@ -66,22 +98,59 @@ async def add_vdbe_endpoint(self, port: int, vdbe_id: int) -> None:
6698 grpc_server .add_insecure_port (f"0.0.0.0:{ port } " )
6799 await grpc_server .start ()
68100 logger .info (
69- "Added VDBE endpoint for ID %d. Listening on port %d." , vdbe_id , port
101+ "Added gRPC VDBE endpoint for ID %d. Listening on port %d." , vdbe_id , port
70102 )
71103 self ._endpoints [vdbe_id ] = (port , grpc_server , query_service )
72104
105+ if self ._use_flight_sql :
106+ session_id , _ = await self ._session_mgr .create_new_session ()
107+ # The flight SQL port is offset by 10,000 from the gRPC port.
108+ flight_sql_port = port + 10_000
109+ flight_sql_server = VdbeFlightSqlServer (
110+ vdbe_id = vdbe_id ,
111+ port = flight_sql_port ,
112+ main_loop = asyncio .get_event_loop (),
113+ handler = self ._handler ,
114+ session_id = session_id ,
115+ )
116+ flight_sql_server .start ()
117+ self ._flight_sql_endpoints [vdbe_id ] = (flight_sql_port , flight_sql_server )
118+ logger .info (
119+ "Added Flight SQL VDBE endpoint for ID %d. Listening on port %d." ,
120+ vdbe_id ,
121+ flight_sql_port ,
122+ )
123+
73124 async def remove_vdbe_endpoint (self , vdbe_id : int ) -> None :
74125 try :
75126 port , grpc_server , query_service = self ._endpoints [vdbe_id ]
76127 await query_service .end_all_sessions ()
77128 # See `brad.front_end.BradFrontEnd.serve_forever`.
78129 grpc_server .__del__ ()
79130 del self ._endpoints [vdbe_id ]
80- logger .info ("Removed VDBE endpoint for ID %d (was port %d)." , vdbe_id , port )
131+ logger .info (
132+ "Removed gRPC VDBE endpoint for ID %d (was port %d)." , vdbe_id , port
133+ )
134+
135+ except KeyError :
136+ logger .error (
137+ "Tried to remove gRPC VDBE endpoint for ID %d, but it was not found." ,
138+ vdbe_id ,
139+ )
81140
141+ try :
142+ port , flight_sql_server = self ._flight_sql_endpoints [vdbe_id ]
143+ flight_sql_server .stop ()
144+ del self ._flight_sql_endpoints [vdbe_id ]
145+ await self ._session_mgr .end_session (flight_sql_server .session_id )
146+ logger .info (
147+ "Removed Flight SQL VDBE endpoint for ID %d (was port %d)." ,
148+ vdbe_id ,
149+ port ,
150+ )
82151 except KeyError :
83152 logger .error (
84- "Tried to remove VDBE endpoint for ID %d, but it was not found." ,
153+ "Tried to remove Flight SQL VDBE endpoint for ID %d, but it was not found." ,
85154 vdbe_id ,
86155 )
87156
@@ -144,7 +213,9 @@ async def run_query_json(
144213
145214 This method may throw an error to indicate a problem with the query.
146215 """
147- results , _ = await self ._handler (query , self ._vdbe_id , session_id , debug_info )
216+ results , _ = await self ._handler (
217+ query , self ._vdbe_id , session_id , debug_info , False
218+ )
148219 return json .dumps (results , cls = DecimalEncoder , default = str )
149220
150221 async def end_session (self , session_id : SessionId ) -> None :
@@ -156,3 +227,82 @@ async def end_all_sessions(self) -> None:
156227 self ._our_sessions .clear ()
157228 for session_id in our_sessions :
158229 await self ._session_mgr .end_session (session_id )
230+
231+
232+ class VdbeFlightSqlServer :
233+ def __init__ (
234+ self ,
235+ * ,
236+ vdbe_id : int ,
237+ port : int ,
238+ main_loop : asyncio .AbstractEventLoop ,
239+ handler : QueryHandler ,
240+ session_id : SessionId ,
241+ ) -> None :
242+ # pylint: disable-next=import-error,no-name-in-module
243+ import brad .native .pybind_brad_server as brad_server
244+
245+ # pylint: disable-next=c-extension-no-member
246+ self ._flight_sql_server = brad_server .BradFlightSqlServer ()
247+ self ._flight_sql_server .init ("0.0.0.0" , port , self ._handle_query )
248+ self ._thread = threading .Thread (
249+ name = f"FlightSqlServer-{ vdbe_id } " , target = self ._serve
250+ )
251+ self ._vdbe_id = vdbe_id
252+ self ._port = port
253+ # Important: The endpoint manager is responsible for creating and
254+ # terminating the session.
255+ self .session_id = session_id
256+
257+ self ._main_loop = main_loop
258+ self ._handler = handler
259+
260+ def start (self ) -> None :
261+ self ._thread .start ()
262+
263+ def stop (self ) -> None :
264+ logger .info (
265+ "BRAD FlightSQL server stopping (port %d, VDBE %d)..." ,
266+ self ._port ,
267+ self ._vdbe_id ,
268+ )
269+ self ._flight_sql_server .shutdown ()
270+ self ._thread .join ()
271+ logger .info (
272+ "BRAD FlightSQL server stopped (port %d, VDBE %d)." ,
273+ self ._port ,
274+ self ._vdbe_id ,
275+ )
276+
277+ def _serve (self ) -> None :
278+ self ._flight_sql_server .serve ()
279+
280+ def _handle_query (self , query : str ) -> Tuple [RowList , Schema ]:
281+ # This method is called from a separate thread. So it's very important
282+ # to schedule the handler on the main event loop thread.
283+ debug_info : Dict [str , Any ] = {}
284+ future = asyncio .run_coroutine_threadsafe ( # type: ignore
285+ self ._handler (query , self ._vdbe_id , self .session_id , debug_info , True ), # type: ignore
286+ self ._main_loop ,
287+ )
288+ row_result , schema = future .result ()
289+ assert schema is not None
290+
291+ # We need to do extra processing for decimal fields since our C++
292+ # backend expects them as strings.
293+ decimal_fields = []
294+ for idx , field in enumerate (schema .fields ):
295+ if field .data_type == DataType .Decimal :
296+ decimal_fields .append (idx )
297+
298+ if len (decimal_fields ) > 0 :
299+ new_rows = []
300+ for row in row_result :
301+ new_row = tuple (
302+ str (value ) if idx in decimal_fields else value
303+ for idx , value in enumerate (row )
304+ )
305+ new_rows .append (new_row )
306+ row_result = new_rows
307+
308+ return row_result , schema
0 commit comments