Skip to content

Commit 0d19d96

Browse files
committed
Rollback any active transactions in Connection.reset().
In detail: 1. If there is an active transaction, `Connection.reset()` will rollback it, along with issuing a warning. Both transactions started with `Connection.transaction()` and manually started with `Connection.execute('BEGIN;')` are supported. 2. It's no longer possible to start a transaction using `Connection.transaction()` API if the connection is in a manually started transaction. 3. New `assertLoopErrorHandlerCalled` helper method for asyncpg TestCase.
1 parent be55d5d commit 0d19d96

File tree

7 files changed

+104
-6
lines changed

7 files changed

+104
-6
lines changed

asyncpg/_testbase.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import inspect
1313
import logging
1414
import os
15+
import re
1516
import time
1617
import unittest
1718

@@ -95,6 +96,30 @@ def assertRunUnder(self, delta):
9596
raise AssertionError(
9697
'running block took longer than {}'.format(delta))
9798

99+
@contextlib.contextmanager
100+
def assertLoopErrorHandlerCalled(self, msg_re: str):
101+
contexts = []
102+
103+
def handler(loop, ctx):
104+
contexts.append(ctx)
105+
106+
old_handler = self.loop.get_exception_handler()
107+
self.loop.set_exception_handler(handler)
108+
try:
109+
yield
110+
111+
for ctx in contexts:
112+
msg = ctx.get('message')
113+
if msg and re.search(msg_re, msg):
114+
return
115+
116+
raise AssertionError(
117+
'no message matching {!r} was logged with '
118+
'loop.call_exception_handler()'.format(msg_re))
119+
120+
finally:
121+
self.loop.set_exception_handler(old_handler)
122+
98123

99124
_default_cluster = None
100125

asyncpg/connection.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,20 @@ def _get_reset_query(self):
512512

513513
caps = self._server_caps
514514

515-
_reset_query = ''
515+
_reset_query = []
516+
if self._protocol.is_in_transaction() or self._top_xact is not None:
517+
self._loop.call_exception_handler({
518+
'message': 'Resetting connection with an '
519+
'active transaction {!r}'.format(self)
520+
})
521+
self._top_xact = None
522+
_reset_query.append('ROLLBACK;')
516523
if caps.advisory_locks:
517-
_reset_query += 'SELECT pg_advisory_unlock_all();\n'
524+
_reset_query.append('SELECT pg_advisory_unlock_all();')
518525
if caps.cursors:
519-
_reset_query += 'CLOSE ALL;\n'
526+
_reset_query.append('CLOSE ALL;')
520527
if caps.notifications and caps.plpgsql:
521-
_reset_query += '''
528+
_reset_query.append('''
522529
DO $$
523530
BEGIN
524531
PERFORM * FROM pg_listening_channels() LIMIT 1;
@@ -527,10 +534,11 @@ def _get_reset_query(self):
527534
END IF;
528535
END;
529536
$$;
530-
'''
537+
''')
531538
if caps.sql_reset:
532-
_reset_query += 'RESET ALL;\n'
539+
_reset_query.append('RESET ALL;')
533540

541+
_reset_query = '\n'.join(_reset_query)
534542
self._reset_query = _reset_query
535543

536544
return _reset_query

asyncpg/protocol/protocol.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ cdef class BaseProtocol(CoreProtocol):
120120
def get_settings(self):
121121
return self.settings
122122

123+
def is_in_transaction(self):
124+
return self.xact_status == PQTRANS_INTRANS
125+
123126
async def prepare(self, stmt_name, query, timeout):
124127
if self.cancel_waiter is not None:
125128
await self.cancel_waiter

asyncpg/transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ async def start(self):
8484
con = self._connection
8585

8686
if con._top_xact is None:
87+
if con._protocol.is_in_transaction():
88+
raise apg_errors.InterfaceError(
89+
'cannot use Connection.transaction() in '
90+
'a manually started transaction')
8791
con._top_xact = self
8892
else:
8993
# Nested transaction block

tests/test_pool.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,35 @@ async def sleep_and_release():
349349
async with pool.acquire() as con:
350350
await con.fetchval('SELECT 1')
351351

352+
async def test_pool_release_in_xact(self):
353+
"""Test that Connection.reset() closes any open transaction."""
354+
async with self.create_pool(database='postgres',
355+
min_size=1, max_size=1) as pool:
356+
async def get_xact_id(con):
357+
return await con.fetchval('select txid_current()')
358+
359+
with self.assertLoopErrorHandlerCalled('an active transaction'):
360+
async with pool.acquire() as con:
361+
real_con = con._con # unwrap PoolConnectionProxy
362+
363+
id1 = await get_xact_id(con)
364+
365+
tr = con.transaction()
366+
self.assertIsNone(con._con._top_xact)
367+
await tr.start()
368+
self.assertIs(real_con._top_xact, tr)
369+
370+
id2 = await get_xact_id(con)
371+
self.assertNotEqual(id1, id2)
372+
373+
self.assertIsNone(real_con._top_xact)
374+
375+
async with pool.acquire() as con:
376+
self.assertIs(con._con, real_con)
377+
self.assertIsNone(con._con._top_xact)
378+
id3 = await get_xact_id(con)
379+
self.assertNotEqual(id2, id3)
380+
352381

353382
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
354383
class TestHostStandby(tb.ConnectedTestCase):

tests/test_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,14 @@ def test_tests_fail_1(self):
3333
suite.run(result)
3434

3535
self.assertIn('ZeroDivisionError', result.errors[0][1])
36+
37+
38+
class TestHelpers(tb.TestCase):
39+
40+
async def test_tests_assertLoopErrorHandlerCalled_01(self):
41+
with self.assertRaisesRegex(AssertionError, r'no message.*was logged'):
42+
with self.assertLoopErrorHandlerCalled('aa'):
43+
self.loop.call_exception_handler({'message': 'bb a bb'})
44+
45+
with self.assertLoopErrorHandlerCalled('aa'):
46+
self.loop.call_exception_handler({'message': 'bbaabb'})

tests/test_transaction.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,21 @@ async def test_transaction_interface_errors(self):
139139
async with tr:
140140
async with tr:
141141
pass
142+
143+
async def test_transaction_within_manual_transaction(self):
144+
self.assertIsNone(self.con._top_xact)
145+
146+
await self.con.execute('BEGIN')
147+
148+
tr = self.con.transaction()
149+
self.assertIsNone(self.con._top_xact)
150+
151+
with self.assertRaisesRegex(asyncpg.InterfaceError,
152+
'cannot use Connection.transaction'):
153+
await tr.start()
154+
155+
with self.assertLoopErrorHandlerCalled(
156+
'Resetting connection with an active transaction'):
157+
await self.con.reset()
158+
159+
self.assertIsNone(self.con._top_xact)

0 commit comments

Comments
 (0)