Skip to content

Commit e533567

Browse files
committed
feat: Implementation for Begin and Rollback clientside statements
1 parent 5fb5610 commit e533567

File tree

8 files changed

+200
-39
lines changed

8 files changed

+200
-39
lines changed

google/cloud/spanner_dbapi/client_side_statement_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,15 @@ def execute(connection, parsed_statement: ParsedStatement):
2222
2323
It is an internal method that can make backwards-incompatible changes.
2424
25+
:type connection: Connection
26+
:param connection: Connection object of the dbApi
27+
2528
:type parsed_statement: ParsedStatement
2629
:param parsed_statement: parsed_statement based on the sql query
2730
"""
2831
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
2932
return connection.commit()
33+
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
34+
return connection.begin()
35+
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
36+
return connection.rollback()

google/cloud/spanner_dbapi/client_side_statement_parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
ClientSideStatementType,
2121
)
2222

23+
RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
2324
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
25+
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
2426

2527

2628
def parse_stmt(query):
@@ -39,4 +41,12 @@ def parse_stmt(query):
3941
return ParsedStatement(
4042
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
4143
)
44+
if RE_BEGIN.match(query):
45+
return ParsedStatement(
46+
StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN
47+
)
48+
if RE_ROLLBACK.match(query):
49+
return ParsedStatement(
50+
StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK
51+
)
4252
return None

google/cloud/spanner_dbapi/connection.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from google.rpc.code_pb2 import ABORTED
3535

3636

37-
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
37+
TRANSACTION_NOT_BEGUN_WARNING = (
38+
"This method is non-operational as transaction has not begun"
39+
)
3840
MAX_INTERNAL_RETRIES = 50
3941

4042

@@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False):
104106
self._read_only = read_only
105107
self._staleness = None
106108
self.request_priority = None
109+
self._transaction_begin_marked = False
107110

108111
@property
109112
def autocommit(self):
@@ -141,14 +144,23 @@ def inside_transaction(self):
141144
"""Flag: transaction is started.
142145
143146
Returns:
144-
bool: True if transaction begun, False otherwise.
147+
bool: True if transaction started, False otherwise.
145148
"""
146149
return (
147150
self._transaction
148151
and not self._transaction.committed
149152
and not self._transaction.rolled_back
150153
)
151154

155+
@property
156+
def transaction_begun(self):
157+
"""Flag: transaction has begun
158+
159+
Returns:
160+
bool: True if transaction begun, False otherwise.
161+
"""
162+
return (not self._autocommit) or self._transaction_begin_marked
163+
152164
@property
153165
def instance(self):
154166
"""Instance to which this connection relates.
@@ -333,12 +345,10 @@ def transaction_checkout(self):
333345
Begin a new transaction, if there is no transaction in
334346
this connection yet. Return the begun one otherwise.
335347
336-
The method is non operational in autocommit mode.
337-
338348
:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
339349
:returns: A Cloud Spanner transaction object, ready to use.
340350
"""
341-
if not self.autocommit:
351+
if self.transaction_begun:
342352
if not self.inside_transaction:
343353
self._transaction = self._session_checkout().transaction()
344354
self._transaction.begin()
@@ -354,7 +364,7 @@ def snapshot_checkout(self):
354364
:rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot`
355365
:returns: A Cloud Spanner snapshot object, ready to use.
356366
"""
357-
if self.read_only and not self.autocommit:
367+
if self.read_only and self.transaction_begun:
358368
if not self._snapshot:
359369
self._snapshot = Snapshot(
360370
self._session_checkout(), multi_use=True, **self.staleness
@@ -377,6 +387,22 @@ def close(self):
377387

378388
self.is_closed = True
379389

390+
@check_not_closed
391+
def begin(self):
392+
"""
393+
Marks the transaction as started.
394+
395+
:raises: :class:`InterfaceError`: if this connection is closed.
396+
:raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running
397+
"""
398+
if self._transaction_begin_marked:
399+
raise OperationalError("A transaction has already begun")
400+
if self.inside_transaction:
401+
raise OperationalError(
402+
"Beginning a new transaction is not allowed when a transaction is already running"
403+
)
404+
self._transaction_begin_marked = True
405+
380406
def commit(self):
381407
"""Commits any pending transaction to the database.
382408
@@ -386,8 +412,8 @@ def commit(self):
386412
raise ValueError("Database needs to be passed for this operation")
387413
self._snapshot = None
388414

389-
if self._autocommit:
390-
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
415+
if not self.transaction_begun:
416+
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
391417
return
392418

393419
self.run_prior_DDL_statements()
@@ -398,6 +424,7 @@ def commit(self):
398424

399425
self._release_session()
400426
self._statements = []
427+
self._transaction_begin_marked = False
401428
except Aborted:
402429
self.retry_transaction()
403430
self.commit()
@@ -410,14 +437,15 @@ def rollback(self):
410437
"""
411438
self._snapshot = None
412439

413-
if self._autocommit:
414-
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
440+
if not self.transaction_begun:
441+
warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2)
415442
elif self._transaction:
416443
if not self.read_only:
417444
self._transaction.rollback()
418445

419446
self._release_session()
420447
self._statements = []
448+
self._transaction_begin_marked = False
421449

422450
@check_not_closed
423451
def cursor(self):

google/cloud/spanner_dbapi/cursor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def execute(self, sql, args=None):
250250
)
251251
if parsed_statement.statement_type == StatementType.DDL:
252252
self._batch_DDLs(sql)
253-
if self.connection.autocommit:
253+
if not self.connection.transaction_begun:
254254
self.connection.run_prior_DDL_statements()
255255
return
256256

@@ -264,7 +264,7 @@ def execute(self, sql, args=None):
264264

265265
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
266266

267-
if not self.connection.autocommit:
267+
if self.connection.transaction_begun:
268268
statement = Statement(
269269
sql,
270270
args,
@@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params):
348348
)
349349
statements.append((sql, params, get_param_types(params)))
350350

351-
if self.connection.autocommit:
351+
if self.connection.transaction_begun:
352352
self.connection.database.run_in_transaction(
353353
self._do_batch_update, statements, many_result_set
354354
)
@@ -396,7 +396,7 @@ def fetchone(self):
396396
sequence, or None when no more data is available."""
397397
try:
398398
res = next(self)
399-
if not self.connection.autocommit and not self.connection.read_only:
399+
if self.connection.transaction_begun and not self.connection.read_only:
400400
self._checksum.consume_result(res)
401401
return res
402402
except StopIteration:
@@ -414,7 +414,7 @@ def fetchall(self):
414414
res = []
415415
try:
416416
for row in self:
417-
if not self.connection.autocommit and not self.connection.read_only:
417+
if self.connection.transaction_begun and not self.connection.read_only:
418418
self._checksum.consume_result(row)
419419
res.append(row)
420420
except Aborted:
@@ -443,7 +443,7 @@ def fetchmany(self, size=None):
443443
for _ in range(size):
444444
try:
445445
res = next(self)
446-
if not self.connection.autocommit and not self.connection.read_only:
446+
if self.connection.transaction_begun and not self.connection.read_only:
447447
self._checksum.consume_result(res)
448448
items.append(res)
449449
except StopIteration:
@@ -473,7 +473,7 @@ def _handle_DQL(self, sql, params):
473473
if self.connection.database is None:
474474
raise ValueError("Database needs to be passed for this operation")
475475
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
476-
if self.connection.read_only and not self.connection.autocommit:
476+
if self.connection.read_only and self.connection.transaction_begun:
477477
# initiate or use the existing multi-use snapshot
478478
self._handle_DQL_with_snapshot(
479479
self.connection.snapshot_checkout(), sql, params

google/cloud/spanner_dbapi/parsed_statement.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class StatementType(Enum):
2727
class ClientSideStatementType(Enum):
2828
COMMIT = 1
2929
BEGIN = 2
30+
ROLLBACK = 3
3031

3132

3233
@dataclass

tests/system/test_dbapi.py

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from google.cloud._helpers import UTC
2323

2424
from google.cloud.spanner_dbapi.connection import Connection, connect
25-
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
25+
from google.cloud.spanner_dbapi.exceptions import ProgrammingError, OperationalError
2626
from google.cloud.spanner_v1 import JsonObject
2727
from google.cloud.spanner_v1 import gapic_version as package_version
2828
from . import _helpers
@@ -80,42 +80,43 @@ def init_connection(self, request, shared_instance, dbapi_database):
8080
self._cursor.close()
8181
self._conn.close()
8282

83-
@pytest.fixture
84-
def execute_common_statements(self):
83+
def _execute_common_statements(self, cursor):
8584
# execute several DML statements within one transaction
86-
self._cursor.execute(
85+
cursor.execute(
8786
"""
8887
INSERT INTO contacts (contact_id, first_name, last_name, email)
8988
VALUES (1, 'first-name', 'last-name', '[email protected]')
9089
"""
9190
)
92-
self._cursor.execute(
91+
cursor.execute(
9392
"""
9493
UPDATE contacts
9594
SET first_name = 'updated-first-name'
9695
WHERE first_name = 'first-name'
9796
"""
9897
)
99-
self._cursor.execute(
98+
cursor.execute(
10099
"""
101100
UPDATE contacts
102101
SET email = '[email protected]'
103102
WHERE email = '[email protected]'
104103
"""
105104
)
106-
107-
@pytest.fixture
108-
def updated_row(self, execute_common_statements):
109105
return (
110106
1,
111107
"updated-first-name",
112108
"last-name",
113109
114110
)
115111

116-
def test_commit(self, updated_row):
112+
@pytest.mark.parametrize("client_side", [False, True])
113+
def test_commit(self, client_side):
117114
"""Test committing a transaction with several statements."""
118-
self._conn.commit()
115+
updated_row = self._execute_common_statements(self._cursor)
116+
if client_side:
117+
self._cursor.execute("""COMMIT""")
118+
else:
119+
self._conn.commit()
119120

120121
# read the resulting data from the database
121122
self._cursor.execute("SELECT * FROM contacts")
@@ -124,18 +125,80 @@ def test_commit(self, updated_row):
124125

125126
assert got_rows == [updated_row]
126127

127-
def test_commit_client_side(self, updated_row):
128-
"""Test committing a transaction with several statements."""
129-
self._cursor.execute("""COMMIT""")
128+
@pytest.mark.noautofixt
129+
def test_begin_client_side(self, shared_instance, dbapi_database):
130+
"""Test beginning a transaction using client side statement,
131+
where connection is in autocommit mode."""
132+
133+
conn1 = Connection(shared_instance, dbapi_database)
134+
conn1.autocommit = True
135+
cursor1 = conn1.cursor()
136+
cursor1.execute("begin transaction")
137+
updated_row = self._execute_common_statements(cursor1)
138+
139+
# As the connection conn1 is not committed a new connection wont see its results
140+
conn2 = Connection(shared_instance, dbapi_database)
141+
cursor2 = conn2.cursor()
142+
cursor2.execute("SELECT * FROM contacts")
143+
conn2.commit()
144+
got_rows = cursor2.fetchall()
145+
assert got_rows != [updated_row]
146+
147+
assert conn1._transaction_begin_marked is True
148+
conn1.commit()
149+
assert conn1._transaction_begin_marked is False
150+
151+
# As the connection conn1 is committed a new connection should see its results
152+
conn3 = Connection(shared_instance, dbapi_database)
153+
cursor3 = conn3.cursor()
154+
cursor3.execute("SELECT * FROM contacts")
155+
conn3.commit()
156+
got_rows = cursor3.fetchall()
157+
assert got_rows == [updated_row]
130158

131-
# read the resulting data from the database
159+
conn1.close()
160+
conn2.close()
161+
conn3.close()
162+
cursor1.close()
163+
cursor2.close()
164+
cursor3.close()
165+
166+
def test_begin_success_post_commit(self):
167+
"""Test beginning a new transaction post commiting an existing transaction
168+
is possible on a connection, when connection is in autocommit mode."""
169+
want_row = (2, "first-name", "last-name", "[email protected]")
170+
self._conn.autocommit = True
171+
self._cursor.execute("begin transaction")
172+
self._cursor.execute(
173+
"""
174+
INSERT INTO contacts (contact_id, first_name, last_name, email)
175+
VALUES (2, 'first-name', 'last-name', '[email protected]')
176+
"""
177+
)
178+
self._conn.commit()
179+
180+
self._cursor.execute("begin transaction")
132181
self._cursor.execute("SELECT * FROM contacts")
133182
got_rows = self._cursor.fetchall()
134183
self._conn.commit()
184+
assert got_rows == [want_row]
135185

136-
assert got_rows == [updated_row]
186+
def test_begin_error_before_commit(self):
187+
"""Test beginning a new transaction before commiting an existing transaction is not possible on a connection, when connection is in autocommit mode."""
188+
self._conn.autocommit = True
189+
self._cursor.execute("begin transaction")
190+
self._cursor.execute(
191+
"""
192+
INSERT INTO contacts (contact_id, first_name, last_name, email)
193+
VALUES (2, 'first-name', 'last-name', '[email protected]')
194+
"""
195+
)
196+
197+
with pytest.raises(OperationalError):
198+
self._cursor.execute("begin transaction")
137199

138-
def test_rollback(self):
200+
@pytest.mark.parametrize("client_side", [False, True])
201+
def test_rollback(self, client_side):
139202
"""Test rollbacking a transaction with several statements."""
140203
want_row = (2, "first-name", "last-name", "[email protected]")
141204

@@ -162,7 +225,11 @@ def test_rollback(self):
162225
WHERE email = '[email protected]'
163226
"""
164227
)
165-
self._conn.rollback()
228+
229+
if client_side:
230+
self._cursor.execute("ROLLBACK")
231+
else:
232+
self._conn.rollback()
166233

167234
# read the resulting data from the database
168235
self._cursor.execute("SELECT * FROM contacts")

0 commit comments

Comments
 (0)