55import threading
66import time
77from abc import ABC , abstractmethod
8- from typing import TYPE_CHECKING , Dict , List , Literal , Optional , Union
8+ from dataclasses import dataclass
9+ from typing import TYPE_CHECKING , List , Literal , Optional , Union
910
1011from redis .typing import Number
1112
@@ -454,6 +455,13 @@ def __hash__(self) -> int:
454455 return hash ((self .__class__ .__name__ , int (self .id )))
455456
456457
458+ @dataclass
459+ class NodesToSlotsMapping :
460+ src_node_address : str
461+ node_address : str
462+ slots : str
463+
464+
457465class OSSNodeMigratedNotification (MaintenanceNotification ):
458466 """
459467 Notification for when a Redis OSS API client is used and a node has completed migrating slots.
@@ -463,15 +471,15 @@ class OSSNodeMigratedNotification(MaintenanceNotification):
463471
464472 Args:
465473 id (int): Unique identifier for this notification
466- nodes_to_slots_mapping (Dict[str, str ]): Mapping of node addresses to slots
474+ nodes_to_slots_mapping (List[NodesToSlotsMapping ]): List of node-to- slots mappings
467475 """
468476
469477 DEFAULT_TTL = 30
470478
471479 def __init__ (
472480 self ,
473481 id : int ,
474- nodes_to_slots_mapping : Dict [ str , str ],
482+ nodes_to_slots_mapping : List [ NodesToSlotsMapping ],
475483 ):
476484 super ().__init__ (id , OSSNodeMigratedNotification .DEFAULT_TTL )
477485 self .nodes_to_slots_mapping = nodes_to_slots_mapping
@@ -967,10 +975,6 @@ def __init__(
967975 self ._processed_notifications = set ()
968976 self ._in_progress = set ()
969977 self ._lock = threading .RLock ()
970- self .connection = None
971-
972- def set_connection (self , connection : "MaintNotificationsAbstractConnection" ):
973- self .connection = connection
974978
975979 def get_handler_for_connection (self ):
976980 # Copy all data that should be shared between connections
@@ -980,7 +984,6 @@ def get_handler_for_connection(self):
980984 copy ._processed_notifications = self ._processed_notifications
981985 copy ._in_progress = self ._in_progress
982986 copy ._lock = self ._lock
983- copy .connection = None
984987 return copy
985988
986989 def remove_expired_notifications (self ):
@@ -1011,55 +1014,56 @@ def handle_oss_maintenance_completed_notification(
10111014 # that has also has the notification and we don't want to
10121015 # process the same notification twice
10131016 return
1014- if self .connection is None :
1015- logging .error (
1016- "Connection is not set for OSSMaintNotificationsHandler. "
1017- f"Failed to handle notification: { notification } "
1018- )
1019- return
10201017
1021- logging .debug (
1022- f"Handling SMIGRATED notification: { notification } with connection: { self .connection } , connected to ip { self .connection .get_resolved_ip ()} "
1023- )
1018+ logging .debug (f"Handling SMIGRATED notification: { notification } " )
10241019 self ._in_progress .add (notification )
10251020
1026- # get the node to which the connection is connected
1027- # before refreshing the cluster topology
1028- current_node = self .cluster_client .nodes_manager .get_node (
1029- host = self .connection .host , port = self .connection .port
1030- )
1031-
1032- # Updates the cluster slots cache with the new slots mapping
1033- # This will also update the nodes cache with the new nodes mapping
1021+ # Extract the information about the src and destination nodes that are affected by the maintenance
10341022 additional_startup_nodes_info = []
1035- for node_address , _ in notification .nodes_to_slots_mapping .items ():
1036- new_node_host , new_node_port = node_address .split (":" )
1023+ affected_nodes = set ()
1024+ for mapping in notification .nodes_to_slots_mapping :
1025+ new_node_host , new_node_port = mapping .node_address .split (":" )
1026+ src_host , src_port = mapping .src_node_address .split (":" )
1027+ src_node = self .cluster_client .nodes_manager .get_node (
1028+ host = src_host , port = src_port
1029+ )
1030+ if src_node is not None :
1031+ affected_nodes .add (src_node )
1032+
10371033 additional_startup_nodes_info .append (
10381034 (new_node_host , int (new_node_port ))
10391035 )
10401036
1037+ # Updates the cluster slots cache with the new slots mapping
1038+ # This will also update the nodes cache with the new nodes mapping
10411039 self .cluster_client .nodes_manager .initialize (
10421040 disconnect_startup_nodes_pools = False ,
10431041 additional_startup_nodes_info = additional_startup_nodes_info ,
10441042 )
10451043
1046- with current_node .redis_connection .connection_pool ._lock :
1047- # mark for reconnect all in use connections to the node - this will force them to
1048- # disconnect after they complete their current commands
1049- # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1050- # all in flight connections after they are done with current command execution
1051- for conn in current_node .redis_connection .connection_pool ._get_in_use_connections ():
1052- conn .mark_for_reconnect ()
1044+ all_nodes = self .cluster_client .nodes_manager .nodes_cache .values ()
1045+
1046+ for current_node in all_nodes :
1047+ if current_node .redis_connection is None :
1048+ continue
1049+ with current_node .redis_connection .connection_pool ._lock :
1050+ if current_node in affected_nodes :
1051+ # mark for reconnect all in use connections to the node - this will force them to
1052+ # disconnect after they complete their current commands
1053+ # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1054+ # all in flight connections after they are done with current command execution
1055+ for conn in current_node .redis_connection .connection_pool ._get_in_use_connections ():
1056+ conn .mark_for_reconnect ()
1057+
1058+ # if (
1059+ # current_node
1060+ # not in self.cluster_client.nodes_manager.nodes_cache.values()
1061+ # ):
1062+ # # disconnect all free connections to the node - this node will be dropped
1063+ # # from the cluster, so we don't need to revert the timeouts
1064+ # for conn in current_node.redis_connection.connection_pool._get_free_connections():
1065+ # conn.disconnect()
10531066
1054- if (
1055- current_node
1056- not in self .cluster_client .nodes_manager .nodes_cache .values ()
1057- ):
1058- # disconnect all free connections to the node - this node will be dropped
1059- # from the cluster, so we don't need to revert the timeouts
1060- for conn in current_node .redis_connection .connection_pool ._get_free_connections ():
1061- conn .disconnect ()
1062- else :
10631067 if self .config .is_relaxed_timeouts_enabled ():
10641068 # reset the timeouts for the node to which the connection is connected
10651069 # Perform check if other maintenance ops are in progress for the same node
0 commit comments