From 567ca6807e679efde936e186fadd4c8a90608525 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 21 Jan 2026 19:08:45 +0200 Subject: [PATCH 1/6] Removing the pool disconnections when cluster topology is refreshed - just marking in use connections for reconnect and disconnecting free ones instead. --- redis/asyncio/cluster.py | 116 +++++- redis/asyncio/connection.py | 5 + redis/cluster.py | 73 +++- tests/conftest.py | 3 + tests/test_asyncio/test_cluster.py | 378 +++++++++++++++++- .../test_asyncio/test_cluster_transaction.py | 48 ++- tests/test_cluster.py | 16 +- tests/test_cluster_transaction.py | 50 ++- 8 files changed, 623 insertions(+), 66 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index c9ac62ec4f..0b3890fbb1 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -973,12 +973,18 @@ async def _execute_command( except (ConnectionError, TimeoutError): # Connection retries are being handled in the node's # Retry object. - # Remove the failed node from the startup nodes before we try - # to reinitialize the cluster - self.nodes_manager.startup_nodes.pop(target_node.name, None) - # Hard force of reinitialize of the node/slots setup - # and try again with the new setup - await self.aclose() + # Mark active connections for reconnect and disconnect free ones + # This handles connection state (like READONLY) that may be stale + target_node.update_active_connections_for_reconnect() + await target_node.disconnect_free_connections() + + # Move the failed node to the end of the cached nodes list + # so it's tried last during reinitialization + self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name) + + # Signal that reinitialization is needed + # The retry loop will handle initialize() AND replace_default_node() + self._initialize = True raise except (ClusterDownError, SlotNotCoveredError): # ClusterDownError can occur during a failover and to get @@ -1263,12 +1269,48 @@ def acquire_connection(self) -> Connection: raise MaxConnectionsError() + async def disconnect_if_needed(self, connection: Connection) -> None: + """ + Disconnect a connection if it's marked for reconnect. + This implements lazy disconnection to avoid race conditions. + The connection will auto-reconnect on next use. + """ + if connection.should_reconnect(): + await connection.disconnect() + def release(self, connection: Connection) -> None: """ Release connection back to free queue. + If the connection is marked for reconnect, it will be disconnected + lazily when next acquired via disconnect_if_needed(). """ self._free.append(connection) + def update_active_connections_for_reconnect(self) -> None: + """ + Mark all in-use (active) connections for reconnect. + In-use connections are those in _connections but not currently in _free. + They will be disconnected when released back to the pool. + """ + free_set = set(self._free) + for connection in self._connections: + if connection not in free_set: + connection.mark_for_reconnect() + + async def disconnect_free_connections(self) -> None: + """ + Disconnect all free/idle connections in the pool. + This is useful after topology changes (e.g., failover) to clear + stale connection state like READONLY mode. + The connections remain in the pool and will reconnect on next use. + """ + if self._free: + # Take a snapshot to avoid issues if _free changes during await + await asyncio.gather( + *(connection.disconnect() for connection in tuple(self._free)), + return_exceptions=True, + ) + async def parse_response( self, connection: Connection, command: str, **kwargs: Any ) -> Any: @@ -1298,6 +1340,8 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = self.acquire_connection() + # Handle lazy disconnect for connections marked for reconnect + await self.disconnect_if_needed(connection) # Execute command await connection.send_packed_command(connection.pack_command(*args), False) @@ -1312,6 +1356,8 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection connection = self.acquire_connection() + # Handle lazy disconnect for connections marked for reconnect + await self.disconnect_if_needed(connection) # Execute command await connection.send_packed_command( @@ -1432,15 +1478,50 @@ def set_nodes( if remove_old: for name in list(old.keys()): if name not in new: - task = asyncio.create_task(old.pop(name).disconnect()) # noqa + # Node is removed from cache before disconnect starts, + # so it won't be found in lookups during disconnect + # Mark active connections for reconnect so they get disconnected after current command completes + # and disconnect free connections immediately + # the node is removed from the cache before the connections changes so it won't be used and should be safe + # not to wait for the disconnects + removed_node = old.pop(name) + removed_node.update_active_connections_for_reconnect() + asyncio.create_task(removed_node.disconnect_free_connections()) # noqa for name, node in new.items(): if name in old: - if old[name] is node: - continue - task = asyncio.create_task(old[name].disconnect()) # noqa + # Preserve the existing node but mark connections for reconnect. + # This method is sync so we can't call disconnect_free_connections() + # which is async. Instead, we mark free connections for reconnect + # and they will be lazily disconnected when acquired via + # disconnect_if_needed() to avoid race conditions. + # TODO: Make this method async in the next major release to allow + # immediate disconnection of free connections. + existing_node = old[name] + existing_node.update_active_connections_for_reconnect() + for conn in existing_node._free: + conn.mark_for_reconnect() + continue + # New node is detected and should be added to the pool old[name] = node + def move_node_to_end_of_cached_nodes(self, node_name: str) -> None: + """ + Move a failing node to the end of startup_nodes and nodes_cache so it's + tried last during reinitialization and when selecting the default node. + If the node is not in the respective list, nothing is done. + """ + # Move in startup_nodes + if node_name in self.startup_nodes and len(self.startup_nodes) > 1: + node = self.startup_nodes.pop(node_name) + self.startup_nodes[node_name] = node # Re-insert at end + + # Move in nodes_cache - this affects get_nodes_by_server_type ordering + # which is used to select the default_node during initialize() + if node_name in self.nodes_cache and len(self.nodes_cache) > 1: + node = self.nodes_cache.pop(node_name) + self.nodes_cache[node_name] = node # Re-insert at end + def move_slot(self, e: AskError | MovedError): redirected_node = self.get_node(host=e.host, port=e.port) if redirected_node: @@ -2351,6 +2432,9 @@ async def _immediate_execute_command(self, *args, **options): async def _get_connection_and_send_command(self, *args, **options): redis_node, connection = self._get_client_and_connection_for_transaction() + # Only disconnect if not watching - disconnecting would lose WATCH state + if not self._watching: + await redis_node.disconnect_if_needed(connection) return await self._send_command_parse_response( connection, redis_node, args[0], *args, **options ) @@ -2383,7 +2467,10 @@ async def _reinitialize_on_error(self, error): type(error) in self.SLOT_REDIRECT_ERRORS or type(error) in self.CONNECTION_ERRORS ): - if self._transaction_connection: + if self._transaction_connection and self._transaction_node: + # Disconnect and release back to pool + await self._transaction_connection.disconnect() + self._transaction_node.release(self._transaction_connection) self._transaction_connection = None self._pipe.cluster_client.reinitialize_counter += 1 @@ -2443,6 +2530,9 @@ async def _execute_transaction( self._executing = True redis_node, connection = self._get_client_and_connection_for_transaction() + # Only disconnect if not watching - disconnecting would lose WATCH state + if not self._watching: + await redis_node.disconnect_if_needed(connection) stack = chain( [PipelineCommand(0, "MULTI")], @@ -2550,8 +2640,10 @@ async def reset(self): self._transaction_connection = None except self.CONNECTION_ERRORS: # disconnect will also remove any previous WATCHes - if self._transaction_connection: + if self._transaction_connection and self._transaction_node: await self._transaction_connection.disconnect() + self._transaction_node.release(self._transaction_connection) + self._transaction_connection = None # clean up the other instance attributes self._transaction_node = None diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2d1ff96bae..3920a1eda1 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -381,6 +381,9 @@ def mark_for_reconnect(self): def should_reconnect(self): return self._should_reconnect + def reset_should_reconnect(self): + self._should_reconnect = False + @abstractmethod async def _connect(self): pass @@ -519,6 +522,8 @@ async def disconnect(self, nowait: bool = False) -> None: try: async with async_timeout(self.socket_connect_timeout): self._parser.on_disconnect() + # Reset the reconnect flag + self.reset_should_reconnect() if not self.is_connected: return try: diff --git a/redis/cluster.py b/redis/cluster.py index 75448d285e..03f5bdb839 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1429,11 +1429,23 @@ def _execute_command(self, target_node, *args, **kwargs): if connection is not None: connection.disconnect() - # Remove the failed node from the startup nodes before we try - # to reinitialize the cluster - self.nodes_manager.startup_nodes.pop(target_node.name, None) - # Reset the cluster node's connection - target_node.redis_connection = None + # Instead of setting to None, properly handle the pool + # Get the pool safely - redis_connection could be set to None + # by another thread between the check and access + redis_conn = target_node.redis_connection + if redis_conn is not None: + pool = redis_conn.connection_pool + if pool is not None: + with pool._lock: + # take care for the active connections in the pool + pool.update_active_connections_for_reconnect() + # disconnect all free connections + pool.disconnect_free_connections() + + # Move the failed node to the end of the cached nodes list + self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name) + + # DON'T set redis_connection = None - keep the pool for reuse self.nodes_manager.initialize() raise e except MovedError as e: @@ -1814,6 +1826,23 @@ def populate_startup_nodes(self, nodes): for n in nodes: self.startup_nodes[n.name] = n + def move_node_to_end_of_cached_nodes(self, node_name: str) -> None: + """ + Move a failing node to the end of startup_nodes and nodes_cache so it's + tried last during reinitialization and when selecting the default node. + If the node is not in the respective list, nothing is done. + """ + # Move in startup_nodes + if node_name in self.startup_nodes and len(self.startup_nodes) > 1: + node = self.startup_nodes.pop(node_name) + self.startup_nodes[node_name] = node # Re-insert at end + + # Move in nodes_cache - this affects get_nodes_by_server_type ordering + # which is used to select the default_node during initialize() + if node_name in self.nodes_cache and len(self.nodes_cache) > 1: + node = self.nodes_cache.pop(node_name) + self.nodes_cache[node_name] = node # Re-insert at end + def check_slots_coverage(self, slots_cache): # Validate if all slots are covered or if we should try next # startup node @@ -1941,10 +1970,16 @@ def initialize(self): startup_node.host, startup_node.port, **kwargs ) self.startup_nodes[startup_node.name].redis_connection = r - # Make sure cluster mode is enabled on this node try: + # Make sure cluster mode is enabled on this node cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - r.connection_pool.disconnect() + with r.connection_pool._lock: + # take care to clear connections before we move on + # mark all active connections for reconnect - they will be + # reconnected on next use, but will allow current in flight commands to complete first + r.connection_pool.update_active_connections_for_reconnect() + # Needed to clear READONLY state when it is no longer applicable + r.connection_pool.disconnect_free_connections() except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" @@ -3448,6 +3483,15 @@ def _reinitialize_on_error(self, error): or type(error) in self.CONNECTION_ERRORS ): if self._transaction_connection: + # Disconnect and release back to pool + self._transaction_connection.disconnect() + node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + if node: + node.redis_connection.connection_pool.release( + self._transaction_connection + ) self._transaction_connection = None self._pipe.reinitialize_counter += 1 @@ -3601,14 +3645,23 @@ def reset(self): node = self._nodes_manager.find_connection_owner( self._transaction_connection ) - node.redis_connection.connection_pool.release( - self._transaction_connection - ) + if node: + node.redis_connection.connection_pool.release( + self._transaction_connection + ) self._transaction_connection = None except self.CONNECTION_ERRORS: # disconnect will also remove any previous WATCHes if self._transaction_connection: self._transaction_connection.disconnect() + node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + if node: + node.redis_connection.connection_pool.release( + self._transaction_connection + ) + self._transaction_connection = None # clean up the other instance attributes self._watching = False diff --git a/tests/conftest.py b/tests/conftest.py index 9d2f51795a..91bc32d600 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -589,6 +589,9 @@ def mock_cache() -> CacheInterface: @pytest.fixture() def mock_connection() -> ConnectionInterface: mock_connection = Mock(spec=ConnectionInterface) + # Add host and port attributes needed by find_connection_owner + mock_connection.host = "127.0.0.1" + mock_connection.port = 6379 return mock_connection diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 8ca0bb8541..19e221d5af 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -973,14 +973,23 @@ async def test_default_node_is_replaced_after_exception(self, r): # CLUSTER NODES command is being executed on the default node nodes = await r.cluster_nodes() assert "myself" in nodes.get(curr_default_node.name).get("flags") - # Mock connection error for the default node - mock_node_resp_exc(curr_default_node, ConnectionError("error")) - # Test that the command succeed from a different node - nodes = await r.cluster_nodes() - assert "myself" not in nodes.get(curr_default_node.name).get("flags") - assert r.get_default_node() != curr_default_node - # Rollback to the old default node - r.replace_default_node(curr_default_node) + # Save original free connections to restore later + original_free = list(curr_default_node._free) + try: + # Mock connection error for the default node + mock_node_resp_exc(curr_default_node, ConnectionError("error")) + # Test that the command succeed from a different node + nodes = await r.cluster_nodes() + assert "myself" not in nodes.get(curr_default_node.name).get("flags") + assert r.get_default_node() != curr_default_node + finally: + # Restore original connections so teardown can work + while curr_default_node._free: + curr_default_node._free.pop() + for conn in original_free: + curr_default_node._free.append(conn) + # Rollback to the old default node + r.replace_default_node(curr_default_node) async def test_address_remap(self, create_redis, master_host): """Test that we can create a rediscluster object with @@ -2768,6 +2777,359 @@ async def test_init_slots_dynamic_startup_nodes(self, dynamic_startup_nodes): else: assert startup_nodes == ["my@DNS.com:7000"] + async def test_move_node_to_end_of_cached_nodes(self) -> None: + """ + Test that move_node_to_end_of_cached_nodes moves a node to the end of + startup_nodes and nodes_cache. + """ + node1 = ClusterNode(default_host, 7000) + node2 = ClusterNode(default_host, 7001) + node3 = ClusterNode(default_host, 7002) + + nodes_manager = NodesManager( + startup_nodes=[node1, node2, node3], + require_full_coverage=False, + connection_kwargs={}, + ) + # Also populate nodes_cache with the same nodes + nodes_manager.nodes_cache = { + node1.name: node1, + node2.name: node2, + node3.name: node3, + } + + # Verify initial order + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name, node2.name, node3.name] + assert nodes_cache_names == [node1.name, node2.name, node3.name] + + # Move first node to end + nodes_manager.move_node_to_end_of_cached_nodes(node1.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node3.name, node1.name] + assert nodes_cache_names == [node2.name, node3.name, node1.name] + + # Move middle node to end + nodes_manager.move_node_to_end_of_cached_nodes(node3.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node1.name, node3.name] + assert nodes_cache_names == [node2.name, node1.name, node3.name] + + # Moving last node should keep it at the end + nodes_manager.move_node_to_end_of_cached_nodes(node3.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node1.name, node3.name] + assert nodes_cache_names == [node2.name, node1.name, node3.name] + + async def test_move_node_to_end_of_cached_nodes_nonexistent(self) -> None: + """ + Test that move_node_to_end_of_cached_nodes does nothing for a + nonexistent node. + """ + node1 = ClusterNode(default_host, 7000) + node2 = ClusterNode(default_host, 7001) + + nodes_manager = NodesManager( + startup_nodes=[node1, node2], + require_full_coverage=False, + connection_kwargs={}, + ) + # Also populate nodes_cache + nodes_manager.nodes_cache = {node1.name: node1, node2.name: node2} + + # Try to move a non-existent node - should not raise + nodes_manager.move_node_to_end_of_cached_nodes("nonexistent:9999") + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name, node2.name] + assert nodes_cache_names == [node1.name, node2.name] + + async def test_move_node_to_end_of_cached_nodes_single_node(self) -> None: + """ + Test that move_node_to_end_of_cached_nodes does nothing when there's + only one node. + """ + node1 = ClusterNode(default_host, 7000) + + nodes_manager = NodesManager( + startup_nodes=[node1], + require_full_coverage=False, + connection_kwargs={}, + ) + # Also populate nodes_cache + nodes_manager.nodes_cache = {node1.name: node1} + + # Should not raise or change anything with single node + nodes_manager.move_node_to_end_of_cached_nodes(node1.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name] + assert nodes_cache_names == [node1.name] + + +class TestClusterNodeConnectionHandling: + """Tests for ClusterNode connection handling methods.""" + + async def test_update_active_connections_for_reconnect(self) -> None: + """ + Test that update_active_connections_for_reconnect marks in-use connections. + """ + node = ClusterNode(default_host, 7000) + + # Create mock connections + conn1 = mock.AsyncMock(spec=Connection) + conn2 = mock.AsyncMock(spec=Connection) + conn3 = mock.AsyncMock(spec=Connection) + + # Add all connections to _connections + node._connections = [conn1, conn2, conn3] + # Only conn1 is free, conn2 and conn3 are "in-use" + node._free.append(conn1) + + # Mark active connections for reconnect + node.update_active_connections_for_reconnect() + + # conn1 is free, should NOT be marked + conn1.mark_for_reconnect.assert_not_called() + # conn2 and conn3 are in-use, should be marked + conn2.mark_for_reconnect.assert_called_once() + conn3.mark_for_reconnect.assert_called_once() + + async def test_disconnect_free_connections(self) -> None: + """ + Test that disconnect_free_connections disconnects all free connections. + """ + node = ClusterNode(default_host, 7000) + + # Create mock connections + conn1 = mock.AsyncMock(spec=Connection) + conn2 = mock.AsyncMock(spec=Connection) + conn3 = mock.AsyncMock(spec=Connection) + + # Add all connections to _connections + node._connections = [conn1, conn2, conn3] + # conn1 and conn2 are free, conn3 is "in-use" + node._free.append(conn1) + node._free.append(conn2) + + # Disconnect free connections + await node.disconnect_free_connections() + + # conn1 and conn2 should be disconnected + conn1.disconnect.assert_called_once() + conn2.disconnect.assert_called_once() + # conn3 is in-use, should NOT be disconnected + conn3.disconnect.assert_not_called() + + # Connections should still be in _free (not removed) + assert conn1 in node._free + assert conn2 in node._free + + async def test_disconnect_free_connections_empty(self) -> None: + """ + Test that disconnect_free_connections handles empty _free gracefully. + """ + node = ClusterNode(default_host, 7000) + + # No free connections + assert len(node._free) == 0 + + # Should not raise + await node.disconnect_free_connections() + + async def test_release_with_reconnect_flag(self) -> None: + """ + Test that release() adds connection to _free even if marked for reconnect. + Disconnect happens lazily via disconnect_if_needed() when next acquired. + """ + node = ClusterNode(default_host, 7000) + + # Create a mock connection marked for reconnect + conn = mock.AsyncMock(spec=Connection) + conn.should_reconnect.return_value = True + + node._connections = [conn] + + # Release the connection - sync, just adds to _free + node.release(conn) + + # Connection should be in _free, disconnect happens lazily on acquire + assert conn in node._free + conn.disconnect.assert_not_called() + + async def test_release_without_reconnect_flag(self) -> None: + """ + Test that release() adds connection to _free without disconnect. + """ + node = ClusterNode(default_host, 7000) + + # Create a mock connection NOT marked for reconnect + conn = mock.AsyncMock(spec=Connection) + conn.should_reconnect.return_value = False + + node._connections = [conn] + + # Release the connection + node.release(conn) + + # Connection should NOT be disconnected but added to _free + conn.disconnect.assert_not_called() + assert conn in node._free + + async def test_disconnect_if_needed_disconnects_when_reconnect_needed( + self, + ) -> None: + """ + Test that disconnect_if_needed() disconnects a connection marked for reconnect. + This implements lazy disconnect to avoid race conditions. + """ + node = ClusterNode(default_host, 7000) + + # Create a mock connection marked for reconnect + conn = mock.AsyncMock(spec=Connection) + conn.should_reconnect.return_value = True + + # disconnect_if_needed should disconnect the connection + await node.disconnect_if_needed(conn) + + conn.disconnect.assert_called_once() + + async def test_disconnect_if_needed_skips_when_no_reconnect_needed(self) -> None: + """ + Test that disconnect_if_needed() does not disconnect if no reconnect needed. + """ + node = ClusterNode(default_host, 7000) + + # Create a mock connection NOT marked for reconnect + conn = mock.AsyncMock(spec=Connection) + conn.should_reconnect.return_value = False + + # disconnect_if_needed should not disconnect + await node.disconnect_if_needed(conn) + + conn.disconnect.assert_not_called() + + +class TestClusterConnectionErrorHandling: + """Tests for cluster connection error handling behavior.""" + + async def test_connection_error_calls_move_node_to_end_of_cached_nodes( + self, + ) -> None: + """ + Test that ConnectionError triggers move_node_to_end_of_cached_nodes + instead of pop. + """ + with mock.patch.object( + NodesManager, "move_node_to_end_of_cached_nodes", autospec=True + ) as move_node_to_end_of_cached_nodes: + with mock.patch.object(ClusterNode, "execute_command") as execute_command: + + async def execute_command_mock(*args, **kwargs): + if args[0] == "CLUSTER SLOTS": + return default_cluster_slots + elif args[0] == "COMMAND": + return {"get": [], "set": []} + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif len(args) > 1 and args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} + elif args[0] == "GET": + raise ConnectionError("Connection failed") + return None + + execute_command.side_effect = execute_command_mock + + with mock.patch.object( + AsyncCommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, node: Optional[ClusterNode] = None) -> None: + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = await RedisCluster(host=default_host, port=7000) + with pytest.raises(ConnectionError): + await rc.get("foo") + + # Verify move_node_to_end_of_cached_nodes was called + move_node_to_end_of_cached_nodes.assert_called() + + async def test_connection_error_handles_node_connections(self) -> None: + """ + Test that ConnectionError triggers proper connection handling on the node. + """ + with mock.patch.object( + ClusterNode, + "update_active_connections_for_reconnect", + autospec=True, + ) as update_active: + with mock.patch.object( + ClusterNode, "disconnect_free_connections", autospec=True + ) as disconnect_free: + with mock.patch.object( + ClusterNode, "execute_command" + ) as execute_command: + + async def execute_command_mock(*args, **kwargs): + if args[0] == "CLUSTER SLOTS": + return default_cluster_slots + elif args[0] == "COMMAND": + return {"get": [], "set": []} + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif ( + len(args) > 1 and args[1] == "cluster-require-full-coverage" + ): + return {"cluster-require-full-coverage": "yes"} + elif args[0] == "GET": + raise ConnectionError("Connection failed") + return None + + execute_command.side_effect = execute_command_mock + + with mock.patch.object( + AsyncCommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock( + self, node: Optional[ClusterNode] = None + ) -> None: + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = await RedisCluster(host=default_host, port=7000) + with pytest.raises(ConnectionError): + await rc.get("foo") + + # Verify connection handling methods were called + update_active.assert_called() + disconnect_free.assert_called() + class TestClusterPipeline: """Tests for the ClusterPipeline class.""" diff --git a/tests/test_asyncio/test_cluster_transaction.py b/tests/test_asyncio/test_cluster_transaction.py index e39d4aaab9..b1ed5f4bdb 100644 --- a/tests/test_asyncio/test_cluster_transaction.py +++ b/tests/test_asyncio/test_cluster_transaction.py @@ -268,47 +268,69 @@ async def test_retry_transaction_on_connection_error(self, r): key = "book" slot = r.keyslot(key) + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + original_slots_cache = r.nodes_manager.slots_cache[slot] + mock_connection = Mock(spec=Connection) mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( "Conn error" ) mock_connection.retry = Retry(NoBackoff(), 0) + # Set host/port to match the node for find_connection_owner + mock_connection.host = node_importing.host + mock_connection.port = node_importing.port - _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) node_importing._free.append(mock_connection) r.nodes_manager.slots_cache[slot] = [node_importing] r.reinitialize_steps = 1 - async with r.pipeline(transaction=True) as pipe: - pipe.set(key, "val") - assert await pipe.execute() == [True] + try: + async with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + assert await pipe.execute() == [True] - assert mock_connection.read_response.call_count == 1 + assert mock_connection.read_response.call_count == 1 + finally: + # Clean up mock connection from node so teardown can work + if mock_connection in node_importing._free: + node_importing._free.remove(mock_connection) + r.nodes_manager.slots_cache[slot] = original_slots_cache @pytest.mark.onlycluster async def test_retry_transaction_on_connection_error_with_watched_keys(self, r): key = "book" slot = r.keyslot(key) + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + original_slots_cache = r.nodes_manager.slots_cache[slot] + mock_connection = Mock(spec=Connection) mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( "Conn error" ) mock_connection.retry = Retry(NoBackoff(), 0) + # Set host/port to match the node for find_connection_owner + mock_connection.host = node_importing.host + mock_connection.port = node_importing.port - _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) node_importing._free.append(mock_connection) r.nodes_manager.slots_cache[slot] = [node_importing] r.reinitialize_steps = 1 - async with r.pipeline(transaction=True) as pipe: - await pipe.watch(key) - - pipe.multi() - pipe.set(key, "val") - assert await pipe.execute() == [True] + try: + async with r.pipeline(transaction=True) as pipe: + await pipe.watch(key) - assert mock_connection.read_response.call_count == 1 + pipe.multi() + pipe.set(key, "val") + assert await pipe.execute() == [True] + + assert mock_connection.read_response.call_count == 1 + finally: + # Clean up mock connection from node so teardown can work + if mock_connection in node_importing._free: + node_importing._free.remove(mock_connection) + r.nodes_manager.slots_cache[slot] = original_slots_cache @pytest.mark.onlycluster async def test_exec_error_raised(self, r): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index f70d650861..7583392d04 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -933,15 +933,13 @@ def moved_redirect_effect(connection, *args, **options): parse_response.side_effect = moved_redirect_effect assert r.get("key") == b"value" for node_name, conn in node_conn_origin.items(): - if node_name == node.name: - # The old redis connection of the timed out node should have been - # deleted and replaced - assert conn != r.get_redis_connection(node) - else: - # other nodes' redis connection should have been reused during the - # topology refresh - cur_node = r.get_node(node_name=node_name) - assert conn == r.get_redis_connection(cur_node) + # all nodes' redis connection should have been reused during the + # topology refresh + # even the failing node doesn't need to establish a + # new Redis connection (which is actually a new Redis Client instance) + # but the connection pool is reused and all connections are reset and reconnected + cur_node = r.get_node(node_name=node_name) + assert conn == r.get_redis_connection(cur_node) def test_cluster_get_set_retry_object(self, request): retry = Retry(NoBackoff(), 2) diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index d5b21abc9d..0290fb2f7d 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -272,13 +272,24 @@ def test_retry_transaction_on_connection_error(self, r, mock_connection): mock_pool._lock = threading.RLock() _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) - node_importing.redis_connection.connection_pool = mock_pool - r.nodes_manager.slots_cache[slot] = [node_importing] - r.reinitialize_steps = 1 + # Set mock connection's host/port to match the node for find_connection_owner + mock_connection.host = node_importing.host + mock_connection.port = node_importing.port + # Save original pool to restore later + original_pool = node_importing.redis_connection.connection_pool + original_slots_cache = r.nodes_manager.slots_cache[slot] + try: + node_importing.redis_connection.connection_pool = mock_pool + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 - with r.pipeline(transaction=True) as pipe: - pipe.set(key, "val") - assert pipe.execute() == [b"OK"] + with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + assert pipe.execute() == [b"OK"] + finally: + # Restore original pool so teardown can work + node_importing.redis_connection.connection_pool = original_pool + r.nodes_manager.slots_cache[slot] = original_slots_cache @pytest.mark.onlycluster def test_retry_transaction_on_connection_error_with_watched_keys( @@ -298,15 +309,26 @@ def test_retry_transaction_on_connection_error_with_watched_keys( mock_pool.connection_kwargs = {} _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) - node_importing.redis_connection.connection_pool = mock_pool - r.nodes_manager.slots_cache[slot] = [node_importing] - r.reinitialize_steps = 1 + # Set mock connection's host/port to match the node for find_connection_owner + mock_connection.host = node_importing.host + mock_connection.port = node_importing.port + # Save original pool to restore later + original_pool = node_importing.redis_connection.connection_pool + original_slots_cache = r.nodes_manager.slots_cache[slot] + try: + node_importing.redis_connection.connection_pool = mock_pool + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 - with r.pipeline(transaction=True) as pipe: - pipe.watch(key) - pipe.multi() - pipe.set(key, "val") - assert pipe.execute() == [b"OK"] + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + assert pipe.execute() == [b"OK"] + finally: + # Restore original pool so teardown can work + node_importing.redis_connection.connection_pool = original_pool + r.nodes_manager.slots_cache[slot] = original_slots_cache @pytest.mark.onlycluster def test_exec_error_raised(self, r): From 5766626919165e58d7af510528778078e9b5e3d0 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 23 Jan 2026 10:06:41 +0200 Subject: [PATCH 2/6] Adding defensive check when accessing connection_pool - applying review comments --- redis/cluster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 03f5bdb839..bd371626ad 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3488,7 +3488,7 @@ def _reinitialize_on_error(self, error): node = self._nodes_manager.find_connection_owner( self._transaction_connection ) - if node: + if node and node.redis_connection: node.redis_connection.connection_pool.release( self._transaction_connection ) @@ -3645,7 +3645,7 @@ def reset(self): node = self._nodes_manager.find_connection_owner( self._transaction_connection ) - if node: + if node and node.redis_connection: node.redis_connection.connection_pool.release( self._transaction_connection ) @@ -3657,7 +3657,7 @@ def reset(self): node = self._nodes_manager.find_connection_owner( self._transaction_connection ) - if node: + if node and node.redis_connection: node.redis_connection.connection_pool.release( self._transaction_connection ) From 8cc3b6fc1c9c4483cc21744b01d41201179c9d4b Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 23 Jan 2026 10:52:26 +0200 Subject: [PATCH 3/6] Adding unit tests for newly introduced method move_node_to_end_of_cached_nodes(sync client version) --- tests/test_cluster.py | 93 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 7583392d04..36b8ef65a6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3217,6 +3217,99 @@ def move_slots_worker(): # primary should be first assert slot_nodes[0].server_type == PRIMARY + def test_move_node_to_end_of_cached_nodes(self): + """ + Test that move_node_to_end_of_cached_nodes moves a node to the end of + startup_nodes and nodes_cache. + """ + node1 = ClusterNode(default_host, 7000) + node2 = ClusterNode(default_host, 7001) + node3 = ClusterNode(default_host, 7002) + + with patch.object(NodesManager, "initialize"): + nodes_manager = NodesManager( + startup_nodes=[node1, node2, node3], + require_full_coverage=False, + ) + # Also populate nodes_cache with the same nodes + nodes_manager.nodes_cache = { + node1.name: node1, + node2.name: node2, + node3.name: node3, + } + + # Verify initial order + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name, node2.name, node3.name] + assert nodes_cache_names == [node1.name, node2.name, node3.name] + + # Move first node to end + nodes_manager.move_node_to_end_of_cached_nodes(node1.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node3.name, node1.name] + assert nodes_cache_names == [node2.name, node3.name, node1.name] + + # Move middle node to end + nodes_manager.move_node_to_end_of_cached_nodes(node3.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node1.name, node3.name] + assert nodes_cache_names == [node2.name, node1.name, node3.name] + + # Moving last node should keep it at the end + nodes_manager.move_node_to_end_of_cached_nodes(node3.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node2.name, node1.name, node3.name] + assert nodes_cache_names == [node2.name, node1.name, node3.name] + + def test_move_node_to_end_of_cached_nodes_nonexistent(self): + """ + Test that move_node_to_end_of_cached_nodes does nothing for a + nonexistent node. + """ + node1 = ClusterNode(default_host, 7000) + node2 = ClusterNode(default_host, 7001) + + with patch.object(NodesManager, "initialize"): + nodes_manager = NodesManager( + startup_nodes=[node1, node2], + require_full_coverage=False, + ) + # Also populate nodes_cache + nodes_manager.nodes_cache = {node1.name: node1, node2.name: node2} + + # Try to move a non-existent node - should not raise + nodes_manager.move_node_to_end_of_cached_nodes("nonexistent:9999") + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name, node2.name] + assert nodes_cache_names == [node1.name, node2.name] + + def test_move_node_to_end_of_cached_nodes_single_node(self): + """ + Test that move_node_to_end_of_cached_nodes does nothing when there's + only one node. + """ + node1 = ClusterNode(default_host, 7000) + + with patch.object(NodesManager, "initialize"): + nodes_manager = NodesManager( + startup_nodes=[node1], + require_full_coverage=False, + ) + # Also populate nodes_cache + nodes_manager.nodes_cache = {node1.name: node1} + + # Should not raise or change anything with single node + nodes_manager.move_node_to_end_of_cached_nodes(node1.name) + startup_node_names = list(nodes_manager.startup_nodes.keys()) + nodes_cache_names = list(nodes_manager.nodes_cache.keys()) + assert startup_node_names == [node1.name] + assert nodes_cache_names == [node1.name] + @pytest.mark.onlycluster class TestClusterPubSubObject: From ea096418ea5b6c5aff00ae1137bd1db55c17e41d Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 26 Jan 2026 11:14:14 +0200 Subject: [PATCH 4/6] Adding connection disconnect if needed for async after command execution. Removing unneeded connection mock attributes. --- redis/asyncio/cluster.py | 2 ++ tests/test_asyncio/test_cluster_transaction.py | 6 ------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 0b3890fbb1..0c5fd6e65a 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1350,6 +1350,7 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: try: return await self.parse_response(connection, args[0], **kwargs) finally: + await self.disconnect_if_needed(connection) # Release connection self._free.append(connection) @@ -1376,6 +1377,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: ret = True # Release connection + await self.disconnect_if_needed(connection) self._free.append(connection) return ret diff --git a/tests/test_asyncio/test_cluster_transaction.py b/tests/test_asyncio/test_cluster_transaction.py index b1ed5f4bdb..e1394c468f 100644 --- a/tests/test_asyncio/test_cluster_transaction.py +++ b/tests/test_asyncio/test_cluster_transaction.py @@ -276,9 +276,6 @@ async def test_retry_transaction_on_connection_error(self, r): "Conn error" ) mock_connection.retry = Retry(NoBackoff(), 0) - # Set host/port to match the node for find_connection_owner - mock_connection.host = node_importing.host - mock_connection.port = node_importing.port node_importing._free.append(mock_connection) r.nodes_manager.slots_cache[slot] = [node_importing] @@ -309,9 +306,6 @@ async def test_retry_transaction_on_connection_error_with_watched_keys(self, r): "Conn error" ) mock_connection.retry = Retry(NoBackoff(), 0) - # Set host/port to match the node for find_connection_owner - mock_connection.host = node_importing.host - mock_connection.port = node_importing.port node_importing._free.append(mock_connection) r.nodes_manager.slots_cache[slot] = [node_importing] From 0e29f3a74b71d1b5dbbbc41de0eaf83eff4189b4 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 27 Jan 2026 16:39:19 +0200 Subject: [PATCH 5/6] Fixing failing flaky test --- tests/test_asyncio/test_cluster.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 19e221d5af..4a13fb079c 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -501,6 +501,15 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: ) ) + # Explicitly disconnect all nodes to release connections that are still + # in use by the background tasks. When asyncio.gather() raises + # MaxConnectionsError, the other 10 tasks continue running in the + # background (blocked in the mocked read_response). Without this cleanup, + # the test teardown will fail with MaxConnectionsError when trying to + # call flushdb() because all connections are still in use. + for node in rc.get_nodes(): + await node.disconnect() + await rc.aclose() async def test_execute_command_errors(self, r: RedisCluster) -> None: From 854dfa925ad767a0f34cfca84640e7f573855db9 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 27 Jan 2026 17:25:11 +0200 Subject: [PATCH 6/6] Updating test --- tests/test_asyncio/test_cluster.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 4a13fb079c..a0c5fb5fc1 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -489,7 +489,7 @@ async def test_max_connections( with mock.patch.object(Connection, "read_response") as read_response: async def read_response_mocked(*args: Any, **kwargs: Any) -> None: - await asyncio.sleep(10) + await asyncio.sleep(0.1) read_response.side_effect = read_response_mocked @@ -501,14 +501,13 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: ) ) - # Explicitly disconnect all nodes to release connections that are still - # in use by the background tasks. When asyncio.gather() raises - # MaxConnectionsError, the other 10 tasks continue running in the - # background (blocked in the mocked read_response). Without this cleanup, - # the test teardown will fail with MaxConnectionsError when trying to - # call flushdb() because all connections are still in use. - for node in rc.get_nodes(): - await node.disconnect() + # Wait for background tasks to complete and release their connections. + # When asyncio.gather() raises MaxConnectionsError, the other 10 tasks + # continue running in the background. Since commit f6bbfb45 added + # 'await disconnect_if_needed()' to the finally block, we must wait + # for tasks to complete naturally before teardown, otherwise we get + # race conditions with connections being disconnected while still in use. + await asyncio.sleep(0.2) await rc.aclose()