diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index cea0da47bd..68994944ff 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -1004,18 +1004,23 @@ def handle_oss_maintenance_completed_notification( disconnect_startup_nodes_pools=False, additional_startup_nodes_info=[(new_node_host, int(new_node_port))], ) + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + # Some of them might be used by sub sub and we don't know which ones - so we disconnect + # all in flight connections after they are done with current command execution + for conn in ( + current_node.redis_connection.connection_pool._get_in_use_connections() + ): + conn.mark_for_reconnect() if ( current_node not in self.cluster_client.nodes_manager.nodes_cache.values() ): - # disconnect all free connections to the node + # disconnect all free connections to the node - this node will be dropped + # from the cluster, so we don't need to revert the timeouts for conn in current_node.redis_connection.connection_pool._get_free_connections(): conn.disconnect() - # mark for reconnect all in use connections to the node - this will force them to - # disconnect after they complete their current commands - for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): - conn.mark_for_reconnect() else: if self.config.is_relaxed_timeouts_enabled(): # reset the timeouts for the node to which the connection is connected @@ -1025,6 +1030,7 @@ def handle_oss_maintenance_completed_notification( *current_node.redis_connection.connection_pool._get_in_use_connections(), *current_node.redis_connection.connection_pool._get_free_connections(), ): + conn.reset_tmp_settings(reset_relaxed_timeout=True) conn.update_current_socket_timeout(relaxed_timeout=-1) conn.maintenance_state = MaintenanceState.NONE diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index 97e91d2a05..8e2cf55efb 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -1,4 +1,6 @@ +from asyncio import Queue from dataclasses import dataclass +from threading import Thread from typing import List, Optional, cast from redis import ConnectionPool, RedisCluster @@ -975,3 +977,169 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( ), ], ) + + def test_smigrating_smigrated_with_sharded_pubsub( + self, + ): + """ + Test handling of sharded pubsub connections when SMIGRATING and SMIGRATED + notifications are received. + """ + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=5) + + node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + + pubsub = self.cluster.pubsub() + + # subscribe to a channel on node1 + pubsub.ssubscribe("anyprefix:{7}:k") + + msg = pubsub.get_sharded_message( + ignore_subscribe_messages=False, timeout=10, target_node=node_1 + ) + # subscribe msg + assert msg is not None and msg["type"] == "ssubscribe" + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <5200-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + + # get message with node 1 connection to consume the notification + # timeout is 1 second + msg = pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=5000) + # smigrating handled + assert msg is None + + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + assert pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout == 30 + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + == 30 + ) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), + SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15380 <5200-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + # execute command with node 1 connection + # this will first consume the SMIGRATING notification for the connection + # this should update the cluster topology and move the slot range to the new node + # and should set the pubsub connection for reconnect + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + assert pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout is None + ) + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + is None + ) + + # first message will be SMIGRATED notification handling + # during this read connection will be reconnected and will resubscribe to channels + msg = pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=10) + assert msg is None + + assert not pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout is None + ) + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + is None + ) + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection.maintenance_state + == MaintenanceState.NONE + ) + # validate resubscribed + assert pubsub.node_pubsub_mapping[node_1.name].subscribed + + def test_smigrating_smigrated_with_std_pubsub( + self, + ): + """ + Test handling of standard pubsub connections when SMIGRATING and SMIGRATED + notifications are received. + """ + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=5) + + pubsub = self.cluster.pubsub() + + # subscribe to a channel on node1 + pubsub.subscribe("anyprefix:{7}:k") + + msg = pubsub.get_message(ignore_subscribe_messages=False, timeout=10) + # subscribe msg + assert msg is not None and msg["type"] == "subscribe" + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <5200-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + + # get message with node 1 connection to consume the notification + # timeout is 1 second + msg = pubsub.get_message(ignore_subscribe_messages=False, timeout=5000) + # smigrating handled + assert msg is None + + assert pubsub.connection._sock is not None + assert pubsub.connection._socket_timeout == 30 + assert pubsub.connection._socket_connect_timeout == 30 + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), + SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15380 <5200-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + # execute command with node 1 connection + # this will first consume the SMIGRATING notification for the connection + # this should update the cluster topology and move the slot range to the new node + # and should set the pubsub connection for reconnect + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + assert res is True + + assert pubsub.connection._should_reconnect + assert pubsub.connection._sock is not None + assert pubsub.connection._socket_timeout is None + assert pubsub.connection._socket_connect_timeout is None + + # first message will be SMIGRATED notification handling + # during this read connection will be reconnected and will resubscribe to channels + msg = pubsub.get_message(ignore_subscribe_messages=True, timeout=10) + assert msg is None + + assert not pubsub.connection._should_reconnect + assert pubsub.connection._sock is not None + assert pubsub.connection._socket_timeout is None + assert pubsub.connection._socket_connect_timeout is None + assert pubsub.connection.maintenance_state == MaintenanceState.NONE + # validate resubscribed + assert pubsub.subscribed