5252 WatchError ,
5353)
5454from redis .lock import Lock
55- from redis .maint_notifications import MaintNotificationsConfig
55+ from redis .maint_notifications import (
56+ MaintNotificationsConfig ,
57+ OSSMaintNotificationsHandler ,
58+ )
5659from redis .retry import Retry
5760from redis .utils import (
61+ check_protocol_version ,
5862 deprecated_args ,
5963 dict_merge ,
6064 list_keys_to_dict ,
@@ -215,6 +219,67 @@ def cleanup_kwargs(**kwargs):
215219 return connection_kwargs
216220
217221
222+ class MaintNotificationsAbstractRedisCluster :
223+ """
224+ Abstract class for handling maintenance notifications logic.
225+ This class is expected to be used as base class together with RedisCluster.
226+
227+ This class is intended to be used with multiple inheritance!
228+
229+ All logic related to maintenance notifications is encapsulated in this class.
230+ """
231+
232+ def __init__ (
233+ self ,
234+ maint_notifications_config : Optional [MaintNotificationsConfig ],
235+ ** kwargs ,
236+ ):
237+ # Initialize maintenance notifications
238+ is_protocol_supported = check_protocol_version (kwargs .get ("protocol" ), 3 )
239+
240+ if (
241+ maint_notifications_config
242+ and maint_notifications_config .enabled
243+ and not is_protocol_supported
244+ ):
245+ raise RedisError (
246+ "Maintenance notifications handlers on connection are only supported with RESP version 3"
247+ )
248+ if maint_notifications_config is None and is_protocol_supported :
249+ maint_notifications_config = MaintNotificationsConfig ()
250+
251+ self .maint_notifications_config = maint_notifications_config
252+
253+ if self .maint_notifications_config and self .maint_notifications_config .enabled :
254+ self ._oss_cluster_maint_notifications_handler = (
255+ OSSMaintNotificationsHandler (self , self .maint_notifications_config )
256+ )
257+ # Update connection kwargs for all future nodes connections
258+ self ._update_connection_kwargs_for_maint_notifications (
259+ self ._oss_cluster_maint_notifications_handler
260+ )
261+ # Update existing nodes connections - they are created as part of the RedisCluster constructor
262+ for node in self .get_nodes ():
263+ node .redis_connection .connection_pool .update_maint_notifications_config (
264+ self .maint_notifications_config ,
265+ oss_cluster_maint_notifications_handler = self ._oss_cluster_maint_notifications_handler ,
266+ )
267+ else :
268+ self ._oss_cluster_maint_notifications_handler = None
269+
270+ def _update_connection_kwargs_for_maint_notifications (
271+ self , oss_cluster_maint_notifications_handler : OSSMaintNotificationsHandler
272+ ):
273+ """
274+ Update the connection kwargs for all future connections.
275+ """
276+ self .nodes_manager .connection_kwargs .update (
277+ {
278+ "oss_cluster_maint_notifications_handler" : oss_cluster_maint_notifications_handler ,
279+ }
280+ )
281+
282+
218283class AbstractRedisCluster :
219284 RedisClusterRequestTTL = 16
220285
@@ -462,7 +527,9 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
462527 self .nodes_manager .default_node = random .choice (replicas )
463528
464529
465- class RedisCluster (AbstractRedisCluster , RedisClusterCommands ):
530+ class RedisCluster (
531+ AbstractRedisCluster , MaintNotificationsAbstractRedisCluster , RedisClusterCommands
532+ ):
466533 @classmethod
467534 def from_url (cls , url : str , ** kwargs : Any ) -> "RedisCluster" :
468535 """
@@ -613,8 +680,7 @@ def __init__(
613680 `redis.maint_notifications.MaintNotificationsConfig` for details.
614681 Only supported with RESP3.
615682 If not provided and protocol is RESP3, the maintenance notifications
616- will be enabled by default (logic is included in the NodesManager
617- initialization).
683+ will be enabled by default.
618684 :**kwargs:
619685 Extra arguments that will be sent into Redis instance when created
620686 (See Official redis-py doc for supported kwargs - the only limitation
@@ -696,9 +762,16 @@ def __init__(
696762 kwargs .get ("decode_responses" , False ),
697763 )
698764 protocol = kwargs .get ("protocol" , None )
699- if (cache_config or cache ) and protocol not in [ 3 , "3" ] :
765+ if (cache_config or cache ) and not check_protocol_version ( protocol , 3 ) :
700766 raise RedisError ("Client caching is only supported with RESP version 3" )
701767
768+ if maint_notifications_config and not check_protocol_version (protocol , 3 ):
769+ raise RedisError (
770+ "Maintenance notifications are only supported with RESP version 3"
771+ )
772+ if check_protocol_version (protocol , 3 ) and maint_notifications_config is None :
773+ maint_notifications_config = MaintNotificationsConfig ()
774+
702775 self .command_flags = self .__class__ .COMMAND_FLAGS .copy ()
703776 self .node_flags = self .__class__ .NODE_FLAGS .copy ()
704777 self .read_from_replicas = read_from_replicas
@@ -710,6 +783,7 @@ def __init__(
710783 else :
711784 self ._event_dispatcher = event_dispatcher
712785 self .startup_nodes = startup_nodes
786+
713787 self .nodes_manager = NodesManager (
714788 startup_nodes = startup_nodes ,
715789 from_url = from_url ,
@@ -764,6 +838,10 @@ def __init__(
764838 self ._aggregate_nodes = None
765839 self ._lock = threading .RLock ()
766840
841+ MaintNotificationsAbstractRedisCluster .__init__ (
842+ self , maint_notifications_config , ** kwargs
843+ )
844+
767845 def __enter__ (self ):
768846 return self
769847
@@ -1639,9 +1717,7 @@ def __init__(
16391717 cache_config : Optional [CacheConfig ] = None ,
16401718 cache_factory : Optional [CacheFactoryInterface ] = None ,
16411719 event_dispatcher : Optional [EventDispatcher ] = None ,
1642- maint_notifications_config : Optional [
1643- MaintNotificationsConfig
1644- ] = MaintNotificationsConfig (),
1720+ maint_notifications_config : Optional [MaintNotificationsConfig ] = None ,
16451721 ** kwargs ,
16461722 ):
16471723 self .nodes_cache : Dict [str , Redis ] = {}
@@ -1886,11 +1962,29 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache):
18861962
18871963 return target_node
18881964
1889- def initialize (self ):
1965+ def initialize (
1966+ self ,
1967+ additional_startup_nodes_info : List [Tuple [str , int ]] = [],
1968+ disconnect_startup_nodes_pools : bool = True ,
1969+ ):
18901970 """
18911971 Initializes the nodes cache, slots cache and redis connections.
18921972 :startup_nodes:
18931973 Responsible for discovering other nodes in the cluster
1974+ :disconnect_startup_nodes_pools:
1975+ Whether to disconnect the connection pool of the startup nodes
1976+ after the initialization is complete. This is useful when the
1977+ startup nodes are not part of the cluster and we want to avoid
1978+ keeping the connection open.
1979+ :additional_startup_nodes_info:
1980+ Additional nodes to add temporarily to the startup nodes.
1981+ The additional nodes will be used just in the process of extraction of the slots
1982+ and nodes information from the cluster.
1983+ This is useful when we want to add new nodes to the cluster
1984+ and initialize the client
1985+ with them.
1986+ The format of the list is a list of tuples, where each tuple contains
1987+ the host and port of the node.
18941988 """
18951989 self .reset ()
18961990 tmp_nodes_cache = {}
@@ -1900,9 +1994,25 @@ def initialize(self):
19001994 fully_covered = False
19011995 kwargs = self .connection_kwargs
19021996 exception = None
1997+
1998+ # Create cache if it's not provided and cache config is set
1999+ # should be done before initializing the first connection
2000+ # so that it will be applied to all connections
2001+ if self ._cache is None and self ._cache_config is not None :
2002+ if self ._cache_factory is None :
2003+ self ._cache = CacheFactory (self ._cache_config ).get_cache ()
2004+ else :
2005+ self ._cache = self ._cache_factory .get_cache ()
2006+
2007+ additional_startup_nodes = [
2008+ ClusterNode (host , port ) for host , port in additional_startup_nodes_info
2009+ ]
19032010 # Convert to tuple to prevent RuntimeError if self.startup_nodes
19042011 # is modified during iteration
1905- for startup_node in tuple (self .startup_nodes .values ()):
2012+ for startup_node in (
2013+ * self .startup_nodes .values (),
2014+ * additional_startup_nodes ,
2015+ ):
19062016 try :
19072017 if startup_node .redis_connection :
19082018 r = startup_node .redis_connection
@@ -1918,7 +2028,11 @@ def initialize(self):
19182028 # Make sure cluster mode is enabled on this node
19192029 try :
19202030 cluster_slots = str_if_bytes (r .execute_command ("CLUSTER SLOTS" ))
1921- r .connection_pool .disconnect ()
2031+ if disconnect_startup_nodes_pools :
2032+ # Disconnect the connection pool to avoid keeping the connection open
2033+ # For some cases we might not want to disconnect current pool and
2034+ # lose in flight commands responses
2035+ r .connection_pool .disconnect ()
19222036 except ResponseError :
19232037 raise RedisClusterException (
19242038 "Cluster mode is not enabled on this node"
@@ -1999,12 +2113,6 @@ def initialize(self):
19992113 f"one reachable node: { str (exception )} "
20002114 ) from exception
20012115
2002- if self ._cache is None and self ._cache_config is not None :
2003- if self ._cache_factory is None :
2004- self ._cache = CacheFactory (self ._cache_config ).get_cache ()
2005- else :
2006- self ._cache = self ._cache_factory .get_cache ()
2007-
20082116 # Create Redis connections to all nodes
20092117 self .create_redis_connections (list (tmp_nodes_cache .values ()))
20102118
0 commit comments