Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NodeMigratedNotification,
NodeMigratingNotification,
NodeMovingNotification,
NodesToSlotsMapping,
OSSNodeMigratedNotification,
OSSNodeMigratingNotification,
)
Expand Down Expand Up @@ -195,9 +196,20 @@ def parse_oss_maintenance_completed_msg(response):
# SMIGRATED <seq_number> [<host:port> <slot, range1-range2,...>, ...]
id = response[1]
nodes_to_slots_mapping_data = response[2]
nodes_to_slots_mapping = {}
for node, slots in nodes_to_slots_mapping_data:
nodes_to_slots_mapping[safe_str(node)] = safe_str(slots)
nodes_to_slots_mapping = []
for src_node, node, slots in nodes_to_slots_mapping_data:
# Parse the node address to extract host and port
src_node_str = safe_str(src_node)
node_str = safe_str(node)
slots_str = safe_str(slots)
# The src_node_address is not provided in the SMIGRATED message,
# so we use an empty string as a placeholder
mapping = NodesToSlotsMapping(
src_node_address=src_node_str,
dest_node_address=node_str,
slots=slots_str,
)
nodes_to_slots_mapping.append(mapping)

return OSSNodeMigratedNotification(id, nodes_to_slots_mapping)

Expand Down Expand Up @@ -341,17 +353,20 @@ def handle_push_response(self, response, **kwargs):

if notification is not None:
return self.maintenance_push_handler_func(notification)
if (
msg_type == _SMIGRATED_MESSAGE
and self.oss_cluster_maint_push_handler_func
if msg_type == _SMIGRATED_MESSAGE and (
self.oss_cluster_maint_push_handler_func
or self.maintenance_push_handler_func
):
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
msg_type
][1]
notification = parser_function(response)

if notification is not None:
return self.oss_cluster_maint_push_handler_func(notification)
if self.maintenance_push_handler_func:
self.maintenance_push_handler_func(notification)
if self.oss_cluster_maint_push_handler_func:
self.oss_cluster_maint_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
Expand Down
3 changes: 3 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,9 @@ def __repr__(self) -> str:
def __eq__(self, obj: Any) -> bool:
return isinstance(obj, ClusterNode) and obj.name == self.name

def __hash__(self) -> int:
return hash(self.name)

_DEL_MESSAGE = "Unclosed ClusterNode object"

def __del__(
Expand Down
3 changes: 3 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,9 @@ def __repr__(self):
def __eq__(self, obj):
return isinstance(obj, ClusterNode) and obj.name == self.name

def __hash__(self):
return hash(self.name)


class LoadBalancingStrategy(Enum):
ROUND_ROBIN = "round_robin"
Expand Down
41 changes: 19 additions & 22 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,23 +440,17 @@ def _configure_maintenance_notifications(
else:
self._maint_notifications_pool_handler = None

self._maint_notifications_connection_handler = (
MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
)

if oss_cluster_maint_notifications_handler:
# Extract a reference to a new handler that copies all properties
# of the original one and has a different connection reference
# This is needed because when we attach the handler to the parser
# we need to make sure that the handler has a reference to the
# connection that the parser is attached to.
self._oss_cluster_maint_notifications_handler = (
oss_cluster_maint_notifications_handler.get_handler_for_connection()
oss_cluster_maint_notifications_handler
)
self._oss_cluster_maint_notifications_handler.set_connection(self)
else:
self._oss_cluster_maint_notifications_handler = None

self._maint_notifications_connection_handler = (
MaintNotificationsConnectionHandler(self, self.maint_notifications_config)
)

# Set up OSS cluster handler to parser if available
if self._oss_cluster_maint_notifications_handler:
parser.set_oss_cluster_maint_push_handler(
Expand Down Expand Up @@ -521,21 +515,12 @@ def set_maint_notifications_pool_handler_for_connection(
def set_maint_notifications_cluster_handler_for_connection(
self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler
):
# Deep copy the cluster handler to avoid sharing the same handler
# between multiple connections, because otherwise each connection will override
# the connection reference and the handler will only hold a reference
# to the last connection that was set.
maint_notifications_cluster_handler_copy = (
oss_cluster_maint_notifications_handler.get_handler_for_connection()
)

maint_notifications_cluster_handler_copy.set_connection(self)
self._get_parser().set_oss_cluster_maint_push_handler(
maint_notifications_cluster_handler_copy.handle_notification
oss_cluster_maint_notifications_handler.handle_notification
)

self._oss_cluster_maint_notifications_handler = (
maint_notifications_cluster_handler_copy
oss_cluster_maint_notifications_handler
)

# Update maintenance notification connection handler if it doesn't exist
Expand Down Expand Up @@ -1142,6 +1127,7 @@ def disconnect(self, *args):
self._sock = None
# reset the reconnect flag
self.reset_should_reconnect()

if conn_sock is None:
return

Expand All @@ -1156,6 +1142,17 @@ def disconnect(self, *args):
except OSError:
pass

if self.maintenance_state == MaintenanceState.MAINTENANCE:
# this block will be executed only if the connection was in maintenance state
# and the connection was closed.
# The state change won't be applied on connections that are in Moving state
# because their state and configurations will be handled when the moving ttl expires.
self.reset_tmp_settings(reset_relaxed_timeout=True)
self.maintenance_state = MaintenanceState.NONE
# reset the sets that keep track of received start maint
# notifications and skipped end maint notifications
self.reset_received_notifications()

def mark_for_reconnect(self):
self._should_reconnect = True

Expand Down
147 changes: 73 additions & 74 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import threading
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional, Union

from redis.typing import Number

Expand Down Expand Up @@ -454,6 +455,13 @@ def __hash__(self) -> int:
return hash((self.__class__.__name__, int(self.id)))


@dataclass
class NodesToSlotsMapping:
src_node_address: str
dest_node_address: str
slots: str


class OSSNodeMigratedNotification(MaintenanceNotification):
"""
Notification for when a Redis OSS API client is used and a node has completed migrating slots.
Expand All @@ -463,15 +471,15 @@ class OSSNodeMigratedNotification(MaintenanceNotification):

Args:
id (int): Unique identifier for this notification
nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots
nodes_to_slots_mapping (List[NodesToSlotsMapping]): List of node-to-slots mappings
"""

DEFAULT_TTL = 30
DEFAULT_TTL = 120

def __init__(
self,
id: int,
nodes_to_slots_mapping: Dict[str, str],
nodes_to_slots_mapping: List[NodesToSlotsMapping],
):
super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
self.nodes_to_slots_mapping = nodes_to_slots_mapping
Expand Down Expand Up @@ -551,6 +559,25 @@ def _is_private_fqdn(host: str) -> bool:
return False


def add_debug_log_for_notification(
connection: "MaintNotificationsAbstractConnection",
notification: Union[str, MaintenanceNotification],
):
if logging.getLogger().isEnabledFor(logging.DEBUG):
try:
socket_address = (
connection._sock.getsockname() if connection._sock else None
)
except (AttributeError, OSError):
socket_address = None

logging.debug(
f"Handling maintenance notification: {notification}, "
f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, "
f"socket_address: {socket_address}",
)


class MaintNotificationsConfig:
"""
Configuration class for maintenance notifications handling behaviour. Notifications are received through
Expand Down Expand Up @@ -885,6 +912,7 @@ class MaintNotificationsConnectionHandler:
OSSNodeMigratingNotification: 1,
NodeMigratedNotification: 0,
NodeFailedOverNotification: 0,
OSSNodeMigratedNotification: 0,
}

def __init__(
Expand Down Expand Up @@ -913,10 +941,8 @@ def handle_notification(self, notification: MaintenanceNotification):
def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
):
logging.debug(
f"Handling start maintenance notification: {notification}, "
f"with connection: {self.connection}, connected to ip {self.connection.get_resolved_ip()}"
)
add_debug_log_for_notification(self.connection, notification)

if (
self.connection.maintenance_state == MaintenanceState.MOVING
or not self.config.is_relaxed_timeouts_enabled()
Expand All @@ -942,10 +968,7 @@ def handle_maintenance_completed_notification(self):
or not self.config.is_relaxed_timeouts_enabled()
):
return
logging.debug(
f"Handling end maintenance notification with connection: {self.connection}, "
f"connected to ip {self.connection.get_resolved_ip()}"
)
add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED")
self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
# Maintenance completed - reset the connection
# timeouts by providing -1 as the relaxed timeout
Expand All @@ -967,10 +990,6 @@ def __init__(
self._processed_notifications = set()
self._in_progress = set()
self._lock = threading.RLock()
self.connection = None

def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
self.connection = connection

def get_handler_for_connection(self):
# Copy all data that should be shared between connections
Expand All @@ -980,7 +999,6 @@ def get_handler_for_connection(self):
copy._processed_notifications = self._processed_notifications
copy._in_progress = self._in_progress
copy._lock = self._lock
copy.connection = None
return copy

def remove_expired_notifications(self):
Expand Down Expand Up @@ -1011,77 +1029,58 @@ def handle_oss_maintenance_completed_notification(
# that has also has the notification and we don't want to
# process the same notification twice
return
if self.connection is None:
logging.error(
"Connection is not set for OSSMaintNotificationsHandler. "
f"Failed to handle notification: {notification}"
)
return

logging.debug(
f"Handling SMIGRATED notification: {notification} with connection: {self.connection}, connected to ip {self.connection.get_resolved_ip()}"
)
logging.debug(f"Handling SMIGRATED notification: {notification}")
self._in_progress.add(notification)

# get the node to which the connection is connected
# before refreshing the cluster topology
current_node = self.cluster_client.nodes_manager.get_node(
host=self.connection.host, port=self.connection.port
)

# Updates the cluster slots cache with the new slots mapping
# This will also update the nodes cache with the new nodes mapping
# Extract the information about the src and destination nodes that are affected by the maintenance
additional_startup_nodes_info = []
for node_address, _ in notification.nodes_to_slots_mapping.items():
new_node_host, new_node_port = node_address.split(":")
affected_nodes = set()
for mapping in notification.nodes_to_slots_mapping:
new_node_host, new_node_port = mapping.dest_node_address.split(":")
src_host, src_port = mapping.src_node_address.split(":")
src_node = self.cluster_client.nodes_manager.get_node(
host=src_host, port=src_port
)
if src_node is not None:
affected_nodes.add(src_node)

additional_startup_nodes_info.append(
(new_node_host, int(new_node_port))
)

# Updates the cluster slots cache with the new slots mapping
# This will also update the nodes cache with the new nodes mapping
self.cluster_client.nodes_manager.initialize(
disconnect_startup_nodes_pools=False,
additional_startup_nodes_info=additional_startup_nodes_info,
)

with current_node.redis_connection.connection_pool._lock:
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
conn.mark_for_reconnect()
all_nodes = set(affected_nodes)
all_nodes = all_nodes.union(
self.cluster_client.nodes_manager.nodes_cache.values()
)

if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()
else:
if self.config.is_relaxed_timeouts_enabled():
# reset the timeouts for the node to which the connection is connected
# Perform check if other maintenance ops are in progress for the same node
# and if so, don't reset the timeouts and wait for the last maintenance
# to complete
for conn in (
*current_node.redis_connection.connection_pool._get_in_use_connections(),
*current_node.redis_connection.connection_pool._get_free_connections(),
):
if (
len(conn.get_processed_start_notifications())
> len(conn.get_skipped_end_notifications()) + 1
):
# we have received more start notifications than end notifications
# for this connection - we should not reset the timeouts
# and add the notification id to the set of skipped end notifications
conn.add_skipped_end_notification(notification.id)
else:
conn.reset_tmp_settings(reset_relaxed_timeout=True)
conn.update_current_socket_timeout(relaxed_timeout=-1)
conn.maintenance_state = MaintenanceState.NONE
conn.reset_received_notifications()
for current_node in all_nodes:
if current_node.redis_connection is None:
continue
with current_node.redis_connection.connection_pool._lock:
if current_node in affected_nodes:
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
conn.mark_for_reconnect()

if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()

# mark the notification as processed
self._processed_notifications.add(notification)
Expand Down
Loading
Loading