Skip to content

Commit 3356388

Browse files
committed
add cluster "host_port_remap" feature
1 parent 896f087 commit 3356388

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

redis/asyncio/cluster.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import warnings
66
from typing import (
77
Any,
8+
Callable,
89
Deque,
910
Dict,
1011
Generator,
1112
List,
1213
Mapping,
1314
Optional,
15+
Tuple,
1416
Type,
1517
TypeVar,
1618
Union,
@@ -250,6 +252,7 @@ def __init__(
250252
ssl_certfile: Optional[str] = None,
251253
ssl_check_hostname: bool = False,
252254
ssl_keyfile: Optional[str] = None,
255+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
253256
) -> None:
254257
if db:
255258
raise RedisClusterException(
@@ -337,7 +340,12 @@ def __init__(
337340
if host and port:
338341
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
339342

340-
self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
343+
self.nodes_manager = NodesManager(
344+
startup_nodes,
345+
require_full_coverage,
346+
kwargs,
347+
host_port_remap=host_port_remap,
348+
)
341349
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
342350
self.read_from_replicas = read_from_replicas
343351
self.reinitialize_steps = reinitialize_steps
@@ -1059,17 +1067,20 @@ class NodesManager:
10591067
"require_full_coverage",
10601068
"slots_cache",
10611069
"startup_nodes",
1070+
"host_port_remap",
10621071
)
10631072

10641073
def __init__(
10651074
self,
10661075
startup_nodes: List["ClusterNode"],
10671076
require_full_coverage: bool,
10681077
connection_kwargs: Dict[str, Any],
1078+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
10691079
) -> None:
10701080
self.startup_nodes = {node.name: node for node in startup_nodes}
10711081
self.require_full_coverage = require_full_coverage
10721082
self.connection_kwargs = connection_kwargs
1083+
self.host_port_remap = host_port_remap
10731084

10741085
self.default_node: "ClusterNode" = None
10751086
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1228,6 +1239,7 @@ async def initialize(self) -> None:
12281239
if host == "":
12291240
host = startup_node.host
12301241
port = int(primary_node[1])
1242+
host, port = self.remap_host_port(host, port)
12311243

12321244
target_node = tmp_nodes_cache.get(get_node_name(host, port))
12331245
if not target_node:
@@ -1246,6 +1258,7 @@ async def initialize(self) -> None:
12461258
for replica_node in replica_nodes:
12471259
host = replica_node[0]
12481260
port = replica_node[1]
1261+
host, port = self.remap_host_port(host, port)
12491262

12501263
target_replica_node = tmp_nodes_cache.get(
12511264
get_node_name(host, port)
@@ -1319,6 +1332,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
13191332
)
13201333
)
13211334

1335+
def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1336+
"""
1337+
Remap the host and port returned from the cluster to a different
1338+
internal value. Useful if the client is not connecting directly
1339+
to the cluster.
1340+
"""
1341+
if self.host_port_remap:
1342+
return self.host_port_remap(host, port)
1343+
return host, port
1344+
13221345

13231346
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
13241347
"""

0 commit comments

Comments
 (0)