Skip to content

Commit af2ca45

Browse files
chayimjamestiotiodvora-h
committed
Fixing cancelled async futures (#2666)
Co-authored-by: James R T <[email protected]> Co-authored-by: dvora-h <[email protected]>
1 parent b3c89ac commit af2ca45

File tree

7 files changed

+275
-71
lines changed

7 files changed

+275
-71
lines changed

.github/workflows/integration.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
timeout-minutes: 30
5252
strategy:
5353
max-parallel: 15
54+
fail-fast: false
5455
matrix:
5556
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8']
5657
test-type: ['standalone', 'cluster']
@@ -108,6 +109,7 @@ jobs:
108109
name: Install package from commit hash
109110
runs-on: ubuntu-latest
110111
strategy:
112+
fail-fast: false
111113
matrix:
112114
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7']
113115
steps:

redis/asyncio/client.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -493,24 +493,34 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
493493
):
494494
raise error
495495

496-
# COMMAND EXECUTION AND PROTOCOL PARSING
497-
async def execute_command(self, *args, **options):
498-
"""Execute a command and return a parsed response"""
499-
await self.initialize()
500-
pool = self.connection_pool
501-
command_name = args[0]
502-
conn = self.connection or await pool.get_connection(command_name, **options)
503-
496+
async def _try_send_command_parse_response(self, conn, *args, **options):
504497
try:
505498
return await conn.retry.call_with_retry(
506499
lambda: self._send_command_parse_response(
507-
conn, command_name, *args, **options
500+
conn, args[0], *args, **options
508501
),
509502
lambda error: self._disconnect_raise(conn, error),
510503
)
504+
except asyncio.CancelledError:
505+
await conn.disconnect(nowait=True)
506+
raise
511507
finally:
508+
if self.single_connection_client:
509+
self._single_conn_lock.release()
512510
if not self.connection:
513-
await pool.release(conn)
511+
await self.connection_pool.release(conn)
512+
513+
# COMMAND EXECUTION AND PROTOCOL PARSING
514+
async def execute_command(self, *args, **options):
515+
"""Execute a command and return a parsed response"""
516+
await self.initialize()
517+
pool = self.connection_pool
518+
command_name = args[0]
519+
conn = self.connection or await pool.get_connection(command_name, **options)
520+
521+
return await asyncio.shield(
522+
self._try_send_command_parse_response(conn, *args, **options)
523+
)
514524

515525
async def parse_response(
516526
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -749,10 +759,18 @@ async def _disconnect_raise_connect(self, conn, error):
749759
is not a TimeoutError. Otherwise, try to reconnect
750760
"""
751761
await conn.disconnect()
762+
752763
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
753764
raise error
754765
await conn.connect()
755766

767+
async def _try_execute(self, conn, command, *arg, **kwargs):
768+
try:
769+
return await command(*arg, **kwargs)
770+
except asyncio.CancelledError:
771+
await conn.disconnect()
772+
raise
773+
756774
async def _execute(self, conn, command, *args, **kwargs):
757775
"""
758776
Connect manually upon disconnection. If the Redis server is down,
@@ -761,9 +779,11 @@ async def _execute(self, conn, command, *args, **kwargs):
761779
called by the # connection to resubscribe us to any channels and
762780
patterns we were previously listening to
763781
"""
764-
return await conn.retry.call_with_retry(
765-
lambda: command(*args, **kwargs),
766-
lambda error: self._disconnect_raise_connect(conn, error),
782+
return await asyncio.shield(
783+
conn.retry.call_with_retry(
784+
lambda: self._try_execute(conn, command, *args, **kwargs),
785+
lambda error: self._disconnect_raise_connect(conn, error),
786+
)
767787
)
768788

769789
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1165,6 +1185,18 @@ async def _disconnect_reset_raise(self, conn, error):
11651185
await self.reset()
11661186
raise
11671187

1188+
async def _try_send_command_parse_response(self, conn, *args, **options):
1189+
try:
1190+
return await conn.retry.call_with_retry(
1191+
lambda: self._send_command_parse_response(
1192+
conn, args[0], *args, **options
1193+
),
1194+
lambda error: self._disconnect_reset_raise(conn, error),
1195+
)
1196+
except asyncio.CancelledError:
1197+
await conn.disconnect()
1198+
raise
1199+
11681200
async def immediate_execute_command(self, *args, **options):
11691201
"""
11701202
Execute a command immediately, but don't auto-retry on a
@@ -1180,13 +1212,13 @@ async def immediate_execute_command(self, *args, **options):
11801212
command_name, self.shard_hint
11811213
)
11821214
self.connection = conn
1183-
1184-
return await conn.retry.call_with_retry(
1185-
lambda: self._send_command_parse_response(
1186-
conn, command_name, *args, **options
1187-
),
1188-
lambda error: self._disconnect_reset_raise(conn, error),
1189-
)
1215+
try:
1216+
return await asyncio.shield(
1217+
self._try_send_command_parse_response(conn, *args, **options)
1218+
)
1219+
except asyncio.CancelledError:
1220+
await conn.disconnect()
1221+
raise
11901222

11911223
def pipeline_execute_command(self, *args, **options):
11921224
"""
@@ -1353,6 +1385,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13531385
await self.reset()
13541386
raise
13551387

1388+
async def _try_execute(self, conn, execute, stack, raise_on_error):
1389+
try:
1390+
return await conn.retry.call_with_retry(
1391+
lambda: execute(conn, stack, raise_on_error),
1392+
lambda error: self._disconnect_raise_reset(conn, error),
1393+
)
1394+
except asyncio.CancelledError:
1395+
# not supposed to be possible, yet here we are
1396+
await conn.disconnect(nowait=True)
1397+
raise
1398+
finally:
1399+
await self.reset()
1400+
13561401
async def execute(self, raise_on_error: bool = True):
13571402
"""Execute all the commands in the current pipeline"""
13581403
stack = self.command_stack
@@ -1375,15 +1420,10 @@ async def execute(self, raise_on_error: bool = True):
13751420

13761421
try:
13771422
return await asyncio.shield(
1378-
conn.retry.call_with_retry(
1379-
lambda: execute(conn, stack, raise_on_error),
1380-
lambda error: self._disconnect_raise_reset(conn, error),
1381-
)
1423+
self._try_execute(conn, execute, stack, raise_on_error)
13821424
)
1383-
except asyncio.CancelledError:
1384-
# not supposed to be possible, yet here we are
1385-
await conn.disconnect(nowait=True)
1386-
raise
1425+
except RuntimeError:
1426+
await self.reset()
13871427
finally:
13881428
await self.reset()
13891429

redis/asyncio/cluster.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
10161016
finally:
10171017
self._free.append(connection)
10181018

1019+
async def _try_parse_response(self, cmd, connection, ret):
1020+
try:
1021+
cmd.result = await asyncio.shield(
1022+
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
1023+
)
1024+
except asyncio.CancelledError:
1025+
await connection.disconnect(nowait=True)
1026+
raise
1027+
except Exception as e:
1028+
cmd.result = e
1029+
ret = True
1030+
return ret
1031+
10191032
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10201033
# Acquire connection
10211034
connection = self.acquire_connection()
@@ -1028,13 +1041,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10281041
# Read responses
10291042
ret = False
10301043
for cmd in commands:
1031-
try:
1032-
cmd.result = await self.parse_response(
1033-
connection, cmd.args[0], **cmd.kwargs
1034-
)
1035-
except Exception as e:
1036-
cmd.result = e
1037-
ret = True
1044+
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
10381045

10391046
# Release connection
10401047
self._free.append(connection)

tests/test_asyncio/test_cluster.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -340,23 +340,6 @@ async def test_from_url(self, request: FixtureRequest) -> None:
340340
rc = RedisCluster.from_url("rediss://localhost:16379")
341341
assert rc.connection_kwargs["connection_class"] is SSLConnection
342342

343-
async def test_asynckills(self, r) -> None:
344-
345-
await r.set("foo", "foo")
346-
await r.set("bar", "bar")
347-
348-
t = asyncio.create_task(r.get("foo"))
349-
await asyncio.sleep(1)
350-
t.cancel()
351-
try:
352-
await t
353-
except asyncio.CancelledError:
354-
pytest.fail("connection is left open with unread response")
355-
356-
assert await r.get("bar") == b"bar"
357-
assert await r.ping()
358-
assert await r.get("foo") == b"foo"
359-
360343
async def test_max_connections(
361344
self, create_redis: Callable[..., RedisCluster]
362345
) -> None:

tests/test_asyncio/test_connection.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
import redis
9+
from redis.asyncio import Redis
910
from redis.asyncio.connection import (
1011
BaseParser,
1112
Connection,
@@ -42,25 +43,47 @@ async def test_invalid_response(create_redis):
4243

4344

4445
@pytest.mark.onlynoncluster
45-
async def test_asynckills(create_redis):
46-
47-
for b in [True, False]:
48-
r = await create_redis(single_connection_client=b)
49-
50-
await r.set("foo", "foo")
51-
await r.set("bar", "bar")
52-
53-
t = asyncio.create_task(r.get("foo"))
54-
await asyncio.sleep(1)
55-
t.cancel()
56-
try:
57-
await t
58-
except asyncio.CancelledError:
59-
pytest.fail("connection left open with unread response")
60-
61-
assert await r.get("bar") == b"bar"
62-
assert await r.ping()
63-
assert await r.get("foo") == b"foo"
46+
async def test_single_connection():
47+
"""Test that concurrent requests on a single client are synchronised."""
48+
r = Redis(single_connection_client=True)
49+
50+
init_call_count = 0
51+
command_call_count = 0
52+
in_use = False
53+
54+
class Retry_:
55+
async def call_with_retry(self, _, __):
56+
# If we remove the single-client lock, this error gets raised as two
57+
# coroutines will be vying for the `in_use` flag due to the two
58+
# asymmetric sleep calls
59+
nonlocal command_call_count
60+
nonlocal in_use
61+
if in_use is True:
62+
raise ValueError("Commands should be executed one at a time.")
63+
in_use = True
64+
await asyncio.sleep(0.01)
65+
command_call_count += 1
66+
await asyncio.sleep(0.03)
67+
in_use = False
68+
return "foo"
69+
70+
mock_conn = mock.MagicMock()
71+
mock_conn.retry = Retry_()
72+
73+
async def get_conn(_):
74+
# Validate only one client is created in single-client mode when
75+
# concurrent requests are made
76+
nonlocal init_call_count
77+
await asyncio.sleep(0.01)
78+
init_call_count += 1
79+
return mock_conn
80+
81+
with mock.patch.object(r.connection_pool, "get_connection", get_conn):
82+
with mock.patch.object(r.connection_pool, "release"):
83+
await asyncio.gather(r.set("a", "b"), r.set("c", "d"))
84+
85+
assert init_call_count == 1
86+
assert command_call_count == 2
6487

6588

6689
@skip_if_server_version_lt("4.0.0")

0 commit comments

Comments
 (0)