Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 106 additions & 12 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,12 +973,18 @@ async def _execute_command(
except (ConnectionError, TimeoutError):
# Connection retries are being handled in the node's
# Retry object.
# Remove the failed node from the startup nodes before we try
# to reinitialize the cluster
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Hard force of reinitialize of the node/slots setup
# and try again with the new setup
await self.aclose()
# Mark active connections for reconnect and disconnect free ones
# This handles connection state (like READONLY) that may be stale
target_node.update_active_connections_for_reconnect()
await target_node.disconnect_free_connections()

# Move the failed node to the end of the cached nodes list
# so it's tried last during reinitialization
self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)

# Signal that reinitialization is needed
# The retry loop will handle initialize() AND replace_default_node()
self._initialize = True
raise
except (ClusterDownError, SlotNotCoveredError):
# ClusterDownError can occur during a failover and to get
Expand Down Expand Up @@ -1263,12 +1269,48 @@ def acquire_connection(self) -> Connection:

raise MaxConnectionsError()

async def disconnect_if_needed(self, connection: Connection) -> None:
"""
Disconnect a connection if it's marked for reconnect.
This implements lazy disconnection to avoid race conditions.
The connection will auto-reconnect on next use.
"""
if connection.should_reconnect():
await connection.disconnect()

def release(self, connection: Connection) -> None:
"""
Release connection back to free queue.
If the connection is marked for reconnect, it will be disconnected
lazily when next acquired via disconnect_if_needed().
"""
self._free.append(connection)

def update_active_connections_for_reconnect(self) -> None:
"""
Mark all in-use (active) connections for reconnect.
In-use connections are those in _connections but not currently in _free.
They will be disconnected when released back to the pool.
"""
free_set = set(self._free)
for connection in self._connections:
if connection not in free_set:
connection.mark_for_reconnect()

async def disconnect_free_connections(self) -> None:
"""
Disconnect all free/idle connections in the pool.
This is useful after topology changes (e.g., failover) to clear
stale connection state like READONLY mode.
The connections remain in the pool and will reconnect on next use.
"""
if self._free:
# Take a snapshot to avoid issues if _free changes during await
await asyncio.gather(
*(connection.disconnect() for connection in tuple(self._free)),
return_exceptions=True,
)

async def parse_response(
self, connection: Connection, command: str, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -1298,6 +1340,8 @@ async def parse_response(
async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = self.acquire_connection()
# Handle lazy disconnect for connections marked for reconnect
await self.disconnect_if_needed(connection)

# Execute command
await connection.send_packed_command(connection.pack_command(*args), False)
Expand All @@ -1306,12 +1350,15 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
try:
return await self.parse_response(connection, args[0], **kwargs)
finally:
await self.disconnect_if_needed(connection)
# Release connection
self._free.append(connection)

async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
connection = self.acquire_connection()
# Handle lazy disconnect for connections marked for reconnect
await self.disconnect_if_needed(connection)

# Execute command
await connection.send_packed_command(
Expand All @@ -1330,6 +1377,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
ret = True

# Release connection
await self.disconnect_if_needed(connection)
self._free.append(connection)

return ret
Expand Down Expand Up @@ -1432,15 +1480,50 @@ def set_nodes(
if remove_old:
for name in list(old.keys()):
if name not in new:
task = asyncio.create_task(old.pop(name).disconnect()) # noqa
# Node is removed from cache before disconnect starts,
# so it won't be found in lookups during disconnect
# Mark active connections for reconnect so they get disconnected after current command completes
# and disconnect free connections immediately
# the node is removed from the cache before the connections changes so it won't be used and should be safe
# not to wait for the disconnects
removed_node = old.pop(name)
removed_node.update_active_connections_for_reconnect()
asyncio.create_task(removed_node.disconnect_free_connections()) # noqa

for name, node in new.items():
if name in old:
if old[name] is node:
continue
task = asyncio.create_task(old[name].disconnect()) # noqa
# Preserve the existing node but mark connections for reconnect.
# This method is sync so we can't call disconnect_free_connections()
# which is async. Instead, we mark free connections for reconnect
# and they will be lazily disconnected when acquired via
# disconnect_if_needed() to avoid race conditions.
# TODO: Make this method async in the next major release to allow
# immediate disconnection of free connections.
existing_node = old[name]
existing_node.update_active_connections_for_reconnect()
for conn in existing_node._free:
conn.mark_for_reconnect()
continue
# New node is detected and should be added to the pool
old[name] = node

def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
"""
Move a failing node to the end of startup_nodes and nodes_cache so it's
tried last during reinitialization and when selecting the default node.
If the node is not in the respective list, nothing is done.
"""
# Move in startup_nodes
if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
node = self.startup_nodes.pop(node_name)
self.startup_nodes[node_name] = node # Re-insert at end

# Move in nodes_cache - this affects get_nodes_by_server_type ordering
# which is used to select the default_node during initialize()
if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
node = self.nodes_cache.pop(node_name)
self.nodes_cache[node_name] = node # Re-insert at end

def move_slot(self, e: AskError | MovedError):
redirected_node = self.get_node(host=e.host, port=e.port)
if redirected_node:
Expand Down Expand Up @@ -2351,6 +2434,9 @@ async def _immediate_execute_command(self, *args, **options):

async def _get_connection_and_send_command(self, *args, **options):
redis_node, connection = self._get_client_and_connection_for_transaction()
# Only disconnect if not watching - disconnecting would lose WATCH state
if not self._watching:
await redis_node.disconnect_if_needed(connection)
return await self._send_command_parse_response(
connection, redis_node, args[0], *args, **options
)
Expand Down Expand Up @@ -2383,7 +2469,10 @@ async def _reinitialize_on_error(self, error):
type(error) in self.SLOT_REDIRECT_ERRORS
or type(error) in self.CONNECTION_ERRORS
):
if self._transaction_connection:
if self._transaction_connection and self._transaction_node:
# Disconnect and release back to pool
await self._transaction_connection.disconnect()
self._transaction_node.release(self._transaction_connection)
self._transaction_connection = None

self._pipe.cluster_client.reinitialize_counter += 1
Expand Down Expand Up @@ -2443,6 +2532,9 @@ async def _execute_transaction(
self._executing = True

redis_node, connection = self._get_client_and_connection_for_transaction()
# Only disconnect if not watching - disconnecting would lose WATCH state
if not self._watching:
await redis_node.disconnect_if_needed(connection)

stack = chain(
[PipelineCommand(0, "MULTI")],
Expand Down Expand Up @@ -2550,8 +2642,10 @@ async def reset(self):
self._transaction_connection = None
except self.CONNECTION_ERRORS:
# disconnect will also remove any previous WATCHes
if self._transaction_connection:
if self._transaction_connection and self._transaction_node:
await self._transaction_connection.disconnect()
self._transaction_node.release(self._transaction_connection)
self._transaction_connection = None

# clean up the other instance attributes
self._transaction_node = None
Expand Down
5 changes: 5 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ def mark_for_reconnect(self):
def should_reconnect(self):
return self._should_reconnect

def reset_should_reconnect(self):
self._should_reconnect = False

@abstractmethod
async def _connect(self):
pass
Expand Down Expand Up @@ -519,6 +522,8 @@ async def disconnect(self, nowait: bool = False) -> None:
try:
async with async_timeout(self.socket_connect_timeout):
self._parser.on_disconnect()
# Reset the reconnect flag
self.reset_should_reconnect()
if not self.is_connected:
return
try:
Expand Down
73 changes: 63 additions & 10 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,11 +1429,23 @@ def _execute_command(self, target_node, *args, **kwargs):
if connection is not None:
connection.disconnect()

# Remove the failed node from the startup nodes before we try
# to reinitialize the cluster
self.nodes_manager.startup_nodes.pop(target_node.name, None)
# Reset the cluster node's connection
target_node.redis_connection = None
# Instead of setting to None, properly handle the pool
# Get the pool safely - redis_connection could be set to None
# by another thread between the check and access
redis_conn = target_node.redis_connection
if redis_conn is not None:
pool = redis_conn.connection_pool
if pool is not None:
with pool._lock:
# take care for the active connections in the pool
pool.update_active_connections_for_reconnect()
# disconnect all free connections
pool.disconnect_free_connections()

# Move the failed node to the end of the cached nodes list
self.nodes_manager.move_node_to_end_of_cached_nodes(target_node.name)

# DON'T set redis_connection = None - keep the pool for reuse
self.nodes_manager.initialize()
raise e
except MovedError as e:
Expand Down Expand Up @@ -1814,6 +1826,23 @@ def populate_startup_nodes(self, nodes):
for n in nodes:
self.startup_nodes[n.name] = n

def move_node_to_end_of_cached_nodes(self, node_name: str) -> None:
"""
Move a failing node to the end of startup_nodes and nodes_cache so it's
tried last during reinitialization and when selecting the default node.
If the node is not in the respective list, nothing is done.
"""
# Move in startup_nodes
if node_name in self.startup_nodes and len(self.startup_nodes) > 1:
node = self.startup_nodes.pop(node_name)
self.startup_nodes[node_name] = node # Re-insert at end

# Move in nodes_cache - this affects get_nodes_by_server_type ordering
# which is used to select the default_node during initialize()
if node_name in self.nodes_cache and len(self.nodes_cache) > 1:
node = self.nodes_cache.pop(node_name)
self.nodes_cache[node_name] = node # Re-insert at end

def check_slots_coverage(self, slots_cache):
# Validate if all slots are covered or if we should try next
# startup node
Expand Down Expand Up @@ -1941,10 +1970,16 @@ def initialize(self):
startup_node.host, startup_node.port, **kwargs
)
self.startup_nodes[startup_node.name].redis_connection = r
# Make sure cluster mode is enabled on this node
try:
# Make sure cluster mode is enabled on this node
cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS"))
r.connection_pool.disconnect()
with r.connection_pool._lock:
# take care to clear connections before we move on
# mark all active connections for reconnect - they will be
# reconnected on next use, but will allow current in flight commands to complete first
r.connection_pool.update_active_connections_for_reconnect()
# Needed to clear READONLY state when it is no longer applicable
r.connection_pool.disconnect_free_connections()
except ResponseError:
raise RedisClusterException(
"Cluster mode is not enabled on this node"
Expand Down Expand Up @@ -3448,6 +3483,15 @@ def _reinitialize_on_error(self, error):
or type(error) in self.CONNECTION_ERRORS
):
if self._transaction_connection:
# Disconnect and release back to pool
self._transaction_connection.disconnect()
node = self._nodes_manager.find_connection_owner(
self._transaction_connection
)
if node and node.redis_connection:
node.redis_connection.connection_pool.release(
self._transaction_connection
)
self._transaction_connection = None

self._pipe.reinitialize_counter += 1
Expand Down Expand Up @@ -3601,14 +3645,23 @@ def reset(self):
node = self._nodes_manager.find_connection_owner(
self._transaction_connection
)
node.redis_connection.connection_pool.release(
self._transaction_connection
)
if node and node.redis_connection:
node.redis_connection.connection_pool.release(
self._transaction_connection
)
self._transaction_connection = None
except self.CONNECTION_ERRORS:
# disconnect will also remove any previous WATCHes
if self._transaction_connection:
self._transaction_connection.disconnect()
node = self._nodes_manager.find_connection_owner(
self._transaction_connection
)
if node and node.redis_connection:
node.redis_connection.connection_pool.release(
self._transaction_connection
)
self._transaction_connection = None

# clean up the other instance attributes
self._watching = False
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ def mock_cache() -> CacheInterface:
@pytest.fixture()
def mock_connection() -> ConnectionInterface:
mock_connection = Mock(spec=ConnectionInterface)
# Add host and port attributes needed by find_connection_owner
mock_connection.host = "127.0.0.1"
mock_connection.port = 6379
return mock_connection


Expand Down
Loading