@@ -97,6 +97,55 @@ class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
97
97
pass
98
98
99
99
100
+ class SentinelConnectionPoolProxy :
101
+ def __init__ (
102
+ self ,
103
+ connection_pool ,
104
+ is_master ,
105
+ check_connection ,
106
+ service_name ,
107
+ sentinel_manager ,
108
+ ):
109
+ self .connection_pool_ref = weakref .ref (connection_pool )
110
+ self .is_master = is_master
111
+ self .check_connection = check_connection
112
+ self .service_name = service_name
113
+ self .sentinel_manager = sentinel_manager
114
+ self .reset ()
115
+
116
+ def reset (self ):
117
+ self .master_address = None
118
+ self .slave_rr_counter = None
119
+
120
+ async def get_master_address (self ):
121
+ master_address = await self .sentinel_manager .discover_master (self .service_name )
122
+ if self .is_master and self .master_address != master_address :
123
+ self .master_address = master_address
124
+ # disconnect any idle connections so that they reconnect
125
+ # to the new master the next time that they are used.
126
+ connection_pool = self .connection_pool_ref ()
127
+ if connection_pool is not None :
128
+ await connection_pool .disconnect (inuse_connections = False )
129
+ return master_address
130
+
131
+ async def rotate_slaves (self ) -> AsyncIterator :
132
+ """Round-robin slave balancer"""
133
+ slaves = await self .sentinel_manager .discover_slaves (self .service_name )
134
+ if slaves :
135
+ if self .slave_rr_counter is None :
136
+ self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
137
+ for _ in range (len (slaves )):
138
+ self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
139
+ slave = slaves [self .slave_rr_counter ]
140
+ yield slave
141
+ # Fallback to the master connection
142
+ try :
143
+ yield await self .get_master_address ()
144
+ except MasterNotFoundError :
145
+ pass
146
+ raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
147
+
148
+
100
149
class SentinelConnectionPool (ConnectionPool ):
101
150
"""
102
151
Sentinel backed connection pool.
@@ -116,6 +165,44 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
116
165
)
117
166
self .is_master = kwargs .pop ("is_master" , True )
118
167
self .check_connection = kwargs .pop ("check_connection" , False )
168
+ self .proxy = SentinelConnectionPoolProxy (
169
+ connection_pool = self ,
170
+ is_master = self .is_master ,
171
+ check_connection = self .check_connection ,
172
+ service_name = service_name ,
173
+ sentinel_manager = sentinel_manager ,
174
+ )
175
+ super ().__init__ (** kwargs )
176
+ self .connection_kwargs ["connection_pool" ] = self .proxy
177
+ self .service_name = service_name
178
+ self .sentinel_manager = sentinel_manager
179
+
180
+ def __repr__ (self ):
181
+ return (
182
+ f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } "
183
+ f"(service={ self .service_name } ({ self .is_master and 'master' or 'slave' } ))>"
184
+ )
185
+
186
+ def reset (self ):
187
+ super ().reset ()
188
+ self .proxy .reset ()
189
+
190
+ @property
191
+ def master_address (self ):
192
+ return self .proxy .master_address
193
+
194
+ def owns_connection (self , connection : Connection ):
195
+ check = not self .is_master or (
196
+ self .is_master and self .master_address == (connection .host , connection .port )
197
+ )
198
+ return check and super ().owns_connection (connection )
199
+
200
+ async def get_master_address (self ):
201
+ return await self .proxy .get_master_address ()
202
+
203
+ def rotate_slaves (self ) -> AsyncIterator :
204
+ """Round-robin slave balancer"""
205
+ return self .proxy .rotate_slaves ()
119
206
super ().__init__ (** kwargs )
120
207
self .connection_kwargs ["connection_pool" ] = weakref .proxy (self )
121
208
self .service_name = service_name
0 commit comments