5
5
import warnings
6
6
from typing import (
7
7
Any ,
8
+ Callable ,
8
9
Deque ,
9
10
Dict ,
10
11
Generator ,
11
12
List ,
12
13
Mapping ,
13
14
Optional ,
15
+ Tuple ,
14
16
Type ,
15
17
TypeVar ,
16
18
Union ,
@@ -250,6 +252,7 @@ def __init__(
250
252
ssl_certfile : Optional [str ] = None ,
251
253
ssl_check_hostname : bool = False ,
252
254
ssl_keyfile : Optional [str ] = None ,
255
+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
253
256
) -> None :
254
257
if db :
255
258
raise RedisClusterException (
@@ -337,7 +340,12 @@ def __init__(
337
340
if host and port :
338
341
startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
339
342
340
- self .nodes_manager = NodesManager (startup_nodes , require_full_coverage , kwargs )
343
+ self .nodes_manager = NodesManager (
344
+ startup_nodes ,
345
+ require_full_coverage ,
346
+ kwargs ,
347
+ host_port_remap = host_port_remap ,
348
+ )
341
349
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
342
350
self .read_from_replicas = read_from_replicas
343
351
self .reinitialize_steps = reinitialize_steps
@@ -1059,17 +1067,20 @@ class NodesManager:
1059
1067
"require_full_coverage" ,
1060
1068
"slots_cache" ,
1061
1069
"startup_nodes" ,
1070
+ "host_port_remap" ,
1062
1071
)
1063
1072
1064
1073
def __init__ (
1065
1074
self ,
1066
1075
startup_nodes : List ["ClusterNode" ],
1067
1076
require_full_coverage : bool ,
1068
1077
connection_kwargs : Dict [str , Any ],
1078
+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
1069
1079
) -> None :
1070
1080
self .startup_nodes = {node .name : node for node in startup_nodes }
1071
1081
self .require_full_coverage = require_full_coverage
1072
1082
self .connection_kwargs = connection_kwargs
1083
+ self .host_port_remap = host_port_remap
1073
1084
1074
1085
self .default_node : "ClusterNode" = None
1075
1086
self .nodes_cache : Dict [str , "ClusterNode" ] = {}
@@ -1228,6 +1239,7 @@ async def initialize(self) -> None:
1228
1239
if host == "" :
1229
1240
host = startup_node .host
1230
1241
port = int (primary_node [1 ])
1242
+ host , port = self .remap_host_port (host , port )
1231
1243
1232
1244
target_node = tmp_nodes_cache .get (get_node_name (host , port ))
1233
1245
if not target_node :
@@ -1246,6 +1258,7 @@ async def initialize(self) -> None:
1246
1258
for replica_node in replica_nodes :
1247
1259
host = replica_node [0 ]
1248
1260
port = replica_node [1 ]
1261
+ host , port = self .remap_host_port (host , port )
1249
1262
1250
1263
target_replica_node = tmp_nodes_cache .get (
1251
1264
get_node_name (host , port )
@@ -1319,6 +1332,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
1319
1332
)
1320
1333
)
1321
1334
1335
+ def remap_host_port (self , host : str , port : int ) -> Tuple [str , int ]:
1336
+ """
1337
+ Remap the host and port returned from the cluster to a different
1338
+ internal value. Useful if the client is not connecting directly
1339
+ to the cluster.
1340
+ """
1341
+ if self .host_port_remap :
1342
+ return self .host_port_remap (host , port )
1343
+ return host , port
1344
+
1322
1345
1323
1346
class ClusterPipeline (AbstractRedis , AbstractRedisCluster , AsyncRedisClusterCommands ):
1324
1347
"""
0 commit comments