Skip to content

Commit 0d2ea02

Browse files
tseaverlandrito
authored andcommitted
Unbind transaction from session on commit/rollback. (googleapis#3669)
Closes googleapis#3014.
1 parent 5db3a82 commit 0d2ea02

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

spanner/google/cloud/spanner/session.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ def run_in_transaction(self, func, *args, **kw):
302302
continue
303303
except Exception:
304304
txn.rollback()
305-
del self._transaction
306305
raise
307306

308307
try:
@@ -312,7 +311,6 @@ def run_in_transaction(self, func, *args, **kw):
312311
del self._transaction
313312
else:
314313
committed = txn.committed
315-
del self._transaction
316314
return committed
317315

318316

spanner/google/cloud/spanner/transaction.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def rollback(self):
9393
options = _options_with_prefix(database.name)
9494
api.rollback(self._session.name, self._id, options=options)
9595
self._rolled_back = True
96+
del self._session._transaction
9697

9798
def commit(self):
9899
"""Commit mutations to the database.
@@ -114,6 +115,7 @@ def commit(self):
114115
transaction_id=self._id, options=options)
115116
self.committed = _pb_timestamp_to_datetime(
116117
response.commit_timestamp)
118+
del self._session._transaction
117119
return self.committed
118120

119121
def __enter__(self):

spanner/tests/unit/test_transaction.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ def _getTargetClass(self):
4242

4343
return Transaction
4444

45-
def _make_one(self, *args, **kwargs):
46-
return self._getTargetClass()(*args, **kwargs)
45+
def _make_one(self, session, *args, **kwargs):
46+
transaction = self._getTargetClass()(session, *args, **kwargs)
47+
session._transaction = transaction
48+
return transaction
4749

4850
def test_ctor_defaults(self):
4951
session = _Session()
@@ -208,6 +210,7 @@ def test_rollback_ok(self):
208210
transaction.rollback()
209211

210212
self.assertTrue(transaction._rolled_back)
213+
self.assertIsNone(session._transaction)
211214

212215
session_id, txn_id, options = api._rolled_back
213216
self.assertEqual(session_id, session.name)
@@ -290,6 +293,7 @@ def test_commit_ok(self):
290293
transaction.commit()
291294

292295
self.assertEqual(transaction.committed, now)
296+
self.assertIsNone(session._transaction)
293297

294298
session_id, mutations, txn_id, options = api._committed
295299
self.assertEqual(session_id, session.name)
@@ -368,6 +372,8 @@ class _Database(object):
368372

369373
class _Session(object):
370374

375+
_transaction = None
376+
371377
def __init__(self, database=None, name=TestTransaction.SESSION_NAME):
372378
self._database = database
373379
self.name = name

0 commit comments

Comments
 (0)