Skip to content

Commit 8f550b6

Browse files
committed
Refactoring the SMIGRATED flow - the notification is changed to contain the src node address for each slot range movement.
1 parent e07381c commit 8f550b6

File tree

10 files changed

+377
-267
lines changed

10 files changed

+377
-267
lines changed

redis/_parsers/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
NodeMigratedNotification,
1212
NodeMigratingNotification,
1313
NodeMovingNotification,
14+
NodesToSlotsMapping,
1415
OSSNodeMigratedNotification,
1516
OSSNodeMigratingNotification,
1617
)
@@ -195,9 +196,18 @@ def parse_oss_maintenance_completed_msg(response):
195196
# SMIGRATED <seq_number> [<host:port> <slot, range1-range2,...>, ...]
196197
id = response[1]
197198
nodes_to_slots_mapping_data = response[2]
198-
nodes_to_slots_mapping = {}
199-
for node, slots in nodes_to_slots_mapping_data:
200-
nodes_to_slots_mapping[safe_str(node)] = safe_str(slots)
199+
nodes_to_slots_mapping = []
200+
for src_node, node, slots in nodes_to_slots_mapping_data:
201+
# Parse the node address to extract host and port
202+
src_node_str = safe_str(src_node)
203+
node_str = safe_str(node)
204+
slots_str = safe_str(slots)
205+
# The src_node_address is not provided in the SMIGRATED message,
206+
# so we use an empty string as a placeholder
207+
mapping = NodesToSlotsMapping(
208+
src_node_address=src_node_str, node_address=node_str, slots=slots_str
209+
)
210+
nodes_to_slots_mapping.append(mapping)
201211

202212
return OSSNodeMigratedNotification(id, nodes_to_slots_mapping)
203213

redis/asyncio/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,9 @@ def __repr__(self) -> str:
12121212
def __eq__(self, obj: Any) -> bool:
12131213
return isinstance(obj, ClusterNode) and obj.name == self.name
12141214

1215+
def __hash__(self) -> int:
1216+
return hash(self.name)
1217+
12151218
_DEL_MESSAGE = "Unclosed ClusterNode object"
12161219

12171220
def __del__(

redis/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,9 @@ def __repr__(self):
16761676
def __eq__(self, obj):
16771677
return isinstance(obj, ClusterNode) and obj.name == self.name
16781678

1679+
def __hash__(self):
1680+
return hash(self.name)
1681+
16791682

16801683
class LoadBalancingStrategy(Enum):
16811684
ROUND_ROBIN = "round_robin"

redis/connection.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -440,23 +440,17 @@ def _configure_maintenance_notifications(
440440
else:
441441
self._maint_notifications_pool_handler = None
442442

443+
self._maint_notifications_connection_handler = (
444+
MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
445+
)
446+
443447
if oss_cluster_maint_notifications_handler:
444-
# Extract a reference to a new handler that copies all properties
445-
# of the original one and has a different connection reference
446-
# This is needed because when we attach the handler to the parser
447-
# we need to make sure that the handler has a reference to the
448-
# connection that the parser is attached to.
449448
self._oss_cluster_maint_notifications_handler = (
450-
oss_cluster_maint_notifications_handler.get_handler_for_connection()
449+
oss_cluster_maint_notifications_handler
451450
)
452-
self._oss_cluster_maint_notifications_handler.set_connection(self)
453451
else:
454452
self._oss_cluster_maint_notifications_handler = None
455453

456-
self._maint_notifications_connection_handler = (
457-
MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
458-
)
459-
460454
# Set up OSS cluster handler to parser if available
461455
if self._oss_cluster_maint_notifications_handler:
462456
parser.set_oss_cluster_maint_push_handler(
@@ -521,21 +515,12 @@ def set_maint_notifications_pool_handler_for_connection(
521515
def set_maint_notifications_cluster_handler_for_connection(
522516
self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
523517
):
524-
# Deep copy the cluster handler to avoid sharing the same handler
525-
# between multiple connections, because otherwise each connection will override
526-
# the connection reference and the handler will only hold a reference
527-
# to the last connection that was set.
528-
maint_notifications_cluster_handler_copy = (
529-
oss_cluster_maint_notifications_handler.get_handler_for_connection()
530-
)
531-
532-
maint_notifications_cluster_handler_copy.set_connection(self)
533518
self._get_parser().set_oss_cluster_maint_push_handler(
534-
maint_notifications_cluster_handler_copy.handle_notification
519+
oss_cluster_maint_notifications_handler.handle_notification
535520
)
536521

537522
self._oss_cluster_maint_notifications_handler = (
538-
maint_notifications_cluster_handler_copy
523+
oss_cluster_maint_notifications_handler
539524
)
540525

541526
# Update maintenance notification connection handler if it doesn't exist

redis/maint_notifications.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import threading
66
import time
77
from abc import ABC, abstractmethod
8-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
8+
from dataclasses import dataclass
9+
from typing import TYPE_CHECKING, List, Literal, Optional, Union
910

1011
from redis.typing import Number
1112

@@ -454,6 +455,13 @@ def __hash__(self) -> int:
454455
return hash((self.__class__.__name__, int(self.id)))
455456

456457

458+
@dataclass
459+
class NodesToSlotsMapping:
460+
src_node_address: str
461+
node_address: str
462+
slots: str
463+
464+
457465
class OSSNodeMigratedNotification(MaintenanceNotification):
458466
"""
459467
Notification for when a Redis OSS API client is used and a node has completed migrating slots.
@@ -463,15 +471,15 @@ class OSSNodeMigratedNotification(MaintenanceNotification):
463471
464472
Args:
465473
id (int): Unique identifier for this notification
466-
nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots
474+
nodes_to_slots_mapping (List[NodesToSlotsMapping]): List of node-to-slots mappings
467475
"""
468476

469477
DEFAULT_TTL = 30
470478

471479
def __init__(
472480
self,
473481
id: int,
474-
nodes_to_slots_mapping: Dict[str, str],
482+
nodes_to_slots_mapping: List[NodesToSlotsMapping],
475483
):
476484
super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
477485
self.nodes_to_slots_mapping = nodes_to_slots_mapping
@@ -967,10 +975,6 @@ def __init__(
967975
self._processed_notifications = set()
968976
self._in_progress = set()
969977
self._lock = threading.RLock()
970-
self.connection = None
971-
972-
def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
973-
self.connection = connection
974978

975979
def get_handler_for_connection(self):
976980
# Copy all data that should be shared between connections
@@ -980,7 +984,6 @@ def get_handler_for_connection(self):
980984
copy._processed_notifications = self._processed_notifications
981985
copy._in_progress = self._in_progress
982986
copy._lock = self._lock
983-
copy.connection = None
984987
return copy
985988

986989
def remove_expired_notifications(self):
@@ -1011,55 +1014,56 @@ def handle_oss_maintenance_completed_notification(
10111014
# that has also has the notification and we don't want to
10121015
# process the same notification twice
10131016
return
1014-
if self.connection is None:
1015-
logging.error(
1016-
"Connection is not set for OSSMaintNotificationsHandler. "
1017-
f"Failed to handle notification: {notification}"
1018-
)
1019-
return
10201017

1021-
logging.debug(
1022-
f"Handling SMIGRATED notification: {notification} with connection: {self.connection}, connected to ip {self.connection.get_resolved_ip()}"
1023-
)
1018+
logging.debug(f"Handling SMIGRATED notification: {notification}")
10241019
self._in_progress.add(notification)
10251020

1026-
# get the node to which the connection is connected
1027-
# before refreshing the cluster topology
1028-
current_node = self.cluster_client.nodes_manager.get_node(
1029-
host=self.connection.host, port=self.connection.port
1030-
)
1031-
1032-
# Updates the cluster slots cache with the new slots mapping
1033-
# This will also update the nodes cache with the new nodes mapping
1021+
# Extract the information about the src and destination nodes that are affected by the maintenance
10341022
additional_startup_nodes_info = []
1035-
for node_address, _ in notification.nodes_to_slots_mapping.items():
1036-
new_node_host, new_node_port = node_address.split(":")
1023+
affected_nodes = set()
1024+
for mapping in notification.nodes_to_slots_mapping:
1025+
new_node_host, new_node_port = mapping.node_address.split(":")
1026+
src_host, src_port = mapping.src_node_address.split(":")
1027+
src_node = self.cluster_client.nodes_manager.get_node(
1028+
host=src_host, port=src_port
1029+
)
1030+
if src_node is not None:
1031+
affected_nodes.add(src_node)
1032+
10371033
additional_startup_nodes_info.append(
10381034
(new_node_host, int(new_node_port))
10391035
)
10401036

1037+
# Updates the cluster slots cache with the new slots mapping
1038+
# This will also update the nodes cache with the new nodes mapping
10411039
self.cluster_client.nodes_manager.initialize(
10421040
disconnect_startup_nodes_pools=False,
10431041
additional_startup_nodes_info=additional_startup_nodes_info,
10441042
)
10451043

1046-
with current_node.redis_connection.connection_pool._lock:
1047-
# mark for reconnect all in use connections to the node - this will force them to
1048-
# disconnect after they complete their current commands
1049-
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
1050-
# all in flight connections after they are done with current command execution
1051-
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
1052-
conn.mark_for_reconnect()
1044+
all_nodes = self.cluster_client.nodes_manager.nodes_cache.values()
1045+
1046+
for current_node in all_nodes:
1047+
if current_node.redis_connection is None:
1048+
continue
1049+
with current_node.redis_connection.connection_pool._lock:
1050+
if current_node in affected_nodes:
1051+
# mark for reconnect all in use connections to the node - this will force them to
1052+
# disconnect after they complete their current commands
1053+
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
1054+
# all in flight connections after they are done with current command execution
1055+
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
1056+
conn.mark_for_reconnect()
1057+
1058+
# if (
1059+
# current_node
1060+
# not in self.cluster_client.nodes_manager.nodes_cache.values()
1061+
# ):
1062+
# # disconnect all free connections to the node - this node will be dropped
1063+
# # from the cluster, so we don't need to revert the timeouts
1064+
# for conn in current_node.redis_connection.connection_pool._get_free_connections():
1065+
# conn.disconnect()
10531066

1054-
if (
1055-
current_node
1056-
not in self.cluster_client.nodes_manager.nodes_cache.values()
1057-
):
1058-
# disconnect all free connections to the node - this node will be dropped
1059-
# from the cluster, so we don't need to revert the timeouts
1060-
for conn in current_node.redis_connection.connection_pool._get_free_connections():
1061-
conn.disconnect()
1062-
else:
10631067
if self.config.is_relaxed_timeouts_enabled():
10641068
# reset the timeouts for the node to which the connection is connected
10651069
# Perform check if other maintenance ops are in progress for the same node

tests/maint_notifications/proxy_server_helpers.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def oss_maint_notification_to_resp(txt: str) -> str:
4848
">3\r\n" # Push message with 3 elements
4949
f"+{notification}\r\n" # Element 1: Command
5050
f":{seq_id}\r\n" # Element 2: SeqID
51-
f"*{len(hosts_and_slots) // 2}\r\n" # Element 3: Array of host:port, slots pairs
51+
f"*{len(hosts_and_slots) // 3}\r\n" # Element 3: Array of src_host:src_port, dest_host:dest_port, slots pairs
5252
)
53-
for i in range(0, len(hosts_and_slots), 2):
54-
resp += "*2\r\n"
53+
for i in range(0, len(hosts_and_slots), 3):
54+
resp += "*3\r\n"
5555
resp += f"+{hosts_and_slots[i]}\r\n"
5656
resp += f"+{hosts_and_slots[i + 1]}\r\n"
57+
resp += f"+{hosts_and_slots[i + 2]}\r\n"
5758
else:
5859
# SMIGRATING
5960
# Format: SMIGRATING SeqID slot,range1-range2
@@ -211,20 +212,12 @@ def get_connections(self) -> dict:
211212

212213
def send_notification(
213214
self,
214-
connected_to_port: Union[int, str],
215215
notification: str,
216216
) -> dict:
217217
"""
218-
Send a notification to all connections connected to
219-
a specific node(identified by port number).
220-
221-
This method:
222-
1. Fetches stats from the interceptor server
223-
2. Finds all connection IDs connected to the specified node
224-
3. Sends the notification to each connection
218+
Send a notification to all connections.
225219
226220
Args:
227-
connected_to_port: Port number of the node to send the notification to
228221
notification: The notification message to send (in RESP format)
229222
230223
Returns:
@@ -233,32 +226,12 @@ def send_notification(
233226
Example:
234227
interceptor = ProxyInterceptorHelper(None, "http://localhost:4000")
235228
result = interceptor.send_notification(
236-
"6379",
237229
"KjENCiQ0DQpQSU5HDQo=" # PING command in base64
238230
)
239231
"""
240-
# Get stats to find connection IDs for the node
241-
stats = self.get_stats()
242-
243-
# Extract connection IDs for the specified node
244-
conn_ids = []
245-
for node_key, node_info in stats.items():
246-
node_port = node_key.split("@")[1]
247-
if int(node_port) == int(connected_to_port):
248-
for conn in node_info.get("connections", []):
249-
conn_ids.append(conn["id"])
250-
251-
if not conn_ids:
252-
raise RuntimeError(
253-
f"No connections found for node {node_port}. "
254-
f"Available nodes: {list(set(c.get('node') for c in stats.get('connections', {}).values()))}"
255-
)
256-
257-
# Send notification to each connection
232+
# Send notification to all connections
258233
results = {}
259-
logging.info(f"Sending notification to {len(conn_ids)} connections: {conn_ids}")
260-
connections_query = f"connectionIds={','.join(conn_ids)}"
261-
url = f"{self.server_url}/send-to-clients?{connections_query}&encoding=base64"
234+
url = f"{self.server_url}/send-to-all-clients?encoding=base64"
262235
# Encode notification to base64
263236
data = base64.b64encode(notification.encode("utf-8"))
264237

@@ -271,8 +244,6 @@ def send_notification(
271244
results = {"error": str(e)}
272245

273246
return {
274-
"node_address": node_port,
275-
"connection_ids": conn_ids,
276247
"results": results,
277248
}
278249

0 commit comments

Comments
 (0)