Skip to content

Commit 2361dfb

Browse files
committed
Incorporated comments
1 parent 8b63b9c commit 2361dfb

File tree

10 files changed

+254
-193
lines changed

10 files changed

+254
-193
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,25 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
4646
:type parsed_statement: ParsedStatement
4747
:param parsed_statement: parsed_statement based on the sql query
4848
"""
49-
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
49+
if connection.is_closed:
50+
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
51+
statement_type = parsed_statement.client_side_statement_type
52+
if statement_type == ClientSideStatementType.COMMIT:
5053
connection.commit()
5154
return None
52-
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
55+
if statement_type == ClientSideStatementType.BEGIN:
5356
connection.begin()
5457
return None
55-
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
58+
if statement_type == ClientSideStatementType.ROLLBACK:
5659
connection.rollback()
5760
return None
58-
if (
59-
parsed_statement.client_side_statement_type
60-
== ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
61-
):
62-
if connection.is_closed:
63-
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
61+
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
6462
return _get_streamed_result_set(
6563
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
6664
TypeCode.TIMESTAMP,
6765
connection._transaction.committed,
6866
)
69-
if (
70-
parsed_statement.client_side_statement_type
71-
== ClientSideStatementType.SHOW_READ_TIMESTAMP
72-
):
73-
if connection.is_closed:
74-
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
67+
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
7568
return _get_streamed_result_set(
7669
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
7770
TypeCode.TIMESTAMP,
@@ -85,5 +78,6 @@ def _get_streamed_result_set(column_name, type_code, column_value):
8578
)
8679

8780
result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
88-
result_set.values.extend([_make_value_pb(column_value)])
81+
if column_value is not None:
82+
result_set.values.extend([_make_value_pb(column_value)])
8983
return StreamedResultSet(iter([result_set]))

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
2525
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
2626
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
27-
r"^\s*(SHOW VARIABLE COMMIT_TIMESTAMP)", re.IGNORECASE
27+
r"^\s*(SHOW)\s+(VARIABLE)\s+(COMMIT_TIMESTAMP)", re.IGNORECASE
2828
)
2929
RE_SHOW_READ_TIMESTAMP = re.compile(
30-
r"^\s*(SHOW VARIABLE READ_TIMESTAMP)", re.IGNORECASE
30+
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
3131
)
3232

3333

google/cloud/spanner_dbapi/connection.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.cloud.spanner_v1 import RequestOptions
2424
from google.cloud.spanner_v1.session import _get_retry_delay
2525
from google.cloud.spanner_v1.snapshot import Snapshot
26+
from deprecated import deprecated
2627

2728
from google.cloud.spanner_dbapi.checksum import _compare_checksums
2829
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
@@ -35,7 +36,10 @@
3536

3637

3738
CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
38-
"This method is non-operational as transaction has not started"
39+
"This method is non-operational as transaction has not been started at " "client."
40+
)
41+
SPANNER_TRANSACTION_NOT_STARTED_WARNING = (
42+
"This method is non-operational as transaction has not been started at " "spanner."
3943
)
4044
MAX_INTERNAL_RETRIES = 50
4145

@@ -143,9 +147,10 @@ def database(self):
143147
return self._database
144148

145149
@property
150+
@deprecated(
151+
reason="This method is deprecated. Use spanner_transaction_started method"
152+
)
146153
def inside_transaction(self):
147-
"""Deprecated property which won't be supported in future versions.
148-
Please use spanner_transaction_started property instead."""
149154
return (
150155
self._transaction
151156
and not self._transaction.committed
@@ -268,7 +273,8 @@ def _release_session(self):
268273
"""
269274
if self.database is None:
270275
raise ValueError("Database needs to be passed for this operation")
271-
self.database._pool.put(self._session)
276+
if self._session is not None:
277+
self.database._pool.put(self._session)
272278
self._session = None
273279

274280
def retry_transaction(self):
@@ -310,7 +316,6 @@ def _rerun_previous_statements(self):
310316
status, res = transaction.batch_update(statements)
311317

312318
if status.code == ABORTED:
313-
self._spanner_transaction_started = False
314319
raise Aborted(status.details)
315320

316321
retried_checksum = ResultsChecksum()
@@ -373,6 +378,7 @@ def snapshot_checkout(self):
373378
self._snapshot = Snapshot(
374379
self._session_checkout(), multi_use=True, **self.staleness
375380
)
381+
self._snapshot.begin()
376382
self._spanner_transaction_started = True
377383

378384
return self._snapshot
@@ -398,7 +404,7 @@ def begin(self):
398404
399405
:raises: :class:`InterfaceError`: if this connection is closed.
400406
:raises: :class:`OperationalError`: if there is an existing transaction
401-
that has begin or is running
407+
that has been started
402408
"""
403409
if self._transaction_begin_marked:
404410
raise OperationalError("A transaction has already started")
@@ -422,37 +428,45 @@ def commit(self):
422428
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
423429
)
424430
return
431+
if not self._spanner_transaction_started:
432+
warnings.warn(
433+
SPANNER_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
434+
)
435+
return
425436

426437
self.run_prior_DDL_statements()
427-
if self._spanner_transaction_started:
428-
try:
429-
if not self._read_only:
430-
self._transaction.commit()
431-
432-
self._release_session()
433-
self._statements = []
434-
self._transaction_begin_marked = False
435-
self._spanner_transaction_started = False
436-
except Aborted:
437-
self.retry_transaction()
438-
self.commit()
438+
try:
439+
if not self._read_only:
440+
self._transaction.commit()
441+
except Aborted:
442+
self.retry_transaction()
443+
self.commit()
444+
finally:
445+
self._release_session()
446+
self._statements = []
447+
self._transaction_begin_marked = False
448+
self._spanner_transaction_started = False
439449

440450
def rollback(self):
441451
"""Rolls back any pending transaction.
442452
443453
This is a no-op if there is no active client transaction.
444454
"""
445-
446455
if not self._client_transaction_started:
447456
warnings.warn(
448457
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
449458
)
450459
return
460+
if not self._spanner_transaction_started:
461+
warnings.warn(
462+
SPANNER_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
463+
)
464+
return
451465

452-
if self._spanner_transaction_started:
466+
try:
453467
if not self._read_only:
454468
self._transaction.rollback()
455-
469+
finally:
456470
self._release_session()
457471
self._statements = []
458472
self._transaction_begin_marked = False

google/cloud/spanner_dbapi/cursor.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def close(self):
178178
"""Closes this cursor."""
179179
self._is_closed = True
180180

181-
def _do_execute_update(self, transaction, sql, params):
181+
def _do_execute_update_in_autocommit(self, transaction, sql, params):
182+
"""This function should only be used autocommit mode."""
182183
self.connection._transaction = transaction
183184
self._result_set = transaction.execute_sql(
184185
sql, params=params, param_types=get_param_types(params)
@@ -260,54 +261,54 @@ def execute(self, sql, args=None):
260261

261262
# For every other operation, we've got to ensure that
262263
# any prior DDL statements were run.
263-
# self._run_prior_DDL_statements()
264264
self.connection.run_prior_DDL_statements()
265-
266265
if parsed_statement.statement_type == StatementType.UPDATE:
267266
sql = parse_utils.ensure_where_clause(sql)
268-
269267
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
270268

271269
if self.connection._client_transaction_started:
272-
statement = Statement(
273-
sql,
274-
args,
275-
get_param_types(args or None),
276-
ResultsChecksum(),
277-
)
278-
279-
(
280-
self._result_set,
281-
self._checksum,
282-
) = self.connection.run_statement(statement)
283-
while True:
284-
try:
285-
self._itr = PeekIterator(self._result_set)
286-
break
287-
except Aborted:
288-
self.connection.retry_transaction()
289-
return
290-
291-
if parsed_statement.statement_type == StatementType.QUERY:
292-
self._handle_DQL(sql, args or None)
270+
self._execute_statement_in_non_autocommit_mode(sql, args)
293271
else:
294-
self.connection.database.run_in_transaction(
295-
self._do_execute_update,
296-
sql,
297-
args or None,
298-
)
272+
self._execute_statement_in_autocommit_mode(sql, args, parsed_statement)
273+
299274
except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
300-
self.close()
301-
self.connection.close()
302275
raise IntegrityError(getattr(e, "details", e)) from e
303276
except InvalidArgument as e:
304-
self.close()
305-
self.connection.close()
306277
raise ProgrammingError(getattr(e, "details", e)) from e
307278
except InternalServerError as e:
308-
self.close()
309-
self.connection.close()
310279
raise OperationalError(getattr(e, "details", e)) from e
280+
finally:
281+
if self.connection._client_transaction_started is False:
282+
self.connection._spanner_transaction_started = False
283+
284+
def _execute_statement_in_non_autocommit_mode(self, sql, args):
285+
statement = Statement(
286+
sql,
287+
args,
288+
get_param_types(args or None),
289+
ResultsChecksum(),
290+
)
291+
292+
(
293+
self._result_set,
294+
self._checksum,
295+
) = self.connection.run_statement(statement)
296+
while True:
297+
try:
298+
self._itr = PeekIterator(self._result_set)
299+
break
300+
except Aborted:
301+
self.connection.retry_transaction()
302+
303+
def _execute_statement_in_autocommit_mode(self, sql, args, parsed_statement):
304+
if parsed_statement.statement_type == StatementType.QUERY:
305+
self._handle_DQL(sql, args or None)
306+
else:
307+
self.connection.database.run_in_transaction(
308+
self._do_execute_update_in_autocommit,
309+
sql,
310+
args or None,
311+
)
311312

312313
@check_not_closed
313314
def executemany(self, operation, seq_of_params):
@@ -487,6 +488,10 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
487488
# Unfortunately, Spanner doesn't seem to send back
488489
# information about the number of rows available.
489490
self._row_count = _UNSET_COUNT
491+
if self._result_set.metadata.transaction.read_timestamp is not None:
492+
snapshot._transaction_read_timestamp = (
493+
self._result_set.metadata.transaction.read_timestamp
494+
)
490495

491496
def _handle_DQL(self, sql, params):
492497
if self.connection.database is None:

google/cloud/spanner_v1/snapshot.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,9 @@
1515
"""Model a set of read-only queries to a database as a snapshot."""
1616

1717
import functools
18-
import itertools
1918
import threading
2019
from google.protobuf.struct_pb2 import Struct
21-
from google.cloud.spanner_v1 import (
22-
ExecuteSqlRequest,
23-
PartialResultSet,
24-
ResultSetMetadata,
25-
)
20+
from google.cloud.spanner_v1 import ExecuteSqlRequest
2621
from google.cloud.spanner_v1 import ReadRequest
2722
from google.cloud.spanner_v1 import TransactionOptions
2823
from google.cloud.spanner_v1 import TransactionSelector
@@ -452,17 +447,11 @@ def execute_sql(
452447
if self._transaction_id is None:
453448
# lock is added to handle the inline begin for first rpc
454449
with self._lock:
455-
return self._get_streamed_result_set(
456-
restart, request, trace_attributes, False
457-
)
450+
return self._get_streamed_result_set(restart, request, trace_attributes)
458451
else:
459-
return self._get_streamed_result_set(
460-
restart, request, trace_attributes, True
461-
)
452+
return self._get_streamed_result_set(restart, request, trace_attributes)
462453

463-
def _get_streamed_result_set(
464-
self, restart, request, trace_attributes, transaction_id_set
465-
):
454+
def _get_streamed_result_set(self, restart, request, trace_attributes):
466455
iterator = _restart_on_unavailable(
467456
restart,
468457
request,
@@ -474,16 +463,6 @@ def _get_streamed_result_set(
474463
self._read_request_count += 1
475464
self._execute_sql_count += 1
476465

477-
if self._read_only and not transaction_id_set:
478-
peek = next(iterator)
479-
response_pb = PartialResultSet.pb(peek)
480-
response_metadata = ResultSetMetadata.wrap(response_pb.metadata)
481-
if response_metadata.transaction.read_timestamp is not None:
482-
self._transaction_read_timestamp = (
483-
response_metadata.transaction.read_timestamp
484-
)
485-
iterator = itertools.chain([peek], iterator)
486-
487466
if self._multi_use:
488467
return StreamedResultSet(iterator, source=self)
489468
else:

0 commit comments

Comments
 (0)