From 2ab57ee2c5c2a708383ad41c1971dbb5f9d01bc3 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 29 Jan 2026 15:57:43 +0200 Subject: [PATCH 1/5] Refactoring the SMIGRATED flow - the notification is changed to contain the src node address for each slot range movement. --- redis/_parsers/base.py | 29 +- redis/asyncio/cluster.py | 3 + redis/cluster.py | 3 + redis/connection.py | 41 +- redis/maint_notifications.py | 147 +- .../proxy_server_helpers.py | 43 +- ...st_cluster_maint_notifications_handling.py | 383 ++++-- .../test_maint_notifications.py | 163 ++- .../test_maint_notifications_handling.py | 12 +- tests/test_scenario/conftest.py | 16 +- tests/test_scenario/fault_injector_client.py | 44 +- .../maint_notifications_helpers.py | 8 + .../test_scenario/test_maint_notifications.py | 1223 +++++++---------- 13 files changed, 1087 insertions(+), 1028 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index ab3e81653b..f4f91549b2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -11,6 +11,7 @@ NodeMigratedNotification, NodeMigratingNotification, NodeMovingNotification, + NodesToSlotsMapping, OSSNodeMigratedNotification, OSSNodeMigratingNotification, ) @@ -195,9 +196,20 @@ def parse_oss_maintenance_completed_msg(response): # SMIGRATED [ , ...] id = response[1] nodes_to_slots_mapping_data = response[2] - nodes_to_slots_mapping = {} - for node, slots in nodes_to_slots_mapping_data: - nodes_to_slots_mapping[safe_str(node)] = safe_str(slots) + nodes_to_slots_mapping = [] + for src_node, node, slots in nodes_to_slots_mapping_data: + # Parse the node address to extract host and port + src_node_str = safe_str(src_node) + node_str = safe_str(node) + slots_str = safe_str(slots) + # The src_node_address is not provided in the SMIGRATED message, + # so we use an empty string as a placeholder + mapping = NodesToSlotsMapping( + src_node_address=src_node_str, + dest_node_address=node_str, + slots=slots_str, + ) + nodes_to_slots_mapping.append(mapping) return OSSNodeMigratedNotification(id, nodes_to_slots_mapping) @@ -341,9 +353,9 @@ def handle_push_response(self, response, **kwargs): if notification is not None: return self.maintenance_push_handler_func(notification) - if ( - msg_type == _SMIGRATED_MESSAGE - and self.oss_cluster_maint_push_handler_func + if msg_type == _SMIGRATED_MESSAGE and ( + self.oss_cluster_maint_push_handler_func + or self.maintenance_push_handler_func ): parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ msg_type @@ -351,7 +363,10 @@ def handle_push_response(self, response, **kwargs): notification = parser_function(response) if notification is not None: - return self.oss_cluster_maint_push_handler_func(notification) + if self.maintenance_push_handler_func: + self.maintenance_push_handler_func(notification) + if self.oss_cluster_maint_push_handler_func: + self.oss_cluster_maint_push_handler_func(notification) except Exception as e: logger.error( "Error handling {} message ({}): {}".format(msg_type, response, e) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 74b16fbabc..55a9c02a6d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1212,6 +1212,9 @@ def __repr__(self) -> str: def __eq__(self, obj: Any) -> bool: return isinstance(obj, ClusterNode) and obj.name == self.name + def __hash__(self) -> int: + return hash(self.name) + _DEL_MESSAGE = "Unclosed ClusterNode object" def __del__( diff --git a/redis/cluster.py b/redis/cluster.py index c43a4693ab..a6250869c5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1676,6 +1676,9 @@ def __repr__(self): def __eq__(self, obj): return isinstance(obj, ClusterNode) and obj.name == self.name + def __hash__(self): + return hash(self.name) + class LoadBalancingStrategy(Enum): ROUND_ROBIN = "round_robin" diff --git a/redis/connection.py b/redis/connection.py index eb1f935f14..06f03af5f8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -440,23 +440,17 @@ def _configure_maintenance_notifications( else: self._maint_notifications_pool_handler = None + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler(self, self.maint_notifications_config) + ) + if oss_cluster_maint_notifications_handler: - # Extract a reference to a new handler that copies all properties - # of the original one and has a different connection reference - # This is needed because when we attach the handler to the parser - # we need to make sure that the handler has a reference to the - # connection that the parser is attached to. self._oss_cluster_maint_notifications_handler = ( - oss_cluster_maint_notifications_handler.get_handler_for_connection() + oss_cluster_maint_notifications_handler ) - self._oss_cluster_maint_notifications_handler.set_connection(self) else: self._oss_cluster_maint_notifications_handler = None - self._maint_notifications_connection_handler = ( - MaintNotificationsConnectionHandler(self, self.maint_notifications_config) - ) - # Set up OSS cluster handler to parser if available if self._oss_cluster_maint_notifications_handler: parser.set_oss_cluster_maint_push_handler( @@ -521,21 +515,12 @@ def set_maint_notifications_pool_handler_for_connection( def set_maint_notifications_cluster_handler_for_connection( self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler ): - # Deep copy the cluster handler to avoid sharing the same handler - # between multiple connections, because otherwise each connection will override - # the connection reference and the handler will only hold a reference - # to the last connection that was set. - maint_notifications_cluster_handler_copy = ( - oss_cluster_maint_notifications_handler.get_handler_for_connection() - ) - - maint_notifications_cluster_handler_copy.set_connection(self) self._get_parser().set_oss_cluster_maint_push_handler( - maint_notifications_cluster_handler_copy.handle_notification + oss_cluster_maint_notifications_handler.handle_notification ) self._oss_cluster_maint_notifications_handler = ( - maint_notifications_cluster_handler_copy + oss_cluster_maint_notifications_handler ) # Update maintenance notification connection handler if it doesn't exist @@ -1142,6 +1127,7 @@ def disconnect(self, *args): self._sock = None # reset the reconnect flag self.reset_should_reconnect() + if conn_sock is None: return @@ -1156,6 +1142,17 @@ def disconnect(self, *args): except OSError: pass + if self.maintenance_state == MaintenanceState.MAINTENANCE: + # this block will be executed only if the connection was in maintenance state + # and the connection was closed. + # The state change won't be applied on connections that are in Moving state + # because their state and configurations will be handled when the moving ttl expires. + self.reset_tmp_settings(reset_relaxed_timeout=True) + self.maintenance_state = MaintenanceState.NONE + # reset the sets that keep track of received start maint + # notifications and skipped end maint notifications + self.reset_received_notifications() + def mark_for_reconnect(self): self._should_reconnect = True diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index c4a4d7a3fb..d7642fd9d5 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -5,7 +5,8 @@ import threading import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Literal, Optional, Union from redis.typing import Number @@ -454,6 +455,13 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, int(self.id))) +@dataclass +class NodesToSlotsMapping: + src_node_address: str + dest_node_address: str + slots: str + + class OSSNodeMigratedNotification(MaintenanceNotification): """ Notification for when a Redis OSS API client is used and a node has completed migrating slots. @@ -463,15 +471,15 @@ class OSSNodeMigratedNotification(MaintenanceNotification): Args: id (int): Unique identifier for this notification - nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots + nodes_to_slots_mapping (List[NodesToSlotsMapping]): List of node-to-slots mappings """ - DEFAULT_TTL = 30 + DEFAULT_TTL = 120 def __init__( self, id: int, - nodes_to_slots_mapping: Dict[str, str], + nodes_to_slots_mapping: List[NodesToSlotsMapping], ): super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) self.nodes_to_slots_mapping = nodes_to_slots_mapping @@ -551,6 +559,25 @@ def _is_private_fqdn(host: str) -> bool: return False +def add_debug_log_for_notification( + connection: "MaintNotificationsAbstractConnection", + notification: Union[str, MaintenanceNotification], +): + if logging.getLogger().isEnabledFor(logging.DEBUG): + try: + socket_address = ( + connection._sock.getsockname() if connection._sock else None + ) + except (AttributeError, OSError): + socket_address = None + + logging.debug( + f"Handling maintenance notification: {notification}, " + f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, " + f"socket_address: {socket_address}", + ) + + class MaintNotificationsConfig: """ Configuration class for maintenance notifications handling behaviour. Notifications are received through @@ -885,6 +912,7 @@ class MaintNotificationsConnectionHandler: OSSNodeMigratingNotification: 1, NodeMigratedNotification: 0, NodeFailedOverNotification: 0, + OSSNodeMigratedNotification: 0, } def __init__( @@ -913,10 +941,8 @@ def handle_notification(self, notification: MaintenanceNotification): def handle_maintenance_start_notification( self, maintenance_state: MaintenanceState, notification: MaintenanceNotification ): - logging.debug( - f"Handling start maintenance notification: {notification}, " - f"with connection: {self.connection}, connected to ip {self.connection.get_resolved_ip()}" - ) + add_debug_log_for_notification(self.connection, notification) + if ( self.connection.maintenance_state == MaintenanceState.MOVING or not self.config.is_relaxed_timeouts_enabled() @@ -942,10 +968,7 @@ def handle_maintenance_completed_notification(self): or not self.config.is_relaxed_timeouts_enabled() ): return - logging.debug( - f"Handling end maintenance notification with connection: {self.connection}, " - f"connected to ip {self.connection.get_resolved_ip()}" - ) + add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED") self.connection.reset_tmp_settings(reset_relaxed_timeout=True) # Maintenance completed - reset the connection # timeouts by providing -1 as the relaxed timeout @@ -967,10 +990,6 @@ def __init__( self._processed_notifications = set() self._in_progress = set() self._lock = threading.RLock() - self.connection = None - - def set_connection(self, connection: "MaintNotificationsAbstractConnection"): - self.connection = connection def get_handler_for_connection(self): # Copy all data that should be shared between connections @@ -980,7 +999,6 @@ def get_handler_for_connection(self): copy._processed_notifications = self._processed_notifications copy._in_progress = self._in_progress copy._lock = self._lock - copy.connection = None return copy def remove_expired_notifications(self): @@ -1011,77 +1029,58 @@ def handle_oss_maintenance_completed_notification( # that has also has the notification and we don't want to # process the same notification twice return - if self.connection is None: - logging.error( - "Connection is not set for OSSMaintNotificationsHandler. " - f"Failed to handle notification: {notification}" - ) - return - logging.debug( - f"Handling SMIGRATED notification: {notification} with connection: {self.connection}, connected to ip {self.connection.get_resolved_ip()}" - ) + logging.debug(f"Handling SMIGRATED notification: {notification}") self._in_progress.add(notification) - # get the node to which the connection is connected - # before refreshing the cluster topology - current_node = self.cluster_client.nodes_manager.get_node( - host=self.connection.host, port=self.connection.port - ) - - # Updates the cluster slots cache with the new slots mapping - # This will also update the nodes cache with the new nodes mapping + # Extract the information about the src and destination nodes that are affected by the maintenance additional_startup_nodes_info = [] - for node_address, _ in notification.nodes_to_slots_mapping.items(): - new_node_host, new_node_port = node_address.split(":") + affected_nodes = set() + for mapping in notification.nodes_to_slots_mapping: + new_node_host, new_node_port = mapping.dest_node_address.split(":") + src_host, src_port = mapping.src_node_address.split(":") + src_node = self.cluster_client.nodes_manager.get_node( + host=src_host, port=src_port + ) + if src_node is not None: + affected_nodes.add(src_node) + additional_startup_nodes_info.append( (new_node_host, int(new_node_port)) ) + # Updates the cluster slots cache with the new slots mapping + # This will also update the nodes cache with the new nodes mapping self.cluster_client.nodes_manager.initialize( disconnect_startup_nodes_pools=False, additional_startup_nodes_info=additional_startup_nodes_info, ) - with current_node.redis_connection.connection_pool._lock: - # mark for reconnect all in use connections to the node - this will force them to - # disconnect after they complete their current commands - # Some of them might be used by sub sub and we don't know which ones - so we disconnect - # all in flight connections after they are done with current command execution - for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): - conn.mark_for_reconnect() + all_nodes = set(affected_nodes) + all_nodes = all_nodes.union( + self.cluster_client.nodes_manager.nodes_cache.values() + ) - if ( - current_node - not in self.cluster_client.nodes_manager.nodes_cache.values() - ): - # disconnect all free connections to the node - this node will be dropped - # from the cluster, so we don't need to revert the timeouts - for conn in current_node.redis_connection.connection_pool._get_free_connections(): - conn.disconnect() - else: - if self.config.is_relaxed_timeouts_enabled(): - # reset the timeouts for the node to which the connection is connected - # Perform check if other maintenance ops are in progress for the same node - # and if so, don't reset the timeouts and wait for the last maintenance - # to complete - for conn in ( - *current_node.redis_connection.connection_pool._get_in_use_connections(), - *current_node.redis_connection.connection_pool._get_free_connections(), - ): - if ( - len(conn.get_processed_start_notifications()) - > len(conn.get_skipped_end_notifications()) + 1 - ): - # we have received more start notifications than end notifications - # for this connection - we should not reset the timeouts - # and add the notification id to the set of skipped end notifications - conn.add_skipped_end_notification(notification.id) - else: - conn.reset_tmp_settings(reset_relaxed_timeout=True) - conn.update_current_socket_timeout(relaxed_timeout=-1) - conn.maintenance_state = MaintenanceState.NONE - conn.reset_received_notifications() + for current_node in all_nodes: + if current_node.redis_connection is None: + continue + with current_node.redis_connection.connection_pool._lock: + if current_node in affected_nodes: + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + # Some of them might be used by sub sub and we don't know which ones - so we disconnect + # all in flight connections after they are done with current command execution + for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): + conn.mark_for_reconnect() + + if ( + current_node + not in self.cluster_client.nodes_manager.nodes_cache.values() + ): + # disconnect all free connections to the node - this node will be dropped + # from the cluster, so we don't need to revert the timeouts + for conn in current_node.redis_connection.connection_pool._get_free_connections(): + conn.disconnect() # mark the notification as processed self._processed_notifications.add(notification) diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index 1824c3d108..fa117510cb 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -48,12 +48,13 @@ def oss_maint_notification_to_resp(txt: str) -> str: ">3\r\n" # Push message with 3 elements f"+{notification}\r\n" # Element 1: Command f":{seq_id}\r\n" # Element 2: SeqID - f"*{len(hosts_and_slots) // 2}\r\n" # Element 3: Array of host:port, slots pairs + f"*{len(hosts_and_slots) // 3}\r\n" # Element 3: Array of src_host:src_port, dest_host:dest_port, slots pairs ) - for i in range(0, len(hosts_and_slots), 2): - resp += "*2\r\n" + for i in range(0, len(hosts_and_slots), 3): + resp += "*3\r\n" resp += f"+{hosts_and_slots[i]}\r\n" resp += f"+{hosts_and_slots[i + 1]}\r\n" + resp += f"+{hosts_and_slots[i + 2]}\r\n" else: # SMIGRATING # Format: SMIGRATING SeqID slot,range1-range2 @@ -211,20 +212,12 @@ def get_connections(self) -> dict: def send_notification( self, - connected_to_port: Union[int, str], notification: str, ) -> dict: """ - Send a notification to all connections connected to - a specific node(identified by port number). - - This method: - 1. Fetches stats from the interceptor server - 2. Finds all connection IDs connected to the specified node - 3. Sends the notification to each connection + Send a notification to all connections. Args: - connected_to_port: Port number of the node to send the notification to notification: The notification message to send (in RESP format) Returns: @@ -233,32 +226,12 @@ def send_notification( Example: interceptor = ProxyInterceptorHelper(None, "http://localhost:4000") result = interceptor.send_notification( - "6379", "KjENCiQ0DQpQSU5HDQo=" # PING command in base64 ) """ - # Get stats to find connection IDs for the node - stats = self.get_stats() - - # Extract connection IDs for the specified node - conn_ids = [] - for node_key, node_info in stats.items(): - node_port = node_key.split("@")[1] - if int(node_port) == int(connected_to_port): - for conn in node_info.get("connections", []): - conn_ids.append(conn["id"]) - - if not conn_ids: - raise RuntimeError( - f"No connections found for node {node_port}. " - f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}" - ) - - # Send notification to each connection + # Send notification to all connections results = {} - logging.info(f"Sending notification to {len(conn_ids)} connections: {conn_ids}") - connections_query = f"connectionIds={','.join(conn_ids)}" - url = f"{self.server_url}/send-to-clients?{connections_query}&encoding=base64" + url = f"{self.server_url}/send-to-all-clients?encoding=base64" # Encode notification to base64 data = base64.b64encode(notification.encode("utf-8")) @@ -271,8 +244,6 @@ def send_notification( results = {"error": str(e)} return { - "node_address": node_port, - "connection_ids": conn_ids, "results": results, } diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index e49f5c6131..26885fc903 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import logging from typing import List, Optional, cast from redis import ConnectionPool, RedisCluster @@ -20,6 +21,10 @@ NODE_PORT_NEW = 15382 +# IP addresses used in tests +NODE_IP_LOCALHOST = "127.0.0.1" +NODE_IP_PROXY = "0.0.0.0" + # Initial cluster node configuration for proxy-based tests PROXY_CLUSTER_NODES = [ ClusterNode("127.0.0.1", NODE_PORT_1), @@ -38,19 +43,19 @@ def test_oss_maint_notification_to_resp(self): assert resp == ">3\r\n+SMIGRATING\r\n:12\r\n+123,456,5000-7000\r\n" resp = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000" + f"SMIGRATED 12 {NODE_IP_LOCALHOST}:{NODE_PORT_1} {NODE_IP_LOCALHOST}:{NODE_PORT_2} 123,456,5000-7000" ) assert ( resp - == ">3\r\n+SMIGRATED\r\n:12\r\n*1\r\n*2\r\n+127.0.0.1:15380\r\n+123,456,5000-7000\r\n" + == f">3\r\n+SMIGRATED\r\n:12\r\n*1\r\n*3\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_1}\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_2}\r\n+123,456,5000-7000\r\n" ) resp = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000 127.0.0.1:15381 7000-8000 127.0.0.1:15382 8000-9000" + f"SMIGRATED 12 {NODE_IP_LOCALHOST}:{NODE_PORT_1} {NODE_IP_LOCALHOST}:{NODE_PORT_2} 123,456,5000-7000 {NODE_IP_LOCALHOST}:{NODE_PORT_1} {NODE_IP_LOCALHOST}:{NODE_PORT_3} 7000-8000 {NODE_IP_LOCALHOST}:{NODE_PORT_1} {NODE_IP_LOCALHOST}:{NODE_PORT_NEW} 8000-9000" ) assert ( resp - == ">3\r\n+SMIGRATED\r\n:12\r\n*3\r\n*2\r\n+127.0.0.1:15380\r\n+123,456,5000-7000\r\n*2\r\n+127.0.0.1:15381\r\n+7000-8000\r\n*2\r\n+127.0.0.1:15382\r\n+8000-9000\r\n" + == f">3\r\n+SMIGRATED\r\n:12\r\n*3\r\n*3\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_1}\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_2}\r\n+123,456,5000-7000\r\n*3\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_1}\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_3}\r\n+7000-8000\r\n*3\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_1}\r\n+{NODE_IP_LOCALHOST}:{NODE_PORT_NEW}\r\n+8000-9000\r\n" ) @@ -352,7 +357,6 @@ def _validate_connection_handlers( assert oss_cluster_parser_handler_set_for_con is not None assert hasattr(oss_cluster_parser_handler_set_for_con, "__self__") assert hasattr(oss_cluster_parser_handler_set_for_con, "__func__") - assert oss_cluster_parser_handler_set_for_con.__self__.connection is conn assert ( oss_cluster_parser_handler_set_for_con.__self__.cluster_client is cluster_client @@ -401,7 +405,6 @@ def test_oss_maint_handler_propagation(self): *node.redis_connection.connection_pool._get_free_connections(), ): assert conn._oss_cluster_maint_notifications_handler is not None - assert conn._oss_cluster_maint_notifications_handler.connection is conn self._validate_connection_handlers( conn, cluster, cluster.maint_notifications_config ) @@ -418,10 +421,6 @@ def test_oss_maint_handler_propagation_cache_enabled(self): *node.redis_connection.connection_pool._get_free_connections(), ): assert conn._conn._oss_cluster_maint_notifications_handler is not None - assert ( - conn._conn._oss_cluster_maint_notifications_handler.connection - is conn._conn - ) self._validate_connection_handlers( conn._conn, cluster, cluster.maint_notifications_config ) @@ -541,7 +540,7 @@ def test_receive_smigrating_notification(self): notification = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 123,456,5000-7000" ) - self.proxy_helper.send_notification(NODE_PORT_1, notification) + self.proxy_helper.send_notification(notification) # validate no timeout is relaxed on any connection self._validate_connections_states( @@ -598,7 +597,7 @@ def test_receive_smigrating_with_disabled_relaxed_timeout(self): notification = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 123,456,5000-7000" ) - self.proxy_helper.send_notification(NODE_PORT_1, notification) + self.proxy_helper.send_notification(notification) # validate no timeout is relaxed on any connection self._validate_connections_states( @@ -624,16 +623,16 @@ def test_receive_smigrated_notification(self): self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), - SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_NEW, 0, 5460), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5461, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) # send a notification to node 1 notification = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000" + f"SMIGRATED 12 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 123,456,5000-7000" ) - self.proxy_helper.send_notification(NODE_PORT_1, notification) + self.proxy_helper.send_notification(notification) # execute a command that will receive the notification res = self.cluster.set("anyprefix:{3}:k", "VAL") @@ -641,7 +640,7 @@ def test_receive_smigrated_notification(self): # validate the cluster topology was updated new_node = self.cluster.nodes_manager.get_node( - host="0.0.0.0", port=NODE_PORT_NEW + host=NODE_IP_PROXY, port=NODE_PORT_NEW ) assert new_node is not None @@ -653,16 +652,16 @@ def test_receive_smigrated_notification_with_two_nodes(self): self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), - SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_NEW, 0, 5460), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5461, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) # send a notification to node 1 notification = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 12 127.0.0.1:15380 123,456,5000-7000 127.0.0.1:15382 110-120" + f"SMIGRATED 12 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 123,456,5000-7000 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_NEW} 110-120" ) - self.proxy_helper.send_notification(NODE_PORT_1, notification) + self.proxy_helper.send_notification(notification) # execute a command that will receive the notification res = self.cluster.set("anyprefix:{3}:k", "VAL") @@ -670,7 +669,7 @@ def test_receive_smigrated_notification_with_two_nodes(self): # validate the cluster topology was updated new_node = self.cluster.nodes_manager.get_node( - host="0.0.0.0", port=NODE_PORT_NEW + host=NODE_IP_PROXY, port=NODE_PORT_NEW ) assert new_node is not None @@ -679,13 +678,17 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): # warm up connection pools - create several connections in each pool self._warm_up_connection_pools(self.cluster, created_connections_count=3) - node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) - node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) + node_1 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_1 + ) + node_2 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_2 + ) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 123,2000-3000" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + self.proxy_helper.send_notification(smigrating_node_1) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") self._validate_connections_states( @@ -698,7 +701,8 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): relaxed_timeout=self.config.relaxed_timeout, ), ConnectionStateExpectation( - node_port=NODE_PORT_2, changed_connections_count=0 + node_port=NODE_PORT_2, + changed_connections_count=0, ), ], ) @@ -706,7 +710,7 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 13 8000-9000" ) - self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) + self.proxy_helper.send_notification(smigrating_node_2) # execute command with node 2 connection self.cluster.set("anyprefix:{1}:k", "VAL") @@ -729,25 +733,20 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): ], ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15381 123,2000-3000" - ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) - - smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15381 8000-9000" + f"SMIGRATED 14 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_3} 123,2000-3000" ) - self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) + self.proxy_helper.send_notification(smigrated_node_1) self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), - SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), - SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), - SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), - SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), - SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 0, 122), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 123, 123), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 124, 1999), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 2000, 3000), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 3001, 5460), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5461, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) @@ -761,7 +760,7 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): # validate changed slot is assigned to node 3 assert self.cluster.nodes_manager.get_node_from_slot( 123 - ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + ) == self.cluster.nodes_manager.get_node(host=NODE_IP_PROXY, port=NODE_PORT_3) # validate the connections are in the correct state self._validate_connections_states( self.cluster, @@ -779,18 +778,23 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): ], ) + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED 15 {NODE_IP_PROXY}:{NODE_PORT_2} {NODE_IP_PROXY}:{NODE_PORT_3} 7000-7999" + ) + self.proxy_helper.send_notification(smigrated_node_2) + self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), - SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), - SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), - SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), - SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), - SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 7000), - SlotsRange("0.0.0.0", NODE_PORT_3, 7001, 8000), - SlotsRange("0.0.0.0", NODE_PORT_2, 8001, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 0, 122), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 123, 123), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 124, 2000), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 2001, 3000), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 3001, 5460), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5461, 6999), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 7000, 7999), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 8000, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) # execute command with node 2 connection @@ -801,8 +805,8 @@ def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): assert node_2 in self.cluster.nodes_manager.nodes_cache.values() # validate slot changes are reflected assert self.cluster.nodes_manager.get_node_from_slot( - 8000 - ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + 7000 + ) == self.cluster.nodes_manager.get_node(host=NODE_IP_PROXY, port=NODE_PORT_3) # validate the connections are in the correct state self._validate_connections_states( @@ -824,14 +828,20 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): # warm up connection pools - create several connections in each pool self._warm_up_connection_pools(self.cluster, created_connections_count=3) - node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) - node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) - node_3 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + node_1 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_1 + ) + node_2 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_2 + ) + node_3 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_3 + ) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 0-5460" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + self.proxy_helper.send_notification(smigrating_node_1) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") self._validate_connections_states( @@ -852,7 +862,7 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 13 5461-10922" ) - self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) + self.proxy_helper.send_notification(smigrating_node_2) # execute command with node 2 connection self.cluster.set("anyprefix:{1}:k", "VAL") @@ -876,20 +886,16 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15382 0-5460" + f"SMIGRATED 14 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_NEW} 0-5460" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + self.proxy_helper.send_notification(smigrated_node_1) - smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15382 5461-10922" - ) - self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", 15382, 0, 5460), - SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_NEW, 0, 5460), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5461, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) @@ -901,13 +907,15 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): # validate node 2 is still there assert node_2 in self.cluster.nodes_manager.nodes_cache.values() # validate new node is added - new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + new_node = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_NEW + ) assert new_node is not None assert new_node.redis_connection is not None # validate a slot from the changed range is assigned to the new node assert self.cluster.nodes_manager.get_node_from_slot( 123 - ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + ) == self.cluster.nodes_manager.get_node(host=NODE_IP_PROXY, port=NODE_PORT_NEW) # validate the connections are in the correct state self._validate_removed_node_connections(node_1) @@ -925,12 +933,17 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): ], ) + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED 15 {NODE_IP_PROXY}:{NODE_PORT_2} {NODE_IP_PROXY}:15383 5461-10922" + ) + self.proxy_helper.send_notification(smigrated_node_2) + self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", 15382, 0, 5460), - SlotsRange("0.0.0.0", 15383, 5461, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_NEW, 0, 5460), + SlotsRange(NODE_IP_PROXY, 15383, 5461, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) # execute command with node 2 connection @@ -941,13 +954,13 @@ def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): # validate node 3 is still there assert node_3 in self.cluster.nodes_manager.nodes_cache.values() # validate new node is added - new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + new_node = self.cluster.nodes_manager.get_node(host=NODE_IP_PROXY, port=15383) assert new_node is not None assert new_node.redis_connection is not None # validate a slot from the changed range is assigned to the new node assert self.cluster.nodes_manager.get_node_from_slot( 8000 - ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + ) == self.cluster.nodes_manager.get_node(host=NODE_IP_PROXY, port=15383) # validate the connections in removed node are in the correct state self._validate_removed_node_connections(node_2) @@ -964,9 +977,9 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( self._warm_up_connection_pools(self.cluster, created_connections_count=1) smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 12 1000-2000" + "SMIGRATING 12 1000-2000,2500-3000" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + self.proxy_helper.send_notification(smigrating_node_1) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") self._validate_connections_states( @@ -981,46 +994,26 @@ def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( ], ) - smigrating_node_1_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATING 13 3000-4000" - ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1_2) - # execute command with node 1 connection - self.cluster.set("anyprefix:{3}:k", "VAL") - self._validate_connections_states( - self.cluster, - [ - ConnectionStateExpectation( - node_port=NODE_PORT_1, - changed_connections_count=1, - state=MaintenanceState.MAINTENANCE, - relaxed_timeout=self.config.relaxed_timeout, - ), - ], - ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 1000-2000" + f"SMIGRATED 14 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 1000-2000 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_3} 2500-3000" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + self.proxy_helper.send_notification(smigrated_node_1) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") - # this functionality is part of CAE-1038 and will be added later + # validate the timeout is still relaxed self._validate_connections_states( self.cluster, [ ConnectionStateExpectation( node_port=NODE_PORT_1, - changed_connections_count=1, - state=MaintenanceState.MAINTENANCE, - relaxed_timeout=self.config.relaxed_timeout, ), ], ) smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 15 0.0.0.0:15381 3000-4000" + f"SMIGRATED 15 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_3} 3000-4000" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1_2) + self.proxy_helper.send_notification(smigrated_node_1_2) # execute command with node 1 connection self.cluster.set("anyprefix:{3}:k", "VAL") self._validate_connections_states( @@ -1043,7 +1036,9 @@ def test_smigrating_smigrated_with_sharded_pubsub( # warm up connection pools - create several connections in each pool self._warm_up_connection_pools(self.cluster, created_connections_count=5) - node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + node_1 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_1 + ) pubsub = self.cluster.pubsub() @@ -1059,9 +1054,9 @@ def test_smigrating_smigrated_with_sharded_pubsub( smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 5200-5460" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + self.proxy_helper.send_notification(smigrating_node_1) - # get message with node 1 connection to consume the notification + # get message with node 1 connection to consume the SMIGRATING notification # timeout is 1 second msg = pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=5000) # smigrating handled @@ -1074,28 +1069,63 @@ def test_smigrating_smigrated_with_sharded_pubsub( == 30 ) + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED 14 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 123" + ) + self.proxy_helper.send_notification(smigrated_node_1) + self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), - SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 0, 122), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 123, 123), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 124, 5200), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5201, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) - smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 5200-5460" - ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) # execute command with node 1 connection # this will first consume the SMIGRATING notification for the connection - # this should update the cluster topology and move the slot range to the new node + # then should process the SMIGRATED notification and update the cluster + # topology and move the slot range to the new node # and should set the pubsub connection for reconnect res = self.cluster.set("anyprefix:{3}:k", "VAL") assert res is True assert pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + # validate timeout is not relaxed - it will be relaxed + # when this concrete connections reads the notification + assert pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout == 30 + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + == 30 + ) + + # during this read the connection will detect that it needs to reconnect + # and the waiting on the socket SMIGRATED won't be processed + # it will directly reconnect and receive again the SMIGRATED notification + logging.info( + "Waiting for message with pubsub connection that will reconnect..." + ) + msg = None + while msg is None or msg["type"] != "ssubscribe": + logging.info("Waiting for ssubscribe message...") + msg = pubsub.get_sharded_message( + ignore_subscribe_messages=False, timeout=10 + ) + assert msg is not None and msg["type"] == "ssubscribe" + logging.info("Reconnect ended.") + + logging.info("Consuming SMIGRATED notification with pubsub connection...") + # simulating server's behavior that send the last notification to the new connection + self.proxy_helper.send_notification(smigrated_node_1) + msg = pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=10) + assert msg is None + + assert not pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None assert ( pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout is None ) @@ -1103,12 +1133,105 @@ def test_smigrating_smigrated_with_sharded_pubsub( pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout is None ) + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection.maintenance_state + == MaintenanceState.NONE + ) + # validate resubscribed + assert pubsub.node_pubsub_mapping[node_1.name].subscribed - # first message will be SMIGRATED notification handling - # during this read connection will be reconnected and will resubscribe to channels - msg = pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=10) + def test_smigrating_smigrated_with_sharded_pubsub_and_reconnect_after_smigrated_expires( + self, + ): + """ + Test handling of sharded pubsub connections when SMIGRATING and SMIGRATED + notifications are received. + """ + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=5) + + node_1 = self.cluster.nodes_manager.get_node( + host=NODE_IP_PROXY, port=NODE_PORT_1 + ) + + pubsub = self.cluster.pubsub() + + # subscribe to a channel on node1 + pubsub.ssubscribe("anyprefix:{7}:k") + + msg = pubsub.get_sharded_message( + ignore_subscribe_messages=False, timeout=10, target_node=node_1 + ) + # subscribe msg + assert msg is not None and msg["type"] == "ssubscribe" + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 5200-5460" + ) + self.proxy_helper.send_notification(smigrating_node_1) + + # get message with node 1 connection to consume the SMIGRATING notification + # timeout is 1 second + msg = pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=5000) + # smigrating handled assert msg is None + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + assert pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout == 30 + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + == 30 + ) + + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + f"SMIGRATED 14 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 123" + ) + self.proxy_helper.send_notification(smigrated_node_1) + + self.proxy_helper.set_cluster_slots( + CLUSTER_SLOTS_INTERCEPTOR_NAME, + [ + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 0, 122), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 123, 123), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 124, 5200), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5201, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), + ], + ) + + # execute command with node 1 connection + # this will first consume the SMIGRATING notification for the connection + # then should process the SMIGRATED notification and update the cluster + # topology and move the slot range to the new node + # and should set the pubsub connection for reconnect + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + assert pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect + assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None + # validate timeout is not relaxed - it will be relaxed + # when this concrete connections reads the notification + assert pubsub.node_pubsub_mapping[node_1.name].connection._socket_timeout == 30 + assert ( + pubsub.node_pubsub_mapping[node_1.name].connection._socket_connect_timeout + == 30 + ) + + # during this read the connection will detect that it needs to reconnect + # and the waiting on the socket SMIGRATED won't be processed + # it will directly reconnect and receive again the SMIGRATED notification + logging.info( + "Waiting for message with pubsub connection that will reconnect..." + ) + msg = None + while msg is None or msg["type"] != "ssubscribe": + logging.info("Waiting for ssubscribe message...") + msg = pubsub.get_sharded_message( + ignore_subscribe_messages=False, timeout=10 + ) + assert msg is not None and msg["type"] == "ssubscribe" + logging.info("Reconnect ended.") + assert not pubsub.node_pubsub_mapping[node_1.name].connection._should_reconnect assert pubsub.node_pubsub_mapping[node_1.name].connection._sock is not None assert ( @@ -1147,9 +1270,9 @@ def test_smigrating_smigrated_with_std_pubsub( smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( "SMIGRATING 12 5200-5460" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + self.proxy_helper.send_notification(smigrating_node_1) - # get message with node 1 connection to consume the notification + # get message with to consume the SMIGRATING notification # timeout is 1 second msg = pubsub.get_message(ignore_subscribe_messages=False, timeout=5000) # smigrating handled @@ -1162,31 +1285,31 @@ def test_smigrating_smigrated_with_std_pubsub( self.proxy_helper.set_cluster_slots( CLUSTER_SLOTS_INTERCEPTOR_NAME, [ - SlotsRange("0.0.0.0", NODE_PORT_1, 0, 5200), - SlotsRange("0.0.0.0", NODE_PORT_2, 5201, 10922), - SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + SlotsRange(NODE_IP_PROXY, NODE_PORT_1, 0, 5199), + SlotsRange(NODE_IP_PROXY, NODE_PORT_2, 5200, 10922), + SlotsRange(NODE_IP_PROXY, NODE_PORT_3, 10923, 16383), ], ) smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( - "SMIGRATED 14 0.0.0.0:15380 5200-5460" + f"SMIGRATED 13 {NODE_IP_PROXY}:{NODE_PORT_1} {NODE_IP_PROXY}:{NODE_PORT_2} 5200-5460" ) - self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + self.proxy_helper.send_notification(smigrated_node_1) # execute command with node 1 connection - # this will first consume the SMIGRATING notification for the connection + # this will first consume the SMIGRATING and SMIGRATED notifications for the connection # this should update the cluster topology and move the slot range to the new node # and should set the pubsub connection for reconnect res = self.cluster.set("anyprefix:{3}:k", "VAL") assert res is True - assert res is True - assert pubsub.connection._should_reconnect assert pubsub.connection._sock is not None - assert pubsub.connection._socket_timeout is None - assert pubsub.connection._socket_connect_timeout is None + # validate timeout is still relaxed - it will be unrelaxed when this concrete connection + # will read the notification + assert pubsub.connection._socket_timeout == 30 + assert pubsub.connection._socket_connect_timeout == 30 - # first message will be SMIGRATED notification handling + # next message will be SMIGRATED notification handling # during this read connection will be reconnected and will resubscribe to channels msg = pubsub.get_message(ignore_subscribe_messages=True, timeout=10) assert msg is None diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index 47a27a48cf..f01637b757 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -11,6 +11,7 @@ NodeMigratedNotification, NodeFailingOverNotification, NodeFailedOverNotification, + NodesToSlotsMapping, OSSNodeMigratingNotification, OSSNodeMigratedNotification, MaintNotificationsConfig, @@ -493,7 +494,13 @@ class TestOSSNodeMigratedNotification: def test_init_with_defaults(self): """Test OSSNodeMigratedNotification initialization with default values.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = {"127.0.0.1:6380": "1-100"} + nodes_to_slots_mapping = [ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ] notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping ) @@ -505,10 +512,18 @@ def test_init_with_defaults(self): def test_init_with_all_parameters(self): """Test OSSNodeMigratedNotification initialization with all parameters.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = { - "127.0.0.1:6380": "1-100", - "127.0.0.1:6381": "101-200", - } + nodes_to_slots_mapping = [ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ), + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6381", + slots="101-200", + ), + ] notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping, @@ -520,16 +535,29 @@ def test_init_with_all_parameters(self): def test_default_ttl(self): """Test that DEFAULT_TTL is used correctly.""" - assert OSSNodeMigratedNotification.DEFAULT_TTL == 30 + assert OSSNodeMigratedNotification.DEFAULT_TTL == 120 notification = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) - assert notification.ttl == 30 + assert notification.ttl == 120 def test_repr(self): """Test OSSNodeMigratedNotification string representation.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = {"127.0.0.1:6380": "1-100"} + nodes_to_slots_mapping = [ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ] notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping, @@ -539,19 +567,31 @@ def test_repr(self): repr_str = repr(notification) assert "OSSNodeMigratedNotification" in repr_str assert "id=1" in repr_str - assert "ttl=30" in repr_str - assert "remaining=20.0s" in repr_str + assert "ttl=120" in repr_str + assert "remaining=110.0s" in repr_str assert "expired=False" in repr_str def test_equality_same_id_and_type(self): """Test equality for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"}, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"}, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6381", + slots="101-200", + ) + ], ) # Should be equal because id and type are the same assert notification1 == notification2 @@ -559,17 +599,38 @@ def test_equality_same_id_and_type(self): def test_equality_different_id(self): """Test inequality for notifications with different id.""" notification1 = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = OSSNodeMigratedNotification( - id=2, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=2, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) assert notification1 != notification2 def test_equality_different_type(self): """Test inequality for notifications of different types.""" notification1 = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = NodeMigratedNotification(id=1) assert notification1 != notification2 @@ -578,11 +639,23 @@ def test_hash_same_id_and_type(self): """Test hash for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"}, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"}, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6381", + slots="101-200", + ) + ], ) # Should have same hash because id and type are the same assert hash(notification1) == hash(notification2) @@ -590,26 +663,68 @@ def test_hash_same_id_and_type(self): def test_hash_different_id(self): """Test hash for notifications with different id.""" notification1 = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = OSSNodeMigratedNotification( - id=2, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=2, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) assert hash(notification1) != hash(notification2) def test_in_set(self): """Test that notifications can be used in sets.""" notification1 = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification2 = OSSNodeMigratedNotification( - id=1, nodes_to_slots_mapping={"127.0.0.1:6380": "1-100"} + id=1, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6380", + slots="1-100", + ) + ], ) notification3 = OSSNodeMigratedNotification( - id=2, nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"} + id=2, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6381", + slots="101-200", + ) + ], ) notification4 = OSSNodeMigratedNotification( - id=2, nodes_to_slots_mapping={"127.0.0.1:6381": "101-200"} + id=2, + nodes_to_slots_mapping=[ + NodesToSlotsMapping( + src_node_address="127.0.0.1:6379", + dest_node_address="127.0.0.1:6381", + slots="101-200", + ) + ], ) notification_set = {notification1, notification2, notification3, notification4} diff --git a/tests/maint_notifications/test_maint_notifications_handling.py b/tests/maint_notifications/test_maint_notifications_handling.py index a61106dae1..19f0d378b7 100644 --- a/tests/maint_notifications/test_maint_notifications_handling.py +++ b/tests/maint_notifications/test_maint_notifications_handling.py @@ -2076,10 +2076,8 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): ) # validate free connections for ip1 changed_free_connections = 0 - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections + free_connections = pool._get_free_connections() + for conn in free_connections: if conn.host == new_ip: changed_free_connections += 1 @@ -2126,10 +2124,8 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): ) # validate free connections for ip2 changed_free_connections = 0 - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections + free_connections = pool._get_free_connections() + for conn in free_connections: if conn.host == new_ip_2: changed_free_connections += 1 diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 08f3e69dbd..6549d989ca 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -47,6 +47,15 @@ def use_mock_proxy(): return os.getenv("REDIS_ENTERPRISE_TESTS", "true").lower() == "false" +# Module-level singleton for fault injector client used in parametrize +# This ensures we create only ONE instance that's shared between parametrize and fixture +_FAULT_INJECTOR_CLIENT_OSS_API = ( + ProxyServerFaultInjector(oss_cluster=True) + if use_mock_proxy() + else REFaultInjector(os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324")) +) + + @pytest.fixture() def endpoint_name(request): return request.config.getoption("--endpoint-name") or os.getenv( @@ -122,11 +131,8 @@ def fault_injector_client(): @pytest.fixture() def fault_injector_client_oss_api(): - if use_mock_proxy(): - return ProxyServerFaultInjector(oss_cluster=True) - else: - url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") - return REFaultInjector(url) + """Return the singleton instance to ensure parametrize and tests use the same client.""" + return _FAULT_INJECTOR_CLIENT_OSS_API @pytest.fixture() diff --git a/tests/test_scenario/fault_injector_client.py b/tests/test_scenario/fault_injector_client.py index c74556db21..de76a7b512 100644 --- a/tests/test_scenario/fault_injector_client.py +++ b/tests/test_scenario/fault_injector_client.py @@ -164,6 +164,13 @@ def execute_rebind( def get_moving_ttl(self) -> int: pass + @abstractmethod + def get_slot_migrate_triggers( + self, + effect_name: SlotMigrateEffects, + ) -> Dict[str, Any]: + pass + @abstractmethod def trigger_effect( self, @@ -313,17 +320,19 @@ def delete_database( ) -> Dict[str, Any]: logging.debug(f"Deleting database with id: {bdb_id}") params = {"bdb_id": bdb_id} - create_db_action = ActionRequest( + delete_db_action = ActionRequest( action_type=ActionType.DELETE_DATABASE, parameters=params, ) - result = self.trigger_action(create_db_action) + result = self.trigger_action(delete_db_action) action_id = result.get("action_id") if not action_id: raise Exception(f"Failed to trigger delete database action: {result}") action_status_check_response = self.get_operation_result(action_id) + self._current_db_id = None + if action_status_check_response.get("status") != TaskStatuses.SUCCESS: raise Exception( f"Delete database action failed: {action_status_check_response}" @@ -675,6 +684,15 @@ def execute_rebind( except Exception as e: raise Exception(f"Failed to execute rladmin bind endpoint: {e}") + def get_slot_migrate_triggers( + self, + effect_name: SlotMigrateEffects, + ) -> Dict[str, Any]: + """Get available triggers(trigger name + db example config) for a slot migration effect.""" + return self._make_request( + "GET", f"/slot-migrate?effect={effect_name.value}&cluster_index=0" + ) + def trigger_effect( self, endpoint_config: Dict[str, Any], @@ -696,7 +714,7 @@ def trigger_effect( "bdb_id": bdb_id, "cluster_index": cluster_index, "effect": effect_name, - "variant": trigger_name, # will be renamed to trigger + "trigger": trigger_name, } if source_node: parameters["source_node"] = source_node @@ -890,7 +908,7 @@ def execute_failover( f"FAILING_OVER {self._get_seq_id()} 2 [1]" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, start_maint_notif) + self.proxy_helper.send_notification(start_maint_notif) # sleep to allow the client to receive the notification time.sleep(self.SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS) @@ -913,7 +931,7 @@ def execute_failover( end_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( f"FAILED_OVER {self._get_seq_id()} [1]" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, end_maint_notif) + self.proxy_helper.send_notification(end_maint_notif) return {"status": "done"} @@ -944,7 +962,7 @@ def execute_migrate( f"MIGRATING {self._get_seq_id()} 2 [1]" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, start_maint_notif) + self.proxy_helper.send_notification(start_maint_notif) # sleep to allow the client to receive the notification time.sleep(self.SLEEP_TIME_BETWEEN_START_END_NOTIFICATIONS) @@ -964,13 +982,13 @@ def execute_migrate( end_maint_notif = RespTranslator.oss_maint_notification_to_resp( f"SMIGRATED {self._get_seq_id()} 127.0.0.1:{self.NODE_PORT_2} 0-200" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, end_maint_notif) + self.proxy_helper.send_notification(end_maint_notif) else: # send migrated end_maint_notif = RespTranslator.re_cluster_maint_notification_to_resp( f"MIGRATED {self._get_seq_id()} [1]" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, end_maint_notif) + self.proxy_helper.send_notification(end_maint_notif) return "done" @@ -994,7 +1012,7 @@ def execute_rebind(self, endpoint_config: Dict[str, Any], endpoint_id: str) -> s maint_start_notif = RespTranslator.re_cluster_maint_notification_to_resp( f"MOVING {self._get_seq_id()} {sleep_time} 127.0.0.1:{self.NODE_PORT_3}" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, maint_start_notif) + self.proxy_helper.send_notification(maint_start_notif) # sleep to allow the client to receive the notification time.sleep(sleep_time) @@ -1012,7 +1030,7 @@ def execute_rebind(self, endpoint_config: Dict[str, Any], endpoint_id: str) -> s smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( f"SMIGRATED {self._get_seq_id()} 127.0.0.1:{self.NODE_PORT_3} 0-8191" ) - self.proxy_helper.send_notification(self.NODE_PORT_1, smigrated_node_1) + self.proxy_helper.send_notification(smigrated_node_1) else: # TODO drop connections to node 1 to simulate that the node is removed pass @@ -1022,6 +1040,12 @@ def execute_rebind(self, endpoint_config: Dict[str, Any], endpoint_id: str) -> s def get_moving_ttl(self) -> int: return self.MOVING_TTL + def get_slot_migrate_triggers( + self, + effect_name: SlotMigrateEffects, + ) -> Dict[str, Any]: + raise NotImplementedError("Not implemented for proxy server") + def trigger_effect( self, endpoint_config: Dict[str, Any], diff --git a/tests/test_scenario/maint_notifications_helpers.py b/tests/test_scenario/maint_notifications_helpers.py index 68588beadf..eb23f7cc7b 100644 --- a/tests/test_scenario/maint_notifications_helpers.py +++ b/tests/test_scenario/maint_notifications_helpers.py @@ -162,6 +162,14 @@ def execute_rebind( """Execute rladmin bind endpoint command and wait for completion.""" return fault_injector.execute_rebind(endpoint_config, endpoint_id) + @staticmethod + def get_slot_migrate_triggers( + fault_injector: FaultInjectorClient, + effect_name: SlotMigrateEffects, + ) -> Dict[str, Any]: + """Get available triggers(trigger name + db example config) for a slot migration effect.""" + return fault_injector.get_slot_migrate_triggers(effect_name) + @staticmethod def trigger_effect( fault_injector: FaultInjectorClient, diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index 75575b10a3..e51eb55d1c 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -1,12 +1,14 @@ """Tests for Redis Enterprise moving push notifications with real cluster operations.""" from concurrent.futures import ThreadPoolExecutor +import json import logging +import random from queue import Queue from threading import Thread import threading import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Union import pytest @@ -20,8 +22,8 @@ from tests.test_scenario.conftest import ( CLIENT_TIMEOUT, RELAXED_TIMEOUT, + _FAULT_INJECTOR_CLIENT_OSS_API, _get_client_maint_notifications, - get_bdbs_config, get_cluster_client_maint_notifications, use_mock_proxy, ) @@ -52,6 +54,8 @@ SLOT_SHUFFLE_TIMEOUT = 120 DEFAULT_BIND_TTL = 15 +DEFAULT_STANDALONE_CLIENT_SOCKET_TIMEOUT = 1 +DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT = 1 class TestPushNotificationsBase: @@ -194,42 +198,45 @@ def _validate_moving_state( """Validate the client connections are in the expected state after migration.""" matching_connected_conns_count = 0 matching_disconnected_conns_count = 0 - connections = self._get_all_connections_in_pool(client) - for conn in connections: - endpoint_configured_correctly = bool( - ( - configured_endpoint_type == EndpointType.NONE - and conn.host == conn.orig_host_address - ) - or ( - configured_endpoint_type != EndpointType.NONE - and conn.host != conn.orig_host_address - and ( - configured_endpoint_type - == MaintNotificationsConfig().get_endpoint_type(conn.host, conn) + with client.connection_pool._lock: + connections = self._get_all_connections_in_pool(client) + for conn in connections: + endpoint_configured_correctly = bool( + ( + configured_endpoint_type == EndpointType.NONE + and conn.host == conn.orig_host_address + ) + or ( + configured_endpoint_type != EndpointType.NONE + and conn.host != conn.orig_host_address + and ( + configured_endpoint_type + == MaintNotificationsConfig().get_endpoint_type( + conn.host, conn + ) + ) ) + or isinstance( + fault_injector_client, ProxyServerFaultInjector + ) # we should not validate the endpoint type when using proxy server ) - or isinstance( - fault_injector_client, ProxyServerFaultInjector - ) # we should not validate the endpoint type when using proxy server - ) - if ( - conn._sock is not None - and conn._sock.gettimeout() == RELAXED_TIMEOUT - and conn.maintenance_state == MaintenanceState.MOVING - and endpoint_configured_correctly - ): - matching_connected_conns_count += 1 - elif ( - conn._sock is None - and conn.maintenance_state == MaintenanceState.MOVING - and conn.socket_timeout == RELAXED_TIMEOUT - and endpoint_configured_correctly - ): - matching_disconnected_conns_count += 1 - else: - pass + if ( + conn._sock is not None + and conn._sock.gettimeout() == RELAXED_TIMEOUT + and conn.maintenance_state == MaintenanceState.MOVING + and endpoint_configured_correctly + ): + matching_connected_conns_count += 1 + elif ( + conn._sock is None + and conn.maintenance_state == MaintenanceState.MOVING + and conn.socket_timeout == RELAXED_TIMEOUT + and endpoint_configured_correctly + ): + matching_disconnected_conns_count += 1 + else: + pass assert matching_connected_conns_count == expected_matching_connected_conns_count assert ( matching_disconnected_conns_count @@ -239,7 +246,7 @@ def _validate_moving_state( def _validate_default_state( self, client: Redis, - expected_matching_conns_count: int, + expected_matching_conns_count: Union[int, Literal["all"]], configured_timeout: float = CLIENT_TIMEOUT, ): """Validate the client connections are in the expected state after migration.""" @@ -284,6 +291,9 @@ def _validate_default_state( client_host = conn_kwargs.get("host", "unknown") client_port = conn_kwargs.get("port", "unknown") + if expected_matching_conns_count == "all": + expected_matching_conns_count = len(connections) + assert matching_conns_count == expected_matching_conns_count, ( f"Default state validation failed. " f"Client: host={client_host}, port={client_port}, " @@ -1158,7 +1168,7 @@ def test_command_execution_during_migrating_and_moving( else: execution_duration = 180 - socket_timeout = 0.5 + socket_timeout = DEFAULT_STANDALONE_CLIENT_SOCKET_TIMEOUT client = _get_client_maint_notifications( endpoints_config=endpoints_config, @@ -1196,6 +1206,9 @@ def execute_commands(duration: int, errors: Queue): thread.start() threads.append(thread) + logging.info("Waiting for threads to start and have a few cycles executed ...") + time.sleep(3) + migrate_and_bind_thread = Thread( target=self._execute_migrate_bind_flow, name="migrate_and_bind_thread", @@ -1289,217 +1302,16 @@ def setup_and_cleanup( logging.info("Cleanup finished") @pytest.mark.timeout(300) # 5 minutes timeout for this test - def test_notification_handling_during_node_fail_over( - self, - fault_injector_client_oss_api: FaultInjectorClient, - ): - """ - Test the push notifications are received when executing re cluster operations. - - """ - logging.info("Creating one connection in the pool.") - # get the node covering first shard - it is the node we will failover - target_node = ( - self.cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) - ) - logging.info(f"Target node for slot 0: {target_node.name}") - conn = target_node.redis_connection.connection_pool.get_connection() - cluster_nodes = ( - self.cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() - ) - - logging.info("Executing failover command...") - failover_thread = Thread( - target=self._execute_failover, - name="failover_thread", - args=(fault_injector_client_oss_api, self.cluster_endpoint_config), - ) - self.maintenance_ops_threads.append(failover_thread) - failover_thread.start() - - logging.info("Waiting for SMIGRATING push notifications...") - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATING_TIMEOUT, - connection=conn, - ) - - logging.info("Validating connection maintenance state...") - assert conn.maintenance_state == MaintenanceState.MAINTENANCE - assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.should_reconnect() is False - - assert len(cluster_nodes) == len( - self.cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for node_key in cluster_nodes.keys(): - assert ( - node_key - in self.cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - - logging.info("Waiting for SMIGRATED push notifications...") - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATED_TIMEOUT, - connection=conn, - ) - - logging.info("Validating connection state after SMIGRATED ...") - # connection will be dropped, but it is marked - # to be disconnected before released to the pool - # we don't waste time to update the timeouts and state - # so it is pointless to check those configs - assert conn.should_reconnect() is True - - # validate that the node was removed from the cluster - # for re clusters we don't receive the replica nodes, - # so after failover the node is removed from the cluster - # and the previous replica that is promoted to primary is added as a new node - - # the overall number of nodes should be the same - one removed and one added - # when I have a db with two shards only and perform a failover on one of them, I can edn up with just one node that holds both shards - # assert len(cluster_nodes) == len( - # cluster_client_maint_notifications.nodes_manager.nodes_cache - # ) - assert ( - target_node.name - not in self.cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - - logging.info("Releasing connection back to the pool...") - target_node.redis_connection.connection_pool.release(conn) - - failover_thread.join() - self.maintenance_ops_threads.remove(failover_thread) - - @pytest.mark.timeout(300) # 5 minutes timeout for this test - def test_command_execution_during_node_fail_over( - self, - fault_injector_client_oss_api: FaultInjectorClient, - ): - """ - Test the push notifications are received when executing re cluster operations. - - """ - - errors = Queue() - if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): - execution_duration = 20 - else: - execution_duration = 180 - - socket_timeout = 3 - - logging.info("Creating client with disabled retries.") - cluster_client_maint_notifications = get_cluster_client_maint_notifications( - endpoints_config=self.cluster_endpoint_config, - disable_retries=False, - socket_timeout=socket_timeout, - enable_maintenance_notifications=True, - ) - # executing initial commands to consume old notifications - cluster_client_maint_notifications.set("key:{3}", "value") - cluster_client_maint_notifications.set("key:{0}", "value") - logging.info("Cluster client created and initialized.") - - def execute_commands(duration: int, errors: Queue): - start = time.time() - while time.time() - start < duration: - try: - # the slot is covered by the first shard - this one will failover - cluster_client_maint_notifications.set("key:{3}", "value") - cluster_client_maint_notifications.get("key:{3}") - except Exception as e: - logging.error( - f"Error in thread {threading.current_thread().name} for key on first shard: {e}" - ) - errors.put( - f"Command failed in thread {threading.current_thread().name} for key on first shard: {e}" - ) - try: - # execute also commands that will run on the second shard - cluster_client_maint_notifications.set("key:{0}", "value") - cluster_client_maint_notifications.get("key:{0}") - except Exception as e: - logging.error( - f"Error in thread {threading.current_thread().name} for key on second shard: {e}" - ) - errors.put( - f"Command failed in thread {threading.current_thread().name} for key on second shard: {e}" - ) - logging.debug(f"{threading.current_thread().name}: Thread ended") - - # get the node covering first shard - it is the node we will failover - target_node = ( - cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) - ) - cluster_nodes = ( - cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() - ) - - threads = [] - for i in range(10): - thread = Thread( - target=execute_commands, - name=f"command_execution_thread_{i}", - args=( - execution_duration, - errors, - ), - ) - thread.start() - threads.append(thread) - - logging.info("Executing failover command...") - failover_thread = Thread( - target=self._execute_failover, - name="failover_thread", - args=(fault_injector_client_oss_api, self.cluster_endpoint_config), - ) - self.maintenance_ops_threads.append(failover_thread) - failover_thread.start() - - for thread in threads: - thread.join() - - failover_thread.join() - self.maintenance_ops_threads.remove(failover_thread) - - # validate that the failed_over primary node was removed from the cluster - # for re clusters we don't receive the replica nodes, - # so after failover the node is removed from the cluster - # and the previous replica that is promoted to primary is added as a new node - - # the overall number of nodes should be the same - one removed and one added - assert len(cluster_nodes) == len( - cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - assert ( - target_node.name - not in cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - - for ( - node - ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): - # validate connections settings - self._validate_default_state( - node.redis_connection, - expected_matching_conns_count=10, - configured_timeout=socket_timeout, - ) - - # validate no errors were raised in the command execution threads - assert errors.empty(), f"Errors occurred in threads: {errors.queue}" - - @pytest.mark.timeout(300) # 5 minutes timeout for this test - def test_notification_handling_during_migration_with_node_replacement( + @pytest.mark.skipif( + use_mock_proxy(), + reason="Mock proxy doesn't support sending notifications to new connections.", + ) + def test_new_connections_receive_last_notification_with_migrating( self, fault_injector_client_oss_api: FaultInjectorClient, ): """ - Test the push notifications are received when executing re cluster operations. + Test the push notifications are sent to the newly created connections. """ cluster_op_target_node, cluster_op_empty_node = ( @@ -1507,7 +1319,6 @@ def test_notification_handling_during_migration_with_node_replacement( fault_injector_client_oss_api, self.cluster_endpoint_config ) ) - db_port = ( self.cluster_endpoint_config["raw_endpoints"][0]["port"] if self.cluster_endpoint_config @@ -1522,14 +1333,11 @@ def test_notification_handling_during_migration_with_node_replacement( f"Creating one connection in the pool using node {target_node.name}." ) conn = target_node.redis_connection.connection_pool.get_connection() - cluster_nodes = ( - self.cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() - ) - logging.info("Executing migrate flow ...") + logging.info("Executing migrate all data from one node to another ...") migrate_thread = Thread( target=self._execute_migration, - name="migrate", + name="migrate_thread", args=( fault_injector_client_oss_api, self.cluster_endpoint_config, @@ -1537,32 +1345,40 @@ def test_notification_handling_during_migration_with_node_replacement( cluster_op_empty_node.node_id, ), ) + self.maintenance_ops_threads.append(migrate_thread) migrate_thread.start() - time.sleep(20) - logging.info("Waiting for SMIGRATING push notifications...") + logging.info( + f"Waiting for SMIGRATING push notifications with the existing connection: {conn}..." + ) ClientValidations.wait_push_notification( self.cluster_client_maint_notifications, timeout=SMIGRATING_TIMEOUT, connection=conn, ) - logging.info("Validating connection maintenance state...") + new_conn = target_node.redis_connection.connection_pool.get_connection() + logging.info( + f"Validating newly created connection will also receive the notification: {new_conn}..." + ) + ClientValidations.wait_push_notification( + self.cluster_client_maint_notifications, + timeout=1, # the notification should have already been sent once, so new conn should receive it almost immediately + connection=new_conn, + fail_on_timeout=False, + ) + + logging.info("Validating connections maintenance state...") assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn._sock.gettimeout() == RELAXED_TIMEOUT assert conn.should_reconnect() is False - assert len(cluster_nodes) == len( - self.cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for node_key in cluster_nodes.keys(): - assert ( - node_key - in self.cluster_client_maint_notifications.nodes_manager.nodes_cache - ) + assert new_conn.maintenance_state == MaintenanceState.MAINTENANCE + assert new_conn._sock.gettimeout() == RELAXED_TIMEOUT + assert new_conn.should_reconnect() is False - logging.info("Waiting for SMIGRATED push notifications...") + logging.info(f"Waiting for SMIGRATED push notifications with {conn}...") ClientValidations.wait_push_notification( self.cluster_client_maint_notifications, timeout=SMIGRATED_TIMEOUT, @@ -1572,467 +1388,570 @@ def test_notification_handling_during_migration_with_node_replacement( logging.info("Validating connection state after SMIGRATED ...") assert conn.should_reconnect() is True + assert new_conn.should_reconnect() is True - # the overall number of nodes should be the same - one removed and one added - assert len(cluster_nodes) == len( - self.cluster_client_maint_notifications.nodes_manager.nodes_cache + new_conn_after_smigrated = ( + target_node.redis_connection.connection_pool.get_connection() ) - assert ( - target_node.name - not in self.cluster_client_maint_notifications.nodes_manager.nodes_cache + assert new_conn_after_smigrated.maintenance_state == MaintenanceState.NONE + assert new_conn_after_smigrated._sock.gettimeout() == CLIENT_TIMEOUT + assert not new_conn_after_smigrated.should_reconnect() + + logging.info( + f"Waiting for SMIGRATED push notifications with another new connection: {new_conn_after_smigrated}..." + ) + ClientValidations.wait_push_notification( + self.cluster_client_maint_notifications, + timeout=1, + connection=new_conn_after_smigrated, + fail_on_timeout=False, ) - logging.info("Releasing connection back to the pool...") + logging.info("Releasing connections back to the pool...") target_node.redis_connection.connection_pool.release(conn) + target_node.redis_connection.connection_pool.release(new_conn) + target_node.redis_connection.connection_pool.release(new_conn_after_smigrated) migrate_thread.join() self.maintenance_ops_threads.remove(migrate_thread) - @pytest.mark.timeout(300) # 5 minutes timeout for this test - def test_command_execution_during_migration_with_node_replacement( + +def generate_params( + fault_injector_client: FaultInjectorClient, + effect_names: list[SlotMigrateEffects], +): + # params should produce list of tuples: (effect_name, trigger_name, bdb_config, bdb_name) + params = [] + logging.info(f"Extracting params for test with effect_names: {effect_names}") + for effect_name in effect_names: + triggers_data = ClusterOperations.get_slot_migrate_triggers( + fault_injector_client, effect_name + ) + + for trigger_info in triggers_data["triggers"]: + trigger = trigger_info["name"] + if trigger == "maintenance_mode": + continue + trigger_requirements = trigger_info["requirements"] + for requirement in trigger_requirements: + dbconfig = requirement["dbconfig"] + ip_type = requirement["oss_cluster_api"]["ip_type"] + if ip_type == "internal": + continue + db_name_pattern = dbconfig.get("name").rsplit("-", 1)[0] + dbconfig["name"] = ( + db_name_pattern # this will ensure dbs will be deleted + ) + + params.append((effect_name, trigger, dbconfig, db_name_pattern)) + + return params + + +class TestClusterClientPushNotificationsWithEffectTriggerBase( + TestPushNotificationsBase +): + def delete_prev_db( self, fault_injector_client_oss_api: FaultInjectorClient, + db_name: str, ): - """ - Test the push notifications are received when executing re cluster operations. - """ - - cluster_op_target_node, cluster_op_empty_node = ( - self.extract_target_node_and_empty_node( - fault_injector_client_oss_api, self.cluster_endpoint_config + try: + logging.info(f"Deleting database if exists: {db_name}") + existing_db_id = None + existing_db_id = ClusterOperations.find_database_id_by_name( + fault_injector_client_oss_api, db_name ) - ) - errors = Queue() - if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): - execution_duration = 20 - else: - execution_duration = 180 - socket_timeout = 3 + if existing_db_id: + fault_injector_client_oss_api.delete_database(existing_db_id) + logging.info(f"Deleted database: {db_name}") + else: + logging.info(f"Database {db_name} does not exist.") + except Exception as e: + logging.error(f"Failed to delete database {db_name}: {e}") - cluster_client_maint_notifications = get_cluster_client_maint_notifications( - endpoints_config=self.cluster_endpoint_config, - disable_retries=True, - socket_timeout=socket_timeout, - enable_maintenance_notifications=True, + def create_db( + self, + fault_injector_client_oss_api: FaultInjectorClient, + bdb_config: Dict[str, Any], + ): + try: + logging.info(f"Creating database: \n{json.dumps(bdb_config, indent=2)}") + cluster_endpoint_config = fault_injector_client_oss_api.create_database( + bdb_config + ) + return cluster_endpoint_config + except Exception as e: + pytest.fail(f"Failed to create database: {e}") + + def setup_env( + self, + fault_injector_client_oss_api: FaultInjectorClient, + db_config: Dict[str, Any], + ): + self.delete_prev_db(fault_injector_client_oss_api, db_config["name"]) + + cluster_endpoint_config = self.create_db( + fault_injector_client_oss_api, db_config ) - def execute_commands(duration: int, errors: Queue): - start = time.time() - while time.time() - start < duration: - try: - # the slot is covered by the first shard - this one will have slots migrated - cluster_client_maint_notifications.set("key:{3}", "value") - cluster_client_maint_notifications.get("key:{3}") - # execute also commands that will run on the second shard - cluster_client_maint_notifications.set("key:{0}", "value") - cluster_client_maint_notifications.get("key:{0}") - except Exception as e: - logging.error( - f"Error in thread {threading.current_thread().name}: {e}" - ) - errors.put( - f"Command failed in thread {threading.current_thread().name}: {e}" - ) - logging.debug(f"{threading.current_thread().name}: Thread ended") + self._bdb_name = db_config["name"] + socket_timeout = DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT - db_port = ( - self.cluster_endpoint_config["raw_endpoints"][0]["port"] - if self.cluster_endpoint_config - else None + cluster_client_maint_notifications = get_cluster_client_maint_notifications( + endpoints_config=cluster_endpoint_config, + disable_retries=True, + socket_timeout=socket_timeout, + enable_maintenance_notifications=True, ) - # get the node that will be migrated - target_node = cluster_client_maint_notifications.nodes_manager.get_node( - host=cluster_op_target_node.external_address, - port=db_port, + return cluster_client_maint_notifications, cluster_endpoint_config + + @pytest.fixture(autouse=True) + def setup_and_cleanup( + self, + ): + self.maintenance_ops_threads = [] + self._bdb_name = None + + # Yield control to the test + yield + + # Cleanup code - this will run even if the test fails + logging.info("Starting cleanup...") + if self._bdb_name: + self.delete_prev_db(_FAULT_INJECTOR_CLIENT_OSS_API, self._bdb_name) + + logging.info("Waiting for maintenance operations threads to finish...") + for thread in self.maintenance_ops_threads: + thread.join() + + logging.info("Cleanup finished") + + +class TestClusterClientPushNotificationsHandlingWithEffectTrigger( + TestClusterClientPushNotificationsWithEffectTriggerBase +): + @pytest.mark.timeout(300) # 5 minutes timeout for this test + @pytest.mark.parametrize( + "effect_name, trigger, db_config, db_name", + generate_params( + _FAULT_INJECTOR_CLIENT_OSS_API, [SlotMigrateEffects.SLOT_SHUFFLE] + ), + ) + def test_notification_handling_during_node_shuffle_no_node_replacement( + self, + fault_injector_client_oss_api: FaultInjectorClient, + effect_name: SlotMigrateEffects, + trigger: str, + db_config: dict[str, Any], + db_name: str, + ): + """ + Test the push notifications are received when executing re cluster operations. + The test validates the behavior when during the operations the slots are moved + between the nodes, but no new nodes are appearing and no nodes are disappearing + + """ + logging.info(f"DB name: {db_name}") + + cluster_client_maint_notifications, cluster_endpoint_config = self.setup_env( + fault_injector_client_oss_api, db_config ) - cluster_nodes = ( + logging.info("Creating one connection in each node's pool.") + initial_cluster_nodes = ( cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() ) - - threads = [] - for i in range(10): - thread = Thread( - target=execute_commands, - name=f"command_execution_thread_{i}", - args=( - execution_duration, - errors, - ), + in_use_connections = {} + for node in initial_cluster_nodes.values(): + in_use_connections[node] = ( + node.redis_connection.connection_pool.get_connection() ) - thread.start() - threads.append(thread) - logging.info("Executing migration flow...") - migrate_thread = Thread( - target=self._execute_migration, - name="migration_thread", + logging.info("Executing FI command that triggers the desired effect...") + trigger_effect_thread = Thread( + target=self._trigger_effect, + name="trigger_effect_thread", args=( fault_injector_client_oss_api, - self.cluster_endpoint_config, - cluster_op_target_node.node_id, - cluster_op_empty_node.node_id, + cluster_endpoint_config, + effect_name, + trigger, ), ) - self.maintenance_ops_threads.append(migrate_thread) - migrate_thread.start() + self.maintenance_ops_threads.append(trigger_effect_thread) + trigger_effect_thread.start() - for thread in threads: - thread.join() + logging.info("Waiting for SMIGRATING push notifications in all connections...") + for conn in in_use_connections.values(): + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=int(SLOT_SHUFFLE_TIMEOUT / 2), + connection=conn, + ) - migrate_thread.join() - self.maintenance_ops_threads.remove(migrate_thread) + logging.info("Validating connection maintenance state...") + for conn in in_use_connections.values(): + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False - # validate cluster nodes - assert len(cluster_nodes) == len( + assert len(initial_cluster_nodes) == len( cluster_client_maint_notifications.nodes_manager.nodes_cache ) - assert ( - target_node.name - not in cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for ( - node - ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): - # validate connections settings - self._validate_default_state( - node.redis_connection, - expected_matching_conns_count=10, - configured_timeout=socket_timeout, + for node_key in initial_cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache ) - # validate no errors were raised in the command execution threads - assert errors.empty(), f"Errors occurred in threads: {errors.queue}" + logging.info("Waiting for SMIGRATED push notifications...") + con_to_read_smigrated = random.choice(list(in_use_connections.values())) + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATED_TIMEOUT, + connection=con_to_read_smigrated, + ) + + logging.info("Validating connection state after SMIGRATED ...") + + updated_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) + + removed_nodes = set(initial_cluster_nodes.values()) - set( + updated_cluster_nodes.values() + ) + assert len(removed_nodes) == 0 + assert len(initial_cluster_nodes) == len(updated_cluster_nodes) + + marked_conns_for_reconnect = 0 + for conn in in_use_connections.values(): + if conn.should_reconnect(): + marked_conns_for_reconnect += 1 + # only one connection should be marked for reconnect + # onle the one that belongs to the node that was from + # the src address of the maintenance + assert marked_conns_for_reconnect == 1 + + logging.info("Releasing connections back to the pool...") + for node, conn in in_use_connections.items(): + if node.redis_connection is None: + continue + node.redis_connection.connection_pool.release(conn) + + trigger_effect_thread.join() + self.maintenance_ops_threads.remove(trigger_effect_thread) @pytest.mark.timeout(300) # 5 minutes timeout for this test - @pytest.mark.skipif( - use_mock_proxy(), - reason="Mock proxy doesn't support sending notifications to new connections.", + @pytest.mark.parametrize( + "effect_name, trigger, db_config, db_name", + generate_params( + _FAULT_INJECTOR_CLIENT_OSS_API, + [ + SlotMigrateEffects.REMOVE_ADD, + ], + ), ) - def test_new_connections_receive_last_notification_with_migrating( + def test_notification_handling_with_node_replace( self, fault_injector_client_oss_api: FaultInjectorClient, + effect_name: SlotMigrateEffects, + trigger: str, + db_config: dict[str, Any], + db_name: str, ): """ - Test the push notifications are sent to the newly created connections. + Test the push notifications are received when executing re cluster operations. + The test validates the behavior when during the operations the slots are moved + between the nodes, and as a result a node is removed and a new node is added to the cluster """ - cluster_op_target_node, cluster_op_empty_node = ( - self.extract_target_node_and_empty_node( - fault_injector_client_oss_api, self.cluster_endpoint_config - ) - ) - db_port = ( - self.cluster_endpoint_config["raw_endpoints"][0]["port"] - if self.cluster_endpoint_config - else None - ) - # get the node that will be migrated - target_node = self.cluster_client_maint_notifications.nodes_manager.get_node( - host=cluster_op_target_node.external_address, - port=db_port, + logging.info(f"DB name: {db_name}") + + cluster_client_maint_notifications, cluster_endpoint_config = self.setup_env( + fault_injector_client_oss_api, db_config ) - logging.info( - f"Creating one connection in the pool using node {target_node.name}." + + logging.info("Creating one connection in each node's pool.") + + initial_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() ) - conn = target_node.redis_connection.connection_pool.get_connection() + in_use_connections = {} + for node in initial_cluster_nodes.values(): + in_use_connections[node] = ( + node.redis_connection.connection_pool.get_connection() + ) - logging.info("Executing migrate all data from one node to another ...") - migrate_thread = Thread( - target=self._execute_migration, - name="migrate_thread", + logging.info("Executing FI command that triggers the desired effect...") + trigger_effect_thread = Thread( + target=self._trigger_effect, + name="trigger_effect_thread", args=( fault_injector_client_oss_api, - self.cluster_endpoint_config, - cluster_op_target_node.node_id, - cluster_op_empty_node.node_id, + cluster_endpoint_config, + effect_name, + trigger, ), ) + self.maintenance_ops_threads.append(trigger_effect_thread) + trigger_effect_thread.start() - self.maintenance_ops_threads.append(migrate_thread) - migrate_thread.start() + logging.info("Waiting for SMIGRATING push notifications in all connections...") + for conn in in_use_connections.values(): + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATING_TIMEOUT, + connection=conn, + ) - logging.info( - f"Waiting for SMIGRATING push notifications with the existing connection: {conn}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATING_TIMEOUT, - connection=conn, - ) + logging.info("Validating connection maintenance state...") + for conn in in_use_connections.values(): + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False - new_conn = target_node.redis_connection.connection_pool.get_connection() - logging.info( - f"Validating newly created connection will also receive the notification: {new_conn}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=1, # the notification should have already been sent once, so new conn should receive it almost immediately - connection=new_conn, - fail_on_timeout=False, + assert len(initial_cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache ) - logging.info("Validating connections maintenance state...") - assert conn.maintenance_state == MaintenanceState.MAINTENANCE - assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.should_reconnect() is False - - assert new_conn.maintenance_state == MaintenanceState.MAINTENANCE - assert new_conn._sock.gettimeout() == RELAXED_TIMEOUT - assert new_conn.should_reconnect() is False + for node_key in initial_cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) - logging.info(f"Waiting for SMIGRATED push notifications with {conn}...") + logging.info("Waiting for SMIGRATED push notifications...") + con_to_read_smigrated = random.choice(list(in_use_connections.values())) ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, + cluster_client_maint_notifications, timeout=SMIGRATED_TIMEOUT, - connection=conn, + connection=con_to_read_smigrated, ) logging.info("Validating connection state after SMIGRATED ...") - assert conn.should_reconnect() is True - assert new_conn.should_reconnect() is True - - new_conn_after_smigrated = ( - target_node.redis_connection.connection_pool.get_connection() + updated_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() ) - assert new_conn_after_smigrated.maintenance_state == MaintenanceState.NONE - assert new_conn_after_smigrated._sock.gettimeout() == CLIENT_TIMEOUT - assert not new_conn_after_smigrated.should_reconnect() - logging.info( - f"Waiting for SMIGRATED push notifications with another new connection: {new_conn_after_smigrated}..." + removed_nodes = set(initial_cluster_nodes.values()) - set( + updated_cluster_nodes.values() ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=1, - connection=new_conn_after_smigrated, - fail_on_timeout=False, + assert len(removed_nodes) == 1 + removed_node = removed_nodes.pop() + assert removed_node is not None + + added_nodes = set(updated_cluster_nodes.values()) - set( + initial_cluster_nodes.values() ) + assert len(added_nodes) == 1 + + conn = in_use_connections.get(removed_node) + # connection will be dropped, but it is marked + # to be disconnected before released to the pool + # we don't waste time to update the timeouts and state + # so it is pointless to check those configs + assert conn is not None + assert conn.should_reconnect() is True logging.info("Releasing connections back to the pool...") - target_node.redis_connection.connection_pool.release(conn) - target_node.redis_connection.connection_pool.release(new_conn) - target_node.redis_connection.connection_pool.release(new_conn_after_smigrated) + for node, conn in in_use_connections.items(): + if node.redis_connection is None: + continue + node.redis_connection.connection_pool.release(conn) - migrate_thread.join() - self.maintenance_ops_threads.remove(migrate_thread) + trigger_effect_thread.join() + self.maintenance_ops_threads.remove(trigger_effect_thread) @pytest.mark.timeout(300) # 5 minutes timeout for this test - @pytest.mark.skipif( - use_mock_proxy(), - reason="Mock proxy doesn't support sending notifications to new connections.", + @pytest.mark.parametrize( + "effect_name, trigger, db_config, db_name", + generate_params( + _FAULT_INJECTOR_CLIENT_OSS_API, + [ + SlotMigrateEffects.REMOVE, + ], + ), ) - def test_new_connections_receive_last_notification_with_failover( + def test_notification_handling_with_node_remove( self, fault_injector_client_oss_api: FaultInjectorClient, + effect_name: SlotMigrateEffects, + trigger: str, + db_config: dict[str, Any], + db_name: str, ): """ - Test the push notifications are sent to the newly created connections. + Test the push notifications are received when executing re cluster operations. + The test validates the behavior when during the operations the slots are moved + between the nodes, and as a result a node is removed. """ - # get the node that will be migrated - target_node = ( - self.cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) + logging.info(f"DB name: {db_name}") + + cluster_client_maint_notifications, cluster_endpoint_config = self.setup_env( + fault_injector_client_oss_api, db_config ) - logging.info( - f"Creating one connection in the pool using node {target_node.name}." + logging.info("Creating one connection in each node's pool.") + + initial_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() ) - conn = target_node.redis_connection.connection_pool.get_connection() - logging.info(f"Connection conn: {conn._get_socket().getsockname()}") + in_use_connections = {} + for node in initial_cluster_nodes.values(): + in_use_connections[node] = ( + node.redis_connection.connection_pool.get_connection() + ) - logging.info("Trigerring failover for node covering first shard...") - failover_thread = Thread( - target=self._execute_failover, - name="failover_thread", - args=(fault_injector_client_oss_api, self.cluster_endpoint_config), + logging.info("Executing FI command that triggers the desired effect...") + trigger_effect_thread = Thread( + target=self._trigger_effect, + name="trigger_effect_thread", + args=( + fault_injector_client_oss_api, + cluster_endpoint_config, + effect_name, + trigger, + ), ) + self.maintenance_ops_threads.append(trigger_effect_thread) + trigger_effect_thread.start() - self.maintenance_ops_threads.append(failover_thread) - failover_thread.start() + logging.info("Waiting for SMIGRATING push notifications in all connections...") + for conn in in_use_connections.values(): + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=int(SLOT_SHUFFLE_TIMEOUT / 2), + connection=conn, + ) - logging.info( - f"Waiting for SMIGRATING push notifications with the existing connection: {conn}, {conn._get_socket().getsockname()}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATING_TIMEOUT, - connection=conn, - ) + logging.info("Validating connection maintenance state...") + for conn in in_use_connections.values(): + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False - logging.info( - f"Creating another connection in the pool using node {target_node.name}. " - "Validating it will also receive the notification(should be received as part of the connection setup)..." + assert len(initial_cluster_nodes) == len( + cluster_client_maint_notifications.nodes_manager.nodes_cache ) - new_conn = target_node.redis_connection.connection_pool.get_connection() - - logging.info("Validating connections maintenance state...") - assert conn.maintenance_state == MaintenanceState.MAINTENANCE - assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.should_reconnect() is False - assert new_conn.maintenance_state == MaintenanceState.MAINTENANCE - assert new_conn._sock.gettimeout() == RELAXED_TIMEOUT - assert new_conn.should_reconnect() is False + for node_key in initial_cluster_nodes.keys(): + assert ( + node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache + ) - logging.info(f"Waiting for SMIGRATED push notifications with {conn}...") + logging.info("Waiting for SMIGRATED push notifications...") + con_to_read_smigrated = random.choice(list(in_use_connections.values())) ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, + cluster_client_maint_notifications, timeout=SMIGRATED_TIMEOUT, - connection=conn, + connection=con_to_read_smigrated, ) logging.info("Validating connection state after SMIGRATED ...") - assert conn.should_reconnect() is True - assert new_conn.should_reconnect() is True - - new_conn_after_smigrated = ( - target_node.redis_connection.connection_pool.get_connection() + updated_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() ) - # TODO check what would be the correct behaviour here !!! The SMIGRATED is alreday handled soon - # so we don't fix the connections again, then what should be validated???? - # Maybe new client instance???? - assert new_conn_after_smigrated.maintenance_state == MaintenanceState.NONE - assert new_conn_after_smigrated._sock.gettimeout() == CLIENT_TIMEOUT - assert not new_conn_after_smigrated.should_reconnect() - - logging.info("Releasing connections back to the pool...") - target_node.redis_connection.connection_pool.release(conn) - target_node.redis_connection.connection_pool.release(new_conn) - target_node.redis_connection.connection_pool.release(new_conn_after_smigrated) - - failover_thread.join() - self.maintenance_ops_threads.remove(failover_thread) - -class TestClusterClientPushNotificationsWithEffectTrigger(TestPushNotificationsBase): - def extract_target_node_and_empty_node( - self, fault_injector_client, endpoints_config - ): - target_node, empty_node = ClusterOperations.find_target_node_and_empty_node( - fault_injector_client, endpoints_config + removed_nodes = set(initial_cluster_nodes.values()) - set( + updated_cluster_nodes.values() ) - logging.info(f"Using target_node: {target_node}, empty_node: {empty_node}") - return target_node, empty_node - - def delete_prev_db( - self, - fault_injector_client_oss_api: FaultInjectorClient, - db_name: str, - ): - try: - logging.info(f"Deleting database if exists: {db_name}") - existing_db_id = None - existing_db_id = ClusterOperations.find_database_id_by_name( - fault_injector_client_oss_api, db_name - ) - - if existing_db_id: - fault_injector_client_oss_api.delete_database(existing_db_id) - logging.info(f"Deleted database: {db_name}") - else: - logging.info(f"Database {db_name} does not exist.") - except Exception as e: - logging.error(f"Failed to delete database {db_name}: {e}") + assert len(removed_nodes) == 1 + removed_node = removed_nodes.pop() + assert removed_node is not None - def create_db( - self, - fault_injector_client_oss_api: FaultInjectorClient, - bdb_config: Dict[str, Any], - ): - try: - cluster_endpoint_config = fault_injector_client_oss_api.create_database( - bdb_config - ) - logging.info(f"Created database: {bdb_config['name']}") - return cluster_endpoint_config - except Exception as e: - pytest.fail(f"Failed to create database: {e}") + assert len(initial_cluster_nodes) == len(updated_cluster_nodes) + 1 - @pytest.fixture(autouse=True) - def setup_and_cleanup( - self, - fault_injector_client_oss_api: FaultInjectorClient, - ): - self.maintenance_ops_threads = [] + conn = in_use_connections.get(removed_node) + # connection will be dropped, but it is marked + # to be disconnected before released to the pool + # we don't waste time to update the timeouts and state + # so it is pointless to check those configs + assert conn is not None + assert conn.should_reconnect() is True - # Yield control to the test - yield + # validate no other connections are marked for reconnect + marked_conns_for_reconnect = 0 + for conn in in_use_connections.values(): + if conn.should_reconnect(): + marked_conns_for_reconnect += 1 + # only one connection should be marked for reconnect + # onle the one that belongs to the node that was from + # the src address of the maintenance + assert marked_conns_for_reconnect == 1 - # Cleanup code - this will run even if the test fails - logging.info("Starting cleanup...") + logging.info("Releasing connections back to the pool...") + for node, conn in in_use_connections.items(): + if node.redis_connection is None: + continue + node.redis_connection.connection_pool.release(conn) - logging.info("Waiting for maintenance operations threads to finish...") - for thread in self.maintenance_ops_threads: - thread.join() + trigger_effect_thread.join() + self.maintenance_ops_threads.remove(trigger_effect_thread) - logging.info("Cleanup finished") +class TestClusterClientCommandsExecutionWithPushNotificationsWithEffectTrigger( + TestClusterClientPushNotificationsWithEffectTriggerBase +): @pytest.mark.timeout(300) # 5 minutes timeout for this test @pytest.mark.parametrize( - "effect_name, trigger, db_config_name", - [ - ( + "effect_name, trigger, db_config, db_name", + generate_params( + _FAULT_INJECTOR_CLIENT_OSS_API, + [ SlotMigrateEffects.SLOT_SHUFFLE, - "migrate", - "maint-notifications-oss-api-slot-shuffle", - ), - ( + SlotMigrateEffects.REMOVE, + SlotMigrateEffects.ADD, SlotMigrateEffects.SLOT_SHUFFLE, - "failover", - "maint-notifications-oss-api-slot-shuffle", - ), - ], + ], + ), ) - def test_command_execution_during_slot_shuffle_without_node_replacement( + def test_command_execution_during_slot_shuffle_no_node_replacement( self, fault_injector_client_oss_api: FaultInjectorClient, effect_name: SlotMigrateEffects, trigger: str, - db_config_name: str, + db_config: dict[str, Any], + db_name: str, ): """ Test the push notifications are received when executing re cluster operations. """ - self.delete_prev_db(fault_injector_client_oss_api, db_config_name) + logging.info(f"DB name: {db_name}") - maint_notifications_cluster_bdb_config = get_bdbs_config(db_config_name) - cluster_endpoint_config = self.create_db( - fault_injector_client_oss_api, maint_notifications_cluster_bdb_config + cluster_client_maint_notifications, cluster_endpoint_config = self.setup_env( + fault_injector_client_oss_api, db_config ) + shards_count = db_config["shards_count"] + logging.info(f"Shards count: {shards_count}") + errors = Queue() if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): execution_duration = 20 else: execution_duration = 180 - socket_timeout = 3 - - cluster_client_maint_notifications = get_cluster_client_maint_notifications( - endpoints_config=cluster_endpoint_config, - disable_retries=True, - socket_timeout=socket_timeout, - enable_maintenance_notifications=True, - ) - def execute_commands(duration: int, errors: Queue): start = time.time() - shards_count = maint_notifications_cluster_bdb_config.get("shards_count", 3) + executed_commands_count = 0 keys_for_all_shards = KeyGenerationHelpers.generate_keys_for_all_shards( shards_count, prefix=f"{threading.current_thread().name}_{effect_name}_{trigger}_key", ) + + logging.info("Starting commands execution...") while time.time() - start < duration: for key in keys_for_all_shards: try: # the slot is covered by the first shard - this one will have slots migrated cluster_client_maint_notifications.set(key, "value") cluster_client_maint_notifications.get(key) + executed_commands_count += 2 except Exception as e: logging.error( f"Error in thread {threading.current_thread().name}: {e}" @@ -2040,12 +1959,12 @@ def execute_commands(duration: int, errors: Queue): errors.put( f"Command failed in thread {threading.current_thread().name}: {e}" ) + if executed_commands_count % 500 == 0: + logging.debug( + f"Executed {executed_commands_count} commands in {threading.current_thread().name}" + ) logging.debug(f"{threading.current_thread().name}: Thread ended") - cluster_nodes = ( - cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() - ) - threads = [] for i in range(10): thread = Thread( @@ -2059,6 +1978,9 @@ def execute_commands(duration: int, errors: Queue): thread.start() threads.append(thread) + logging.info("Waiting for threads to start and have a few cycles executed ...") + time.sleep(3) + logging.info("Executing FI command that triggers the desired effect...") trigger_effect_thread = Thread( target=self._trigger_effect, @@ -2079,142 +2001,19 @@ def execute_commands(duration: int, errors: Queue): trigger_effect_thread.join() self.maintenance_ops_threads.remove(trigger_effect_thread) - # validate cluster nodes - assert len(cluster_nodes) == len( - cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for node_key in cluster_nodes.keys(): - assert ( - node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for ( node ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): # validate connections settings self._validate_default_state( node.redis_connection, - expected_matching_conns_count=10, - configured_timeout=socket_timeout, + expected_matching_conns_count="all", + configured_timeout=DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT, ) logging.info( - f"Node successfully validated: {node.name}, connections: {len(self._get_all_connections_in_pool(node.redis_connection))}" + f"Node successfully validated: {node.name}, " + f"connections: {len(self._get_all_connections_in_pool(node.redis_connection))}" ) # validate no errors were raised in the command execution threads assert errors.empty(), f"Errors occurred in threads: {errors.queue}" - - @pytest.mark.timeout(300) # 5 minutes timeout for this test - @pytest.mark.parametrize( - "effect_name, trigger, db_config_name", - [ - ( - SlotMigrateEffects.SLOT_SHUFFLE, - "migrate", - "maint-notifications-oss-api-slot-shuffle", - ), - ( - SlotMigrateEffects.SLOT_SHUFFLE, - "failover", - "maint-notifications-oss-api-slot-shuffle", - ), - ], - ) - def test_notification_handling_during_node_shuffle_without_node_replacement( - self, - fault_injector_client_oss_api: FaultInjectorClient, - effect_name: SlotMigrateEffects, - trigger: str, - db_config_name: str, - ): - """ - Test the push notifications are received when executing re cluster operations. - The test validates the behavior when during the operations the slots are moved - between the nodes, but no new nodes are appearing and no nodes are disappearing - - """ - self.delete_prev_db(fault_injector_client_oss_api, db_config_name) - - maint_notifications_cluster_bdb_config = get_bdbs_config(db_config_name) - cluster_endpoint_config = self.create_db( - fault_injector_client_oss_api, maint_notifications_cluster_bdb_config - ) - - socket_timeout = 3 - - cluster_client_maint_notifications = get_cluster_client_maint_notifications( - endpoints_config=cluster_endpoint_config, - disable_retries=True, - socket_timeout=socket_timeout, - enable_maintenance_notifications=True, - ) - - logging.info("Creating one connection in the pool.") - # get the node covering first shard - it is the node we will have migrated slots - target_node = ( - cluster_client_maint_notifications.nodes_manager.get_node_from_slot(0) - ) - conn = target_node.redis_connection.connection_pool.get_connection() - cluster_nodes = ( - cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() - ) - - logging.info("Executing FI command that triggers the desired effect...") - trigger_effect_thread = Thread( - target=self._trigger_effect, - name="trigger_effect_thread", - args=( - fault_injector_client_oss_api, - cluster_endpoint_config, - effect_name, - trigger, - ), - ) - self.maintenance_ops_threads.append(trigger_effect_thread) - trigger_effect_thread.start() - - logging.info("Waiting for SMIGRATING push notifications...") - ClientValidations.wait_push_notification( - cluster_client_maint_notifications, - timeout=int(SLOT_SHUFFLE_TIMEOUT / 2), - connection=conn, - ) - - logging.info("Validating connection maintenance state...") - assert conn.maintenance_state == MaintenanceState.MAINTENANCE - assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.should_reconnect() is False - - assert len(cluster_nodes) == len( - cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for node_key in cluster_nodes.keys(): - assert ( - node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - - logging.info("Waiting for SMIGRATED push notifications...") - ClientValidations.wait_push_notification( - cluster_client_maint_notifications, - timeout=SMIGRATED_TIMEOUT, - connection=conn, - ) - - logging.info("Validating connection state after SMIGRATED ...") - - assert conn.should_reconnect() is True - - # the overall number of nodes should be the same - one removed and one added - assert len(cluster_nodes) == len( - cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - for node_key in cluster_nodes.keys(): - assert ( - node_key in cluster_client_maint_notifications.nodes_manager.nodes_cache - ) - - logging.info("Releasing connection back to the pool...") - target_node.redis_connection.connection_pool.release(conn) - - trigger_effect_thread.join() - self.maintenance_ops_threads.remove(trigger_effect_thread) From 7bd6f1bec07584a262d26e3810857933710b3ed0 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 4 Feb 2026 14:35:13 +0200 Subject: [PATCH 2/5] Applying review comments and adding additional test step to ensure connections states --- redis/_parsers/base.py | 3 +- .../maint_notifications_helpers.py | 2 +- .../test_scenario/test_maint_notifications.py | 58 +++++++++++++------ 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f4f91549b2..6fb49563a7 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -202,8 +202,7 @@ def parse_oss_maintenance_completed_msg(response): src_node_str = safe_str(src_node) node_str = safe_str(node) slots_str = safe_str(slots) - # The src_node_address is not provided in the SMIGRATED message, - # so we use an empty string as a placeholder + mapping = NodesToSlotsMapping( src_node_address=src_node_str, dest_node_address=node_str, diff --git a/tests/test_scenario/maint_notifications_helpers.py b/tests/test_scenario/maint_notifications_helpers.py index eb23f7cc7b..4bd1319aac 100644 --- a/tests/test_scenario/maint_notifications_helpers.py +++ b/tests/test_scenario/maint_notifications_helpers.py @@ -45,7 +45,7 @@ def release_connection( @staticmethod def wait_push_notification( redis_client: Union[Redis, RedisCluster], - timeout: int = 120, + timeout: float = 120, fail_on_timeout: bool = True, connection: Optional[Connection] = None, ): diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index e51eb55d1c..9d56eb42d7 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -1422,28 +1422,31 @@ def generate_params( ): # params should produce list of tuples: (effect_name, trigger_name, bdb_config, bdb_name) params = [] - logging.info(f"Extracting params for test with effect_names: {effect_names}") - for effect_name in effect_names: - triggers_data = ClusterOperations.get_slot_migrate_triggers( - fault_injector_client, effect_name - ) + try: + logging.info(f"Extracting params for test with effect_names: {effect_names}") + for effect_name in effect_names: + triggers_data = ClusterOperations.get_slot_migrate_triggers( + fault_injector_client, effect_name + ) - for trigger_info in triggers_data["triggers"]: - trigger = trigger_info["name"] - if trigger == "maintenance_mode": - continue - trigger_requirements = trigger_info["requirements"] - for requirement in trigger_requirements: - dbconfig = requirement["dbconfig"] - ip_type = requirement["oss_cluster_api"]["ip_type"] - if ip_type == "internal": + for trigger_info in triggers_data["triggers"]: + trigger = trigger_info["name"] + if trigger == "maintenance_mode": continue - db_name_pattern = dbconfig.get("name").rsplit("-", 1)[0] - dbconfig["name"] = ( - db_name_pattern # this will ensure dbs will be deleted - ) + trigger_requirements = trigger_info["requirements"] + for requirement in trigger_requirements: + dbconfig = requirement["dbconfig"] + ip_type = requirement["oss_cluster_api"]["ip_type"] + if ip_type == "internal": + continue + db_name_pattern = dbconfig.get("name").rsplit("-", 1)[0] + dbconfig["name"] = ( + db_name_pattern # this will ensure dbs will be deleted + ) - params.append((effect_name, trigger, dbconfig, db_name_pattern)) + params.append((effect_name, trigger, dbconfig, db_name_pattern)) + except Exception as e: + logging.error(f"Failed to extract params for test: {e}") return params @@ -2001,6 +2004,23 @@ def execute_commands(duration: int, errors: Queue): trigger_effect_thread.join() self.maintenance_ops_threads.remove(trigger_effect_thread) + # go through all nodes and all their connections and consume the buffers - to validate no + # notifications were left unconsumed + logging.info( + "Consuming all buffers to validate no notifications were left unconsumed..." + ) + for ( + node + ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): + if node.redis_connection is None: + continue + for conn in self._get_all_connections_in_pool(node.redis_connection): + if conn._sock: + while conn.can_read(timeout=0.2): + conn.read_response(push_request=True) + logging.info(f"Consumed all buffers for node: {node.name}") + logging.info("All buffers consumed.") + for ( node ) in cluster_client_maint_notifications.nodes_manager.nodes_cache.values(): From c31ee8e4171b7d0358ff8ffca0fb1df2434fcf07 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 4 Feb 2026 15:19:11 +0200 Subject: [PATCH 3/5] Migarting the last test to use the FI new actions structure --- .../test_scenario/test_maint_notifications.py | 315 ++++++++---------- 1 file changed, 136 insertions(+), 179 deletions(-) diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index 9d56eb42d7..475f4bba85 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -1235,185 +1235,11 @@ def execute_commands(duration: int, errors: Queue): assert errors.empty(), f"Errors occurred in threads: {errors.queue}" -class TestClusterClientPushNotifications(TestPushNotificationsBase): - def extract_target_node_and_empty_node( - self, fault_injector_client, endpoints_config - ): - target_node, empty_node = ClusterOperations.find_target_node_and_empty_node( - fault_injector_client, endpoints_config - ) - logging.info(f"Using target_node: {target_node}, empty_node: {empty_node}") - return target_node, empty_node - - @pytest.fixture(autouse=True) - def setup_and_cleanup( - self, - fault_injector_client_oss_api: FaultInjectorClient, - maint_notifications_cluster_bdb_config: Dict[str, Any], - ): - # Initialize cleanup flags first to ensure they exist even if setup fails - self.cluster_endpoint_config = None - self.maintenance_ops_threads = [] - - self._bdb_config = maint_notifications_cluster_bdb_config.copy() - self._bdb_name = self._bdb_config["name"] - - try: - logging.info(f"Test Setup: Deleting database if exists: {self._bdb_name}") - existing_db_id = None - existing_db_id = ClusterOperations.find_database_id_by_name( - fault_injector_client_oss_api, self._bdb_name - ) - - if existing_db_id: - fault_injector_client_oss_api.delete_database(existing_db_id) - logging.info(f"Deleting database if exists: {self._bdb_name}") - else: - logging.info(f"Database {self._bdb_name} does not exist.") - except Exception as e: - logging.error(f"Failed to delete database {self._bdb_name}: {e}") - - try: - self.cluster_endpoint_config = ( - fault_injector_client_oss_api.create_database(self._bdb_config) - ) - logging.info(f"Test Setup: Created database: {self._bdb_name}") - except Exception as e: - pytest.fail(f"Failed to create database: {e}") - - self.cluster_client_maint_notifications = ( - get_cluster_client_maint_notifications(self.cluster_endpoint_config) - ) - - # Yield control to the test - yield - - # Cleanup code - this will run even if the test fails - logging.info("Starting cleanup...") - try: - self.cluster_client_maint_notifications.close() - except Exception as e: - logging.error(f"Failed to close client: {e}") - - logging.info("Waiting for maintenance operations threads to finish...") - for thread in self.maintenance_ops_threads: - thread.join() - - logging.info("Cleanup finished") - - @pytest.mark.timeout(300) # 5 minutes timeout for this test - @pytest.mark.skipif( - use_mock_proxy(), - reason="Mock proxy doesn't support sending notifications to new connections.", - ) - def test_new_connections_receive_last_notification_with_migrating( - self, - fault_injector_client_oss_api: FaultInjectorClient, - ): - """ - Test the push notifications are sent to the newly created connections. - - """ - cluster_op_target_node, cluster_op_empty_node = ( - self.extract_target_node_and_empty_node( - fault_injector_client_oss_api, self.cluster_endpoint_config - ) - ) - db_port = ( - self.cluster_endpoint_config["raw_endpoints"][0]["port"] - if self.cluster_endpoint_config - else None - ) - # get the node that will be migrated - target_node = self.cluster_client_maint_notifications.nodes_manager.get_node( - host=cluster_op_target_node.external_address, - port=db_port, - ) - logging.info( - f"Creating one connection in the pool using node {target_node.name}." - ) - conn = target_node.redis_connection.connection_pool.get_connection() - - logging.info("Executing migrate all data from one node to another ...") - migrate_thread = Thread( - target=self._execute_migration, - name="migrate_thread", - args=( - fault_injector_client_oss_api, - self.cluster_endpoint_config, - cluster_op_target_node.node_id, - cluster_op_empty_node.node_id, - ), - ) - - self.maintenance_ops_threads.append(migrate_thread) - migrate_thread.start() - - logging.info( - f"Waiting for SMIGRATING push notifications with the existing connection: {conn}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATING_TIMEOUT, - connection=conn, - ) - - new_conn = target_node.redis_connection.connection_pool.get_connection() - logging.info( - f"Validating newly created connection will also receive the notification: {new_conn}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=1, # the notification should have already been sent once, so new conn should receive it almost immediately - connection=new_conn, - fail_on_timeout=False, - ) - - logging.info("Validating connections maintenance state...") - assert conn.maintenance_state == MaintenanceState.MAINTENANCE - assert conn._sock.gettimeout() == RELAXED_TIMEOUT - assert conn.should_reconnect() is False - - assert new_conn.maintenance_state == MaintenanceState.MAINTENANCE - assert new_conn._sock.gettimeout() == RELAXED_TIMEOUT - assert new_conn.should_reconnect() is False - - logging.info(f"Waiting for SMIGRATED push notifications with {conn}...") - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=SMIGRATED_TIMEOUT, - connection=conn, - ) - - logging.info("Validating connection state after SMIGRATED ...") - - assert conn.should_reconnect() is True - assert new_conn.should_reconnect() is True - - new_conn_after_smigrated = ( - target_node.redis_connection.connection_pool.get_connection() - ) - assert new_conn_after_smigrated.maintenance_state == MaintenanceState.NONE - assert new_conn_after_smigrated._sock.gettimeout() == CLIENT_TIMEOUT - assert not new_conn_after_smigrated.should_reconnect() - - logging.info( - f"Waiting for SMIGRATED push notifications with another new connection: {new_conn_after_smigrated}..." - ) - ClientValidations.wait_push_notification( - self.cluster_client_maint_notifications, - timeout=1, - connection=new_conn_after_smigrated, - fail_on_timeout=False, - ) - - logging.info("Releasing connections back to the pool...") - target_node.redis_connection.connection_pool.release(conn) - target_node.redis_connection.connection_pool.release(new_conn) - target_node.redis_connection.connection_pool.release(new_conn_after_smigrated) - - migrate_thread.join() - self.maintenance_ops_threads.remove(migrate_thread) +# 5 minutes timeout for this test +# @pytest.mark.skipif( +# use_mock_proxy(), +# reason="Mock proxy doesn't support sending notifications to new connections.", +# ) def generate_params( @@ -1896,6 +1722,137 @@ def test_notification_handling_with_node_remove( trigger_effect_thread.join() self.maintenance_ops_threads.remove(trigger_effect_thread) + @pytest.mark.timeout(300) # 5 minutes timeout for this test + @pytest.mark.skipif( + use_mock_proxy(), + reason="Mock proxy doesn't support sending notifications to new connections.", + ) + @pytest.mark.parametrize( + "effect_name, trigger, db_config, db_name", + generate_params( + _FAULT_INJECTOR_CLIENT_OSS_API, + [ + SlotMigrateEffects.SLOT_SHUFFLE, + SlotMigrateEffects.REMOVE_ADD, + SlotMigrateEffects.REMOVE, + SlotMigrateEffects.ADD, + ], + ), + ) + def test_new_connections_receive_last_notification_with_migrating( + self, + fault_injector_client_oss_api: FaultInjectorClient, + effect_name: SlotMigrateEffects, + trigger: str, + db_config: dict[str, Any], + db_name: str, + ): + """ + Test the push notifications are sent to the newly created connections. + + """ + logging.info(f"DB name: {db_name}") + + cluster_client_maint_notifications, cluster_endpoint_config = self.setup_env( + fault_injector_client_oss_api, db_config + ) + + logging.info("Creating one connection in each node's pool.") + initial_cluster_nodes = ( + cluster_client_maint_notifications.nodes_manager.nodes_cache.copy() + ) + in_use_connections = {} + for node in initial_cluster_nodes.values(): + in_use_connections[node] = [ + node.redis_connection.connection_pool.get_connection() + ] + + logging.info("Executing FI command that triggers the desired effect...") + trigger_effect_thread = Thread( + target=self._trigger_effect, + name="trigger_effect_thread", + args=( + fault_injector_client_oss_api, + cluster_endpoint_config, + effect_name, + trigger, + ), + ) + + self.maintenance_ops_threads.append(trigger_effect_thread) + trigger_effect_thread.start() + + logging.info("Waiting for SMIGRATING push notifications in all connections...") + for conns_per_node in in_use_connections.values(): + for conn in conns_per_node: + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=int(SLOT_SHUFFLE_TIMEOUT / 2), + connection=conn, + ) + logging.info("Validating connection maintenance state...") + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False + + logging.info("Validating newly created connections receive the notification...") + for node in initial_cluster_nodes.values(): + conn = node.redis_connection.connection_pool.get_connection() + in_use_connections[node].append(conn) + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=1, + connection=conn, + fail_on_timeout=False, # it might get read during handshake + ) + logging.info("Validating new connection maintenance state...") + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAXED_TIMEOUT + assert conn.should_reconnect() is False + + logging.info("Waiting for SMIGRATED push notifications in all connections...") + marked_conns_for_reconnect = 0 + for conns_per_node in in_use_connections.values(): + for conn in conns_per_node: + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=SMIGRATED_TIMEOUT, + connection=conn, + ) + logging.info("Validating connection state after SMIGRATED ...") + if conn.should_reconnect(): + marked_conns_for_reconnect += 1 + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout == CLIENT_TIMEOUT + assert conn.socket_connect_timeout == CLIENT_TIMEOUT + assert ( + marked_conns_for_reconnect >= 1 + ) # at least one should be marked for reconnect + + logging.info( + "Validating newly created connections receive the SMIGRATED notification..." + ) + for node in initial_cluster_nodes.values(): + conn = node.redis_connection.connection_pool.get_connection() + in_use_connections[node].append(conn) + ClientValidations.wait_push_notification( + cluster_client_maint_notifications, + timeout=1, + connection=conn, + fail_on_timeout=True, + ) + # how to detect it???? + + logging.info("Releasing connections back to the pool...") + for node, conns in in_use_connections.items(): + if node.redis_connection is None: + continue + for conn in conns: + node.redis_connection.connection_pool.release(conn) + + trigger_effect_thread.join() + self.maintenance_ops_threads.remove(trigger_effect_thread) + class TestClusterClientCommandsExecutionWithPushNotificationsWithEffectTrigger( TestClusterClientPushNotificationsWithEffectTriggerBase From c2158c82c8b0142356ca7175036d230a8a546ce4 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 4 Feb 2026 18:31:34 +0200 Subject: [PATCH 4/5] Completed migration of all oss api hitless tests to use the new FI helpers. --- redis/maint_notifications.py | 6 ++- .../maint_notifications_helpers.py | 5 +- .../test_scenario/test_maint_notifications.py | 48 +++++++++---------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index d7642fd9d5..f861e9a710 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -564,17 +564,19 @@ def add_debug_log_for_notification( notification: Union[str, MaintenanceNotification], ): if logging.getLogger().isEnabledFor(logging.DEBUG): + socket_address = None try: socket_address = ( connection._sock.getsockname() if connection._sock else None ) + socket_address = socket_address[1] if socket_address else None except (AttributeError, OSError): - socket_address = None + pass logging.debug( f"Handling maintenance notification: {notification}, " f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, " - f"socket_address: {socket_address}", + f"local socket port: {socket_address}", ) diff --git a/tests/test_scenario/maint_notifications_helpers.py b/tests/test_scenario/maint_notifications_helpers.py index 4bd1319aac..93ad01b49f 100644 --- a/tests/test_scenario/maint_notifications_helpers.py +++ b/tests/test_scenario/maint_notifications_helpers.py @@ -57,7 +57,10 @@ def wait_push_notification( if connection else ClientValidations.get_default_connection(redis_client) ) - logging.info(f"Waiting for push notification on connection: {test_conn}") + logging.info( + f"Waiting for push notification on connection: {test_conn}, " + f"local socket port: {test_conn._sock.getsockname()[1] if test_conn._sock else None}" + ) try: while time.time() - start_time < timeout: diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index 475f4bba85..aa30ddefa0 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -1790,12 +1790,16 @@ def test_new_connections_receive_last_notification_with_migrating( timeout=int(SLOT_SHUFFLE_TIMEOUT / 2), connection=conn, ) - logging.info("Validating connection maintenance state...") + logging.info( + f"Validating connection MAINTENANCE state and RELAXED timeout for conn: {conn}..." + ) assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn._sock.gettimeout() == RELAXED_TIMEOUT assert conn.should_reconnect() is False - logging.info("Validating newly created connections receive the notification...") + logging.info( + "Validating newly created connections will receive the SMIGRATING notification..." + ) for node in initial_cluster_nodes.values(): conn = node.redis_connection.connection_pool.get_connection() in_use_connections[node].append(conn) @@ -1805,12 +1809,16 @@ def test_new_connections_receive_last_notification_with_migrating( connection=conn, fail_on_timeout=False, # it might get read during handshake ) - logging.info("Validating new connection maintenance state...") + logging.info( + f"Validating new connection MAINTENANCE state and RELAXED timeout for conn: {conn}..." + ) assert conn.maintenance_state == MaintenanceState.MAINTENANCE assert conn._sock.gettimeout() == RELAXED_TIMEOUT assert conn.should_reconnect() is False - logging.info("Waiting for SMIGRATED push notifications in all connections...") + logging.info( + "Waiting for SMIGRATED push notifications in ALL EXISTING connections..." + ) marked_conns_for_reconnect = 0 for conns_per_node in in_use_connections.values(): for conn in conns_per_node: @@ -1819,30 +1827,22 @@ def test_new_connections_receive_last_notification_with_migrating( timeout=SMIGRATED_TIMEOUT, connection=conn, ) - logging.info("Validating connection state after SMIGRATED ...") - if conn.should_reconnect(): - marked_conns_for_reconnect += 1 - assert conn.maintenance_state == MaintenanceState.NONE - assert conn.socket_timeout == CLIENT_TIMEOUT - assert conn.socket_connect_timeout == CLIENT_TIMEOUT + logging.info( + f"Validating connection state after SMIGRATED for conn: {conn}, " + f"local socket port: {conn._sock.getsockname()[1] if conn._sock else None}..." + ) + if conn.should_reconnect(): + logging.info(f"Connection marked for reconnect: {conn}") + marked_conns_for_reconnect += 1 + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout == DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT + assert ( + conn.socket_connect_timeout == DEFAULT_OSS_API_CLIENT_SOCKET_TIMEOUT + ) assert ( marked_conns_for_reconnect >= 1 ) # at least one should be marked for reconnect - logging.info( - "Validating newly created connections receive the SMIGRATED notification..." - ) - for node in initial_cluster_nodes.values(): - conn = node.redis_connection.connection_pool.get_connection() - in_use_connections[node].append(conn) - ClientValidations.wait_push_notification( - cluster_client_maint_notifications, - timeout=1, - connection=conn, - fail_on_timeout=True, - ) - # how to detect it???? - logging.info("Releasing connections back to the pool...") for node, conns in in_use_connections.items(): if node.redis_connection is None: From 9ab9e71bebc25d97e649753b261684497f253894 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 9 Feb 2026 10:52:07 +0200 Subject: [PATCH 5/5] Applying review comments --- redis/_parsers/base.py | 27 ++-- redis/maint_notifications.py | 59 ++++--- .../proxy_server_helpers.py | 2 - .../test_maint_notifications.py | 149 +++--------------- .../test_scenario/test_maint_notifications.py | 2 +- 5 files changed, 81 insertions(+), 158 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 6fb49563a7..c2fa13f88d 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -11,7 +11,6 @@ NodeMigratedNotification, NodeMigratingNotification, NodeMovingNotification, - NodesToSlotsMapping, OSSNodeMigratedNotification, OSSNodeMigratingNotification, ) @@ -193,22 +192,26 @@ def parse_oss_maintenance_start_msg(response): @staticmethod def parse_oss_maintenance_completed_msg(response): # Expected message format is: - # SMIGRATED [ , ...] + # SMIGRATED [[ ], ...] id = response[1] nodes_to_slots_mapping_data = response[2] - nodes_to_slots_mapping = [] - for src_node, node, slots in nodes_to_slots_mapping_data: - # Parse the node address to extract host and port + # Build the nodes_to_slots_mapping dict structure: + # { + # "src_host:port": [ + # {"dest_host:port": "slot_range"}, + # ... + # ], + # ... + # } + nodes_to_slots_mapping = {} + for src_node, dest_node, slots in nodes_to_slots_mapping_data: src_node_str = safe_str(src_node) - node_str = safe_str(node) + dest_node_str = safe_str(dest_node) slots_str = safe_str(slots) - mapping = NodesToSlotsMapping( - src_node_address=src_node_str, - dest_node_address=node_str, - slots=slots_str, - ) - nodes_to_slots_mapping.append(mapping) + if src_node_str not in nodes_to_slots_mapping: + nodes_to_slots_mapping[src_node_str] = [] + nodes_to_slots_mapping[src_node_str].append({dest_node_str: slots_str}) return OSSNodeMigratedNotification(id, nodes_to_slots_mapping) diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index f861e9a710..32a53fd571 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -5,8 +5,7 @@ import threading import time from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union from redis.typing import Number @@ -455,13 +454,6 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, int(self.id))) -@dataclass -class NodesToSlotsMapping: - src_node_address: str - dest_node_address: str - slots: str - - class OSSNodeMigratedNotification(MaintenanceNotification): """ Notification for when a Redis OSS API client is used and a node has completed migrating slots. @@ -471,7 +463,25 @@ class OSSNodeMigratedNotification(MaintenanceNotification): Args: id (int): Unique identifier for this notification - nodes_to_slots_mapping (List[NodesToSlotsMapping]): List of node-to-slots mappings + nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address + to list of destination mappings. Each destination mapping is a dict with + the destination node address as key and the slot range as value. + + Structure example: + { + "127.0.0.1:6379": [ + {"127.0.0.1:6380": "1-100"}, + {"127.0.0.1:6381": "101-200"} + ], + "127.0.0.1:6382": [ + {"127.0.0.1:6383": "201-300"} + ] + } + + Where: + - Key (str): Source node address in "host:port" format + - Value (List[Dict[str, str]]): List of destination mappings where each dict + contains destination node address as key and slot range as value """ DEFAULT_TTL = 120 @@ -479,7 +489,7 @@ class OSSNodeMigratedNotification(MaintenanceNotification): def __init__( self, id: int, - nodes_to_slots_mapping: List[NodesToSlotsMapping], + nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]], ): super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) self.nodes_to_slots_mapping = nodes_to_slots_mapping @@ -1035,21 +1045,34 @@ def handle_oss_maintenance_completed_notification( logging.debug(f"Handling SMIGRATED notification: {notification}") self._in_progress.add(notification) - # Extract the information about the src and destination nodes that are affected by the maintenance + # Extract the information about the src and destination nodes that are affected + # by the maintenance. nodes_to_slots_mapping structure: + # { + # "src_host:port": [ + # {"dest_host:port": "slot_range"}, + # ... + # ], + # ... + # } additional_startup_nodes_info = [] affected_nodes = set() - for mapping in notification.nodes_to_slots_mapping: - new_node_host, new_node_port = mapping.dest_node_address.split(":") - src_host, src_port = mapping.src_node_address.split(":") + for ( + src_address, + dest_mappings, + ) in notification.nodes_to_slots_mapping.items(): + src_host, src_port = src_address.split(":") src_node = self.cluster_client.nodes_manager.get_node( host=src_host, port=src_port ) if src_node is not None: affected_nodes.add(src_node) - additional_startup_nodes_info.append( - (new_node_host, int(new_node_port)) - ) + for dest_mapping in dest_mappings: + for dest_address in dest_mapping.keys(): + dest_host, dest_port = dest_address.split(":") + additional_startup_nodes_info.append( + (dest_host, int(dest_port)) + ) # Updates the cluster slots cache with the new slots mapping # This will also update the nodes cache with the new nodes mapping diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index fa117510cb..036d7b257e 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -1,7 +1,5 @@ import base64 from dataclasses import dataclass -import logging -from typing import Union from redis.http.http_client import HttpClient, HttpError diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index f01637b757..0b485c257a 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -11,7 +11,6 @@ NodeMigratedNotification, NodeFailingOverNotification, NodeFailedOverNotification, - NodesToSlotsMapping, OSSNodeMigratingNotification, OSSNodeMigratedNotification, MaintNotificationsConfig, @@ -494,13 +493,7 @@ class TestOSSNodeMigratedNotification: def test_init_with_defaults(self): """Test OSSNodeMigratedNotification initialization with default values.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = [ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ] + nodes_to_slots_mapping = {"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]} notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping ) @@ -512,18 +505,12 @@ def test_init_with_defaults(self): def test_init_with_all_parameters(self): """Test OSSNodeMigratedNotification initialization with all parameters.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = [ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ), - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6381", - slots="101-200", - ), - ] + nodes_to_slots_mapping = { + "127.0.0.1:6379": [ + {"127.0.0.1:6380": "1-100"}, + {"127.0.0.1:6381": "101-200"}, + ] + } notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping, @@ -538,26 +525,14 @@ def test_default_ttl(self): assert OSSNodeMigratedNotification.DEFAULT_TTL == 120 notification = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) assert notification.ttl == 120 def test_repr(self): """Test OSSNodeMigratedNotification string representation.""" with patch("time.monotonic", return_value=1000): - nodes_to_slots_mapping = [ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ] + nodes_to_slots_mapping = {"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]} notification = OSSNodeMigratedNotification( id=1, nodes_to_slots_mapping=nodes_to_slots_mapping, @@ -575,23 +550,11 @@ def test_equality_same_id_and_type(self): """Test equality for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6381", - slots="101-200", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6381": "101-200"}]}, ) # Should be equal because id and type are the same assert notification1 == notification2 @@ -600,23 +563,11 @@ def test_equality_different_id(self): """Test inequality for notifications with different id.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = OSSNodeMigratedNotification( id=2, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) assert notification1 != notification2 @@ -624,13 +575,7 @@ def test_equality_different_type(self): """Test inequality for notifications of different types.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = NodeMigratedNotification(id=1) assert notification1 != notification2 @@ -639,23 +584,11 @@ def test_hash_same_id_and_type(self): """Test hash for notifications with same id and type.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6381", - slots="101-200", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6381": "101-200"}]}, ) # Should have same hash because id and type are the same assert hash(notification1) == hash(notification2) @@ -664,23 +597,11 @@ def test_hash_different_id(self): """Test hash for notifications with different id.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = OSSNodeMigratedNotification( id=2, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) assert hash(notification1) != hash(notification2) @@ -688,43 +609,19 @@ def test_in_set(self): """Test that notifications can be used in sets.""" notification1 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification2 = OSSNodeMigratedNotification( id=1, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6380", - slots="1-100", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6380": "1-100"}]}, ) notification3 = OSSNodeMigratedNotification( id=2, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6381", - slots="101-200", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6381": "101-200"}]}, ) notification4 = OSSNodeMigratedNotification( id=2, - nodes_to_slots_mapping=[ - NodesToSlotsMapping( - src_node_address="127.0.0.1:6379", - dest_node_address="127.0.0.1:6381", - slots="101-200", - ) - ], + nodes_to_slots_mapping={"127.0.0.1:6379": [{"127.0.0.1:6381": "101-200"}]}, ) notification_set = {notification1, notification2, notification3, notification4} @@ -961,6 +858,8 @@ class TestMaintNotificationsConnectionHandler: def setup_method(self): """Set up test fixtures.""" self.mock_connection = Mock() + # Configure _sock.getsockname() to return a proper tuple (host, port) + self.mock_connection._sock.getsockname.return_value = ("127.0.0.1", 12345) self.config = MaintNotificationsConfig(enabled=True, relaxed_timeout=20) self.handler = MaintNotificationsConnectionHandler( self.mock_connection, self.config diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index aa30ddefa0..2ea4b7a7d5 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -1894,7 +1894,7 @@ def test_command_execution_during_slot_shuffle_no_node_replacement( if isinstance(fault_injector_client_oss_api, ProxyServerFaultInjector): execution_duration = 20 else: - execution_duration = 180 + execution_duration = 40 def execute_commands(duration: int, errors: Queue): start = time.time()