Skip to content

Commit 63daa8a

Browse files
authored
feat: Implementation for partitioned query in dbapi (#1067)
* feat: Implementation for partitioned query in dbapi * Comments incorporated and added more tests * Small fix * Test fix * Removing ClientSideStatementParamKey enum * Comments incorporated
1 parent c4210b2 commit 63daa8a

File tree

11 files changed

+324
-34
lines changed

11 files changed

+324
-34
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
5050
:param parsed_statement: parsed_statement based on the sql query
5151
"""
5252
connection = cursor.connection
53+
column_values = []
5354
if connection.is_closed:
5455
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
5556
statement_type = parsed_statement.client_side_statement_type
@@ -63,24 +64,26 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
6364
connection.rollback()
6465
return None
6566
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
66-
if connection._transaction is None:
67-
committed_timestamp = None
68-
else:
69-
committed_timestamp = connection._transaction.committed
67+
if (
68+
connection._transaction is not None
69+
and connection._transaction.committed is not None
70+
):
71+
column_values.append(connection._transaction.committed)
7072
return _get_streamed_result_set(
7173
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
7274
TypeCode.TIMESTAMP,
73-
committed_timestamp,
75+
column_values,
7476
)
7577
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
76-
if connection._snapshot is None:
77-
read_timestamp = None
78-
else:
79-
read_timestamp = connection._snapshot._transaction_read_timestamp
78+
if (
79+
connection._snapshot is not None
80+
and connection._snapshot._transaction_read_timestamp is not None
81+
):
82+
column_values.append(connection._snapshot._transaction_read_timestamp)
8083
return _get_streamed_result_set(
8184
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
8285
TypeCode.TIMESTAMP,
83-
read_timestamp,
86+
column_values,
8487
)
8588
if statement_type == ClientSideStatementType.START_BATCH_DML:
8689
connection.start_batch_dml(cursor)
@@ -89,14 +92,28 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
8992
return connection.run_batch()
9093
if statement_type == ClientSideStatementType.ABORT_BATCH:
9194
return connection.abort_batch()
95+
if statement_type == ClientSideStatementType.PARTITION_QUERY:
96+
partition_ids = connection.partition_query(parsed_statement)
97+
return _get_streamed_result_set(
98+
"PARTITION",
99+
TypeCode.STRING,
100+
partition_ids,
101+
)
102+
if statement_type == ClientSideStatementType.RUN_PARTITION:
103+
return connection.run_partition(
104+
parsed_statement.client_side_statement_params[0]
105+
)
92106

93107

94-
def _get_streamed_result_set(column_name, type_code, column_value):
108+
def _get_streamed_result_set(column_name, type_code, column_values):
95109
struct_type_pb = StructType(
96110
fields=[StructType.Field(name=column_name, type_=Type(code=type_code))]
97111
)
98112

99113
result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
100-
if column_value is not None:
101-
result_set.values.extend([_make_value_pb(column_value)])
114+
if len(column_values) > 0:
115+
column_values_pb = []
116+
for column_value in column_values:
117+
column_values_pb.append(_make_value_pb(column_value))
118+
result_set.values.extend(column_values_pb)
102119
return StreamedResultSet(iter([result_set]))

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
3434
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
3535
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
36+
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
37+
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)
3638

3739

3840
def parse_stmt(query):
@@ -48,6 +50,7 @@ def parse_stmt(query):
4850
:returns: ParsedStatement object.
4951
"""
5052
client_side_statement_type = None
53+
client_side_statement_params = []
5154
if RE_COMMIT.match(query):
5255
client_side_statement_type = ClientSideStatementType.COMMIT
5356
if RE_BEGIN.match(query):
@@ -64,8 +67,19 @@ def parse_stmt(query):
6467
client_side_statement_type = ClientSideStatementType.RUN_BATCH
6568
if RE_ABORT_BATCH.match(query):
6669
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
70+
if RE_PARTITION_QUERY.match(query):
71+
match = re.search(RE_PARTITION_QUERY, query)
72+
client_side_statement_params.append(match.group(2))
73+
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
74+
if RE_RUN_PARTITION.match(query):
75+
match = re.search(RE_RUN_PARTITION, query)
76+
client_side_statement_params.append(match.group(3))
77+
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
6778
if client_side_statement_type is not None:
6879
return ParsedStatement(
69-
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
80+
StatementType.CLIENT_SIDE,
81+
Statement(query),
82+
client_side_statement_type,
83+
client_side_statement_params,
7084
)
7185
return None

google/cloud/spanner_dbapi/connection.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,15 @@
1919
from google.api_core.exceptions import Aborted
2020
from google.api_core.gapic_v1.client_info import ClientInfo
2121
from google.cloud import spanner_v1 as spanner
22+
from google.cloud.spanner_dbapi import partition_helper
2223
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
23-
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
24+
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
25+
from google.cloud.spanner_dbapi.parsed_statement import (
26+
ParsedStatement,
27+
Statement,
28+
StatementType,
29+
)
30+
from google.cloud.spanner_dbapi.partition_helper import PartitionId
2431
from google.cloud.spanner_v1 import RequestOptions
2532
from google.cloud.spanner_v1.session import _get_retry_delay
2633
from google.cloud.spanner_v1.snapshot import Snapshot
@@ -585,6 +592,54 @@ def abort_batch(self):
585592
self._batch_dml_executor = None
586593
self._batch_mode = BatchMode.NONE
587594

595+
@check_not_closed
596+
def partition_query(
597+
self,
598+
parsed_statement: ParsedStatement,
599+
query_options=None,
600+
):
601+
statement = parsed_statement.statement
602+
partitioned_query = parsed_statement.client_side_statement_params[0]
603+
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
604+
raise ProgrammingError(
605+
"Only queries can be partitioned. Invalid statement: " + statement.sql
606+
)
607+
if self.read_only is not True and self._client_transaction_started is True:
608+
raise ProgrammingError(
609+
"Partitioned query not supported as the connection is not in "
610+
"read only mode or ReadWrite transaction started"
611+
)
612+
613+
batch_snapshot = self._database.batch_snapshot()
614+
partition_ids = []
615+
partitions = list(
616+
batch_snapshot.generate_query_batches(
617+
partitioned_query,
618+
statement.params,
619+
statement.param_types,
620+
query_options=query_options,
621+
)
622+
)
623+
for partition in partitions:
624+
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
625+
partition_ids.append(
626+
partition_helper.encode_to_string(batch_transaction_id, partition)
627+
)
628+
return partition_ids
629+
630+
@check_not_closed
631+
def run_partition(self, batch_transaction_id):
632+
partition_id: PartitionId = partition_helper.decode_from_string(
633+
batch_transaction_id
634+
)
635+
batch_transaction_id = partition_id.batch_transaction_id
636+
batch_snapshot = self._database.batch_snapshot(
637+
read_timestamp=batch_transaction_id.read_timestamp,
638+
session_id=batch_transaction_id.session_id,
639+
transaction_id=batch_transaction_id.transaction_id,
640+
)
641+
return batch_snapshot.process(partition_id.partition_result)
642+
588643
def __enter__(self):
589644
return self
590645

google/cloud/spanner_dbapi/parse_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,23 @@ def classify_statement(query, args=None):
232232
get_param_types(args or None),
233233
ResultsChecksum(),
234234
)
235-
if RE_DDL.match(query):
236-
return ParsedStatement(StatementType.DDL, statement)
235+
statement_type = _get_statement_type(statement)
236+
return ParsedStatement(statement_type, statement)
237237

238-
if RE_IS_INSERT.match(query):
239-
return ParsedStatement(StatementType.INSERT, statement)
240238

239+
def _get_statement_type(statement):
240+
query = statement.sql
241+
if RE_DDL.match(query):
242+
return StatementType.DDL
243+
if RE_IS_INSERT.match(query):
244+
return StatementType.INSERT
241245
if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
242246
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
243247
# statements and doesn't yet support WITH for DML statements.
244-
return ParsedStatement(StatementType.QUERY, statement)
248+
return StatementType.QUERY
245249

246250
statement.sql = ensure_where_clause(query)
247-
return ParsedStatement(StatementType.UPDATE, statement)
251+
return StatementType.UPDATE
248252

249253

250254
def sql_pyformat_args_to_spanner(sql, params):

google/cloud/spanner_dbapi/parsed_statement.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 20203 Google LLC All rights reserved.
1+
# Copyright 2023 Google LLC All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from dataclasses import dataclass
1515
from enum import Enum
16-
from typing import Any
16+
from typing import Any, List
1717

1818
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
1919

@@ -35,6 +35,8 @@ class ClientSideStatementType(Enum):
3535
START_BATCH_DML = 6
3636
RUN_BATCH = 7
3737
ABORT_BATCH = 8
38+
PARTITION_QUERY = 9
39+
RUN_PARTITION = 10
3840

3941

4042
@dataclass
@@ -53,3 +55,4 @@ class ParsedStatement:
5355
statement_type: StatementType
5456
statement: Statement
5557
client_side_statement_type: ClientSideStatementType = None
58+
client_side_statement_params: List[Any] = None
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2023 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any
17+
18+
import gzip
19+
import pickle
20+
import base64
21+
22+
23+
def decode_from_string(encoded_partition_id):
24+
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
25+
partition_id_bytes = gzip.decompress(gzip_bytes)
26+
return pickle.loads(partition_id_bytes)
27+
28+
29+
def encode_to_string(batch_transaction_id, partition_result):
30+
partition_id = PartitionId(batch_transaction_id, partition_result)
31+
partition_id_bytes = pickle.dumps(partition_id)
32+
gzip_bytes = gzip.compress(partition_id_bytes)
33+
return str(base64.b64encode(gzip_bytes), "utf-8")
34+
35+
36+
@dataclass
37+
class BatchTransactionId:
38+
transaction_id: str
39+
session_id: str
40+
read_timestamp: Any
41+
42+
43+
@dataclass
44+
class PartitionId:
45+
batch_transaction_id: BatchTransactionId
46+
partition_result: Any

0 commit comments

Comments
 (0)