Skip to content

Commit 799f351

Browse files
authored
Merge branch 'master' into ps_fix_hybrid_tests_for_8_6
2 parents 1899edf + c40ec52 commit 799f351

File tree

8 files changed

+721
-67
lines changed

8 files changed

+721
-67
lines changed

redis/asyncio/cluster.py

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -973,12 +973,18 @@ async def _execute_command(
973973
except (ConnectionError, TimeoutError):
974974
# Connection retries are being handled in the node's
975975
# Retry object.
976-
# Remove the failed node from the startup nodes before we try
977-
# to reinitialize the cluster
978-
self.nodes_manager.startup_nodes.pop(target_node.name, None)
979-
# Hard force of reinitialize of the node/slots setup
980-
# and try again with the new setup
981-
await self.aclose()
976+
# Mark active connections for reconnect and disconnect free ones
977+
# This handles connection state (like READONLY) that may be stale
978+
target_node.update_active_connections_for_reconnect()
979+
await target_node.disconnect_free_connections()
980+
981+
# Move the failed node to the end of the cached nodes list
982+
# so it's tried last during reinitialization
983+
self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
984+
985+
# Signal that reinitialization is needed
986+
# The retry loop will handle initialize() AND replace_default_node()
987+
self._initialize = True
982988
raise
983989
except (ClusterDownError, SlotNotCoveredError):
984990
# ClusterDownError can occur during a failover and to get
@@ -1263,12 +1269,48 @@ def acquire_connection(self) -> Connection:
12631269

12641270
raise MaxConnectionsError()
12651271

1272+
async def disconnect_if_needed(self, connection: Connection) -> None:
1273+
"""
1274+
Disconnect a connection if it's marked for reconnect.
1275+
This implements lazy disconnection to avoid race conditions.
1276+
The connection will auto-reconnect on next use.
1277+
"""
1278+
if connection.should_reconnect():
1279+
await connection.disconnect()
1280+
12661281
def release(self, connection: Connection) -> None:
12671282
"""
12681283
Release connection back to free queue.
1284+
If the connection is marked for reconnect, it will be disconnected
1285+
lazily when next acquired via disconnect_if_needed().
12691286
"""
12701287
self._free.append(connection)
12711288

1289+
def update_active_connections_for_reconnect(self) -> None:
1290+
"""
1291+
Mark all in-use (active) connections for reconnect.
1292+
In-use connections are those in _connections but not currently in _free.
1293+
They will be disconnected when released back to the pool.
1294+
"""
1295+
free_set = set(self._free)
1296+
for connection in self._connections:
1297+
if connection not in free_set:
1298+
connection.mark_for_reconnect()
1299+
1300+
async def disconnect_free_connections(self) -> None:
1301+
"""
1302+
Disconnect all free/idle connections in the pool.
1303+
This is useful after topology changes (e.g., failover) to clear
1304+
stale connection state like READONLY mode.
1305+
The connections remain in the pool and will reconnect on next use.
1306+
"""
1307+
if self._free:
1308+
# Take a snapshot to avoid issues if _free changes during await
1309+
await asyncio.gather(
1310+
*(connection.disconnect() for connection in tuple(self._free)),
1311+
return_exceptions=True,
1312+
)
1313+
12721314
async def parse_response(
12731315
self, connection: Connection, command: str, **kwargs: Any
12741316
) -> Any:
@@ -1298,6 +1340,8 @@ async def parse_response(
12981340
async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
12991341
# Acquire connection
13001342
connection = self.acquire_connection()
1343+
# Handle lazy disconnect for connections marked for reconnect
1344+
await self.disconnect_if_needed(connection)
13011345

13021346
# Execute command
13031347
await connection.send_packed_command(connection.pack_command(*args), False)
@@ -1306,12 +1350,15 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
13061350
try:
13071351
return await self.parse_response(connection, args[0], **kwargs)
13081352
finally:
1353+
await self.disconnect_if_needed(connection)
13091354
# Release connection
13101355
self._free.append(connection)
13111356

13121357
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
13131358
# Acquire connection
13141359
connection = self.acquire_connection()
1360+
# Handle lazy disconnect for connections marked for reconnect
1361+
await self.disconnect_if_needed(connection)
13151362

13161363
# Execute command
13171364
await connection.send_packed_command(
@@ -1330,6 +1377,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
13301377
ret = True
13311378

13321379
# Release connection
1380+
await self.disconnect_if_needed(connection)
13331381
self._free.append(connection)
13341382

13351383
return ret
@@ -1432,15 +1480,50 @@ def set_nodes(
14321480
if remove_old:
14331481
for name in list(old.keys()):
14341482
if name not in new:
1435-
task = asyncio.create_task(old.pop(name).disconnect()) # noqa
1483+
# Node is removed from cache before disconnect starts,
1484+
# so it won't be found in lookups during disconnect
1485+
# Mark active connections for reconnect so they get disconnected after current command completes
1486+
# and disconnect free connections immediately
1487+
# the node is removed from the cache before the connections changes so it won't be used and should be safe
1488+
# not to wait for the disconnects
1489+
removed_node = old.pop(name)
1490+
removed_node.update_active_connections_for_reconnect()
1491+
asyncio.create_task(removed_node.disconnect_free_connections()) # noqa
14361492

14371493
for name, node in new.items():
14381494
if name in old:
1439-
if old[name] is node:
1440-
continue
1441-
task = asyncio.create_task(old[name].disconnect()) # noqa
1495+
# Preserve the existing node but mark connections for reconnect.
1496+
# This method is sync so we can't call disconnect_free_connections()
1497+
# which is async. Instead, we mark free connections for reconnect
1498+
# and they will be lazily disconnected when acquired via
1499+
# disconnect_if_needed() to avoid race conditions.
1500+
# TODO: Make this method async in the next major release to allow
1501+
# immediate disconnection of free connections.
1502+
existing_node = old[name]
1503+
existing_node.update_active_connections_for_reconnect()
1504+
for conn in existing_node._free:
1505+
conn.mark_for_reconnect()
1506+
continue
1507+
# New node is detected and should be added to the pool
14421508
old[name] = node
14431509

1510+
def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1511+
"""
1512+
Move a failing node to the end of startup_nodes and nodes_cache so it's
1513+
tried last during reinitialization and when selecting the default node.
1514+
If the node is not in the respective list, nothing is done.
1515+
"""
1516+
# Move in startup_nodes
1517+
if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1518+
node = self.startup_nodes.pop(node_name)
1519+
self.startup_nodes[node_name] = node # Re-insert at end
1520+
1521+
# Move in nodes_cache - this affects get_nodes_by_server_type ordering
1522+
# which is used to select the default_node during initialize()
1523+
if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1524+
node = self.nodes_cache.pop(node_name)
1525+
self.nodes_cache[node_name] = node # Re-insert at end
1526+
14441527
def move_slot(self, e: AskError | MovedError):
14451528
redirected_node = self.get_node(host=e.host, port=e.port)
14461529
if redirected_node:
@@ -2351,6 +2434,9 @@ async def _immediate_execute_command(self, *args, **options):
23512434

23522435
async def _get_connection_and_send_command(self, *args, **options):
23532436
redis_node, connection = self._get_client_and_connection_for_transaction()
2437+
# Only disconnect if not watching - disconnecting would lose WATCH state
2438+
if not self._watching:
2439+
await redis_node.disconnect_if_needed(connection)
23542440
return await self._send_command_parse_response(
23552441
connection, redis_node, args[0], *args, **options
23562442
)
@@ -2383,7 +2469,10 @@ async def _reinitialize_on_error(self, error):
23832469
type(error) in self.SLOT_REDIRECT_ERRORS
23842470
or type(error) in self.CONNECTION_ERRORS
23852471
):
2386-
if self._transaction_connection:
2472+
if self._transaction_connection and self._transaction_node:
2473+
# Disconnect and release back to pool
2474+
await self._transaction_connection.disconnect()
2475+
self._transaction_node.release(self._transaction_connection)
23872476
self._transaction_connection = None
23882477

23892478
self._pipe.cluster_client.reinitialize_counter += 1
@@ -2443,6 +2532,9 @@ async def _execute_transaction(
24432532
self._executing = True
24442533

24452534
redis_node, connection = self._get_client_and_connection_for_transaction()
2535+
# Only disconnect if not watching - disconnecting would lose WATCH state
2536+
if not self._watching:
2537+
await redis_node.disconnect_if_needed(connection)
24462538

24472539
stack = chain(
24482540
[PipelineCommand(0, "MULTI")],
@@ -2550,8 +2642,10 @@ async def reset(self):
25502642
self._transaction_connection = None
25512643
except self.CONNECTION_ERRORS:
25522644
# disconnect will also remove any previous WATCHes
2553-
if self._transaction_connection:
2645+
if self._transaction_connection and self._transaction_node:
25542646
await self._transaction_connection.disconnect()
2647+
self._transaction_node.release(self._transaction_connection)
2648+
self._transaction_connection = None
25552649

25562650
# clean up the other instance attributes
25572651
self._transaction_node = None

redis/asyncio/connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ def mark_for_reconnect(self):
381381
def should_reconnect(self):
382382
return self._should_reconnect
383383

384+
def reset_should_reconnect(self):
385+
self._should_reconnect = False
386+
384387
@abstractmethod
385388
async def _connect(self):
386389
pass
@@ -519,6 +522,8 @@ async def disconnect(self, nowait: bool = False) -> None:
519522
try:
520523
async with async_timeout(self.socket_connect_timeout):
521524
self._parser.on_disconnect()
525+
# Reset the reconnect flag
526+
self.reset_should_reconnect()
522527
if not self.is_connected:
523528
return
524529
try:

redis/cluster.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,11 +1429,23 @@ def _execute_command(self, target_node, *args, **kwargs):
14291429
if connection is not None:
14301430
connection.disconnect()
14311431

1432-
# Remove the failed node from the startup nodes before we try
1433-
# to reinitialize the cluster
1434-
self.nodes_manager.startup_nodes.pop(target_node.name, None)
1435-
# Reset the cluster node's connection
1436-
target_node.redis_connection = None
1432+
# Instead of setting to None, properly handle the pool
1433+
# Get the pool safely - redis_connection could be set to None
1434+
# by another thread between the check and access
1435+
redis_conn = target_node.redis_connection
1436+
if redis_conn is not None:
1437+
pool = redis_conn.connection_pool
1438+
if pool is not None:
1439+
with pool._lock:
1440+
# take care for the active connections in the pool
1441+
pool.update_active_connections_for_reconnect()
1442+
# disconnect all free connections
1443+
pool.disconnect_free_connections()
1444+
1445+
# Move the failed node to the end of the cached nodes list
1446+
self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)
1447+
1448+
# DON'T set redis_connection = None - keep the pool for reuse
14371449
self.nodes_manager.initialize()
14381450
raise e
14391451
except MovedError as e:
@@ -1814,6 +1826,23 @@ def populate_startup_nodes(self, nodes):
18141826
for n in nodes:
18151827
self.startup_nodes[n.name] = n
18161828

1829+
def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
1830+
"""
1831+
Move a failing node to the end of startup_nodes and nodes_cache so it's
1832+
tried last during reinitialization and when selecting the default node.
1833+
If the node is not in the respective list, nothing is done.
1834+
"""
1835+
# Move in startup_nodes
1836+
if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
1837+
node = self.startup_nodes.pop(node_name)
1838+
self.startup_nodes[node_name] = node # Re-insert at end
1839+
1840+
# Move in nodes_cache - this affects get_nodes_by_server_type ordering
1841+
# which is used to select the default_node during initialize()
1842+
if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
1843+
node = self.nodes_cache.pop(node_name)
1844+
self.nodes_cache[node_name] = node # Re-insert at end
1845+
18171846
def check_slots_coverage(self, slots_cache):
18181847
# Validate if all slots are covered or if we should try next
18191848
# startup node
@@ -1941,10 +1970,16 @@ def initialize(self):
19411970
startup_node.host, startup_node.port, **kwargs
19421971
)
19431972
self.startup_nodes[startup_node.name].redis_connection = r
1944-
# Make sure cluster mode is enabled on this node
19451973
try:
1974+
# Make sure cluster mode is enabled on this node
19461975
cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS"))
1947-
r.connection_pool.disconnect()
1976+
with r.connection_pool._lock:
1977+
# take care to clear connections before we move on
1978+
# mark all active connections for reconnect - they will be
1979+
# reconnected on next use, but will allow current in flight commands to complete first
1980+
r.connection_pool.update_active_connections_for_reconnect()
1981+
# Needed to clear READONLY state when it is no longer applicable
1982+
r.connection_pool.disconnect_free_connections()
19481983
except ResponseError:
19491984
raise RedisClusterException(
19501985
"Cluster mode is not enabled on this node"
@@ -3448,6 +3483,15 @@ def _reinitialize_on_error(self, error):
34483483
or type(error) in self.CONNECTION_ERRORS
34493484
):
34503485
if self._transaction_connection:
3486+
# Disconnect and release back to pool
3487+
self._transaction_connection.disconnect()
3488+
node = self._nodes_manager.find_connection_owner(
3489+
self._transaction_connection
3490+
)
3491+
if node and node.redis_connection:
3492+
node.redis_connection.connection_pool.release(
3493+
self._transaction_connection
3494+
)
34513495
self._transaction_connection = None
34523496

34533497
self._pipe.reinitialize_counter += 1
@@ -3601,14 +3645,23 @@ def reset(self):
36013645
node = self._nodes_manager.find_connection_owner(
36023646
self._transaction_connection
36033647
)
3604-
node.redis_connection.connection_pool.release(
3605-
self._transaction_connection
3606-
)
3648+
if node and node.redis_connection:
3649+
node.redis_connection.connection_pool.release(
3650+
self._transaction_connection
3651+
)
36073652
self._transaction_connection = None
36083653
except self.CONNECTION_ERRORS:
36093654
# disconnect will also remove any previous WATCHes
36103655
if self._transaction_connection:
36113656
self._transaction_connection.disconnect()
3657+
node = self._nodes_manager.find_connection_owner(
3658+
self._transaction_connection
3659+
)
3660+
if node and node.redis_connection:
3661+
node.redis_connection.connection_pool.release(
3662+
self._transaction_connection
3663+
)
3664+
self._transaction_connection = None
36123665

36133666
# clean up the other instance attributes
36143667
self._watching = False

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ def mock_cache() -> CacheInterface:
589589
@pytest.fixture()
590590
def mock_connection() -> ConnectionInterface:
591591
mock_connection = Mock(spec=ConnectionInterface)
592+
# Add host and port attributes needed by find_connection_owner
593+
mock_connection.host = "127.0.0.1"
594+
mock_connection.port = 6379
592595
return mock_connection
593596

594597

0 commit comments

Comments
 (0)