Skip to content

Commit ed70ebb

Browse files
authored
Expose Flight SQL endpoints on VDBEs, various improvements (#523)
1 parent 3b2825e commit ed70ebb

File tree

8 files changed

+199
-15
lines changed

8 files changed

+199
-15
lines changed

config/system_config_demo_s3.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# listen on successive ports (e.g., 6584, 6585, etc.).
77
front_end_interface: "0.0.0.0"
88
front_end_port: 6783
9-
num_front_ends: 3
9+
num_front_ends: 2
1010

1111
# If installed and enabled, BRAD will serve its UI from a webserver that listens
1212
# for connections on this network interface and port.
@@ -127,6 +127,7 @@ std_datasets:
127127
bootstrap_vdbe_path: config/vdbe_demo/imdb_etl_vdbes.json
128128
disable_query_logging: true
129129
vdbe_start_port: 10076
130+
flight_sql_mode: "vdbe"
130131

131132
aurora_max_query_factor: 4.0
132133
aurora_max_query_factor_replace: 10000.0

cpp/server/brad_server_simple.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
9292
columns.push_back(values);
9393

9494
} else if (field_type->Equals(
95-
arrow::decimal(/*precision=*/10, /*scale=*/2))) {
95+
arrow::decimal128(/*precision=*/10, /*scale=*/2))) {
9696
arrow::Decimal128Builder decimalbuilder(
97-
arrow::decimal(/*precision=*/10, /*scale=*/2));
97+
arrow::decimal128(/*precision=*/10, /*scale=*/2));
9898
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
9999
const std::optional<std::string> val =
100100
py::cast<std::optional<std::string>>(
@@ -149,6 +149,11 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
149149
std::shared_ptr<arrow::Array> values;
150150
ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish());
151151
columns.push_back(values);
152+
} else {
153+
std::cerr << "ERROR: Unsupported field type: " << field_type->ToString()
154+
<< std::endl;
155+
return arrow::Status::NotImplemented("Unsupported field type: ",
156+
field_type->ToString());
152157
}
153158
}
154159

src/brad/config/file.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,13 @@ def vdbe_start_port(self) -> int:
315315
return 9876 # Default
316316
return int(self._raw["vdbe_start_port"])
317317

318+
def flight_sql_mode(self) -> Optional[str]:
319+
try:
320+
return self._raw["flight_sql_mode"]
321+
except KeyError:
322+
# FlightSQL mode is not set.
323+
return None
324+
318325
def _extract_log_path(self, config_key: str) -> Optional[pathlib.Path]:
319326
if config_key not in self._raw:
320327
return None

src/brad/connection/factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ async def connect_to(
2929
return cls.connect_to_stub(config)
3030

3131
# HACK: Schema aliasing for convenience.
32-
if schema_name is not None and schema_name == "imdb_editable_100g":
32+
if schema_name is not None and (
33+
schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g"
34+
):
3335
schema_name = "imdb_extended_100g"
3436

3537
connection_details = config.get_connection_details(engine)
@@ -158,7 +160,9 @@ async def connect_to_sidecar(
158160
return cls.connect_to_stub(config)
159161

160162
# HACK: Schema aliasing for convenience.
161-
if schema_name is not None and schema_name == "imdb_editable_100g":
163+
if schema_name is not None and (
164+
schema_name == "imdb_editable_100g" or schema_name == "imdb_etl_100g"
165+
):
162166
schema_name = "imdb_extended_100g"
163167

164168
connection_details = config.get_sidecar_db_details()

src/brad/exec/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pathlib
33
import readline
44
import time
5+
import pyodbc
56
from typing import List, Tuple
67
from tabulate import tabulate
78

@@ -80,6 +81,11 @@ def run_query(client: BradGrpcClient | BradFlightSqlClientOdbc, query: str) -> N
8081
print("Query resulted in an error:")
8182
print(ex.message())
8283
print()
84+
except pyodbc.Error as ex:
85+
print()
86+
print("Query resulted in an error:")
87+
print(repr(ex))
88+
print()
8389

8490

8591
class BradShell(cmd.Cmd):
@@ -145,7 +151,10 @@ def main(args) -> None:
145151
host, port = parse_endpoint(args.endpoint)
146152
print("BRAD Interactive Shell v{}".format(brad.__version__))
147153
print()
148-
print("Connecting to BRAD VDBE at {}:{}...".format(host, port))
154+
if args.use_odbc:
155+
print("Connecting to BRAD VDBE at {}:{} (using ODBC)...".format(host, port))
156+
else:
157+
print("Connecting to BRAD VDBE at {}:{}...".format(host, port))
149158

150159
def run_shell(client: BradGrpcClient | BradFlightSqlClientOdbc) -> None:
151160
print("Connected!")

src/brad/front_end/front_end.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def __init__(
8787
input_queue: mp.Queue,
8888
output_queue: mp.Queue,
8989
):
90-
if BradFrontEnd.native_server_is_supported():
90+
if (
91+
BradFrontEnd.native_server_is_supported()
92+
and config.flight_sql_mode() == "front_end"
93+
):
9194
from brad.front_end.flight_sql_server import BradFlightSqlServer
9295

9396
self._flight_sql_server: Optional[BradFlightSqlServer] = (
@@ -98,8 +101,12 @@ def __init__(
98101
)
99102
)
100103
self._flight_sql_server_session_id: Optional[SessionId] = None
104+
logger.info(
105+
"FlightSQL server is enabled for the front end. Will listen on port 31337."
106+
)
101107
else:
102108
self._flight_sql_server = None
109+
logger.info("FlightSQL server is disabled for the front end.")
103110

104111
self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None
105112

src/brad/front_end/vdbe/vdbe_endpoint_manager.py

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
import asyncio
12
import grpc
23
import json
34
import logging
4-
from typing import Callable, Optional, Tuple, Dict, AsyncIterable, Any, Set, Awaitable
5+
import threading
6+
from typing import (
7+
Callable,
8+
Optional,
9+
Tuple,
10+
Dict,
11+
AsyncIterable,
12+
Any,
13+
Set,
14+
Awaitable,
15+
)
516

617
import brad.proto_gen.brad_pb2_grpc as brad_grpc
7-
from brad.connection.schema import Schema
18+
from brad.config.file import ConfigFile
19+
from brad.connection.schema import Schema, DataType
820
from brad.front_end.brad_interface import BradInterface
921
from brad.front_end.grpc import BradGrpc
1022
from brad.front_end.session import SessionManager, SessionId
@@ -15,9 +27,10 @@
1527
logger = logging.getLogger(__name__)
1628

1729

18-
# (query_string, vdbe_id, session_id, debug_info) -> (rows, schema)
30+
# (query_string, vdbe_id, session_id, debug_info, retrieve_schema) -> (rows, schema)
1931
QueryHandler = Callable[
20-
[str, int, SessionId, Dict[str, Any]], Awaitable[Tuple[RowList, Optional[Schema]]]
32+
[str, int, SessionId, Dict[str, Any], bool],
33+
Awaitable[Tuple[RowList, Optional[Schema]]],
2134
]
2235

2336

@@ -33,11 +46,30 @@ def __init__(
3346
vdbe_mgr: VdbeFrontEndManager,
3447
session_mgr: SessionManager,
3548
handler: QueryHandler,
49+
config: ConfigFile,
3650
) -> None:
3751
self._vdbe_mgr = vdbe_mgr
3852
self._session_mgr = session_mgr
3953
self._handler = handler
54+
self._config = config
4055
self._endpoints: Dict[int, Tuple[int, grpc.aio.Server, VdbeGrpcInterface]] = {}
56+
self._flight_sql_endpoints: Dict[int, Tuple[int, VdbeFlightSqlServer]] = {}
57+
58+
try:
59+
# pylint: disable-next=import-error,no-name-in-module,unused-import
60+
import brad.native.pybind_brad_server as brad_server
61+
62+
self._use_flight_sql = self._config.flight_sql_mode() == "vdbe"
63+
except ImportError:
64+
self._use_flight_sql = False
65+
66+
if self._use_flight_sql:
67+
logger.info("Will start Flight SQL endpoints for VDBEs.")
68+
else:
69+
logger.info(
70+
"Flight SQL endpoints for VDBEs are not available. "
71+
"Using gRPC endpoints only."
72+
)
4173

4274
async def initialize(self) -> None:
4375
for engine in self._vdbe_mgr.engines():
@@ -66,22 +98,59 @@ async def add_vdbe_endpoint(self, port: int, vdbe_id: int) -> None:
6698
grpc_server.add_insecure_port(f"0.0.0.0:{port}")
6799
await grpc_server.start()
68100
logger.info(
69-
"Added VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port
101+
"Added gRPC VDBE endpoint for ID %d. Listening on port %d.", vdbe_id, port
70102
)
71103
self._endpoints[vdbe_id] = (port, grpc_server, query_service)
72104

105+
if self._use_flight_sql:
106+
session_id, _ = await self._session_mgr.create_new_session()
107+
# The flight SQL port is offset by 10,000 from the gRPC port.
108+
flight_sql_port = port + 10_000
109+
flight_sql_server = VdbeFlightSqlServer(
110+
vdbe_id=vdbe_id,
111+
port=flight_sql_port,
112+
main_loop=asyncio.get_event_loop(),
113+
handler=self._handler,
114+
session_id=session_id,
115+
)
116+
flight_sql_server.start()
117+
self._flight_sql_endpoints[vdbe_id] = (flight_sql_port, flight_sql_server)
118+
logger.info(
119+
"Added Flight SQL VDBE endpoint for ID %d. Listening on port %d.",
120+
vdbe_id,
121+
flight_sql_port,
122+
)
123+
73124
async def remove_vdbe_endpoint(self, vdbe_id: int) -> None:
74125
try:
75126
port, grpc_server, query_service = self._endpoints[vdbe_id]
76127
await query_service.end_all_sessions()
77128
# See `brad.front_end.BradFrontEnd.serve_forever`.
78129
grpc_server.__del__()
79130
del self._endpoints[vdbe_id]
80-
logger.info("Removed VDBE endpoint for ID %d (was port %d).", vdbe_id, port)
131+
logger.info(
132+
"Removed gRPC VDBE endpoint for ID %d (was port %d).", vdbe_id, port
133+
)
134+
135+
except KeyError:
136+
logger.error(
137+
"Tried to remove gRPC VDBE endpoint for ID %d, but it was not found.",
138+
vdbe_id,
139+
)
81140

141+
try:
142+
port, flight_sql_server = self._flight_sql_endpoints[vdbe_id]
143+
flight_sql_server.stop()
144+
del self._flight_sql_endpoints[vdbe_id]
145+
await self._session_mgr.end_session(flight_sql_server.session_id)
146+
logger.info(
147+
"Removed Flight SQL VDBE endpoint for ID %d (was port %d).",
148+
vdbe_id,
149+
port,
150+
)
82151
except KeyError:
83152
logger.error(
84-
"Tried to remove VDBE endpoint for ID %d, but it was not found.",
153+
"Tried to remove Flight SQL VDBE endpoint for ID %d, but it was not found.",
85154
vdbe_id,
86155
)
87156

@@ -144,7 +213,9 @@ async def run_query_json(
144213
145214
This method may throw an error to indicate a problem with the query.
146215
"""
147-
results, _ = await self._handler(query, self._vdbe_id, session_id, debug_info)
216+
results, _ = await self._handler(
217+
query, self._vdbe_id, session_id, debug_info, False
218+
)
148219
return json.dumps(results, cls=DecimalEncoder, default=str)
149220

150221
async def end_session(self, session_id: SessionId) -> None:
@@ -156,3 +227,82 @@ async def end_all_sessions(self) -> None:
156227
self._our_sessions.clear()
157228
for session_id in our_sessions:
158229
await self._session_mgr.end_session(session_id)
230+
231+
232+
class VdbeFlightSqlServer:
233+
def __init__(
234+
self,
235+
*,
236+
vdbe_id: int,
237+
port: int,
238+
main_loop: asyncio.AbstractEventLoop,
239+
handler: QueryHandler,
240+
session_id: SessionId,
241+
) -> None:
242+
# pylint: disable-next=import-error,no-name-in-module
243+
import brad.native.pybind_brad_server as brad_server
244+
245+
# pylint: disable-next=c-extension-no-member
246+
self._flight_sql_server = brad_server.BradFlightSqlServer()
247+
self._flight_sql_server.init("0.0.0.0", port, self._handle_query)
248+
self._thread = threading.Thread(
249+
name=f"FlightSqlServer-{vdbe_id}", target=self._serve
250+
)
251+
self._vdbe_id = vdbe_id
252+
self._port = port
253+
# Important: The endpoint manager is responsible for creating and
254+
# terminating the session.
255+
self.session_id = session_id
256+
257+
self._main_loop = main_loop
258+
self._handler = handler
259+
260+
def start(self) -> None:
261+
self._thread.start()
262+
263+
def stop(self) -> None:
264+
logger.info(
265+
"BRAD FlightSQL server stopping (port %d, VDBE %d)...",
266+
self._port,
267+
self._vdbe_id,
268+
)
269+
self._flight_sql_server.shutdown()
270+
self._thread.join()
271+
logger.info(
272+
"BRAD FlightSQL server stopped (port %d, VDBE %d).",
273+
self._port,
274+
self._vdbe_id,
275+
)
276+
277+
def _serve(self) -> None:
278+
self._flight_sql_server.serve()
279+
280+
def _handle_query(self, query: str) -> Tuple[RowList, Schema]:
281+
# This method is called from a separate thread. So it's very important
282+
# to schedule the handler on the main event loop thread.
283+
debug_info: Dict[str, Any] = {}
284+
future = asyncio.run_coroutine_threadsafe( # type: ignore
285+
self._handler(query, self._vdbe_id, self.session_id, debug_info, True), # type: ignore
286+
self._main_loop,
287+
)
288+
row_result, schema = future.result()
289+
assert schema is not None
290+
291+
# We need to do extra processing for decimal fields since our C++
292+
# backend expects them as strings.
293+
decimal_fields = []
294+
for idx, field in enumerate(schema.fields):
295+
if field.data_type == DataType.Decimal:
296+
decimal_fields.append(idx)
297+
298+
if len(decimal_fields) > 0:
299+
new_rows = []
300+
for row in row_result:
301+
new_row = tuple(
302+
str(value) if idx in decimal_fields else value
303+
for idx, value in enumerate(row)
304+
)
305+
new_rows.append(new_row)
306+
row_result = new_rows
307+
308+
return row_result, schema

src/brad/front_end/vdbe/vdbe_front_end.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
vdbe_mgr=self._vdbe_mgr,
131131
session_mgr=self._sessions,
132132
handler=self._run_query_impl,
133+
config=self._config,
133134
)
134135
self._shutdown_event = asyncio.Event()
135136

0 commit comments

Comments
 (0)