Skip to content

Commit a7857e1

Browse files
add "address_remap" feature to RedisCluster (#2726)
* add cluster "host_port_remap" feature for asyncio.RedisCluster * Add a unittest for asyncio.RedisCluster * Add host_port_remap to _sync_ RedisCluster * add synchronous tests * rename arg to `address_remap` and take and return an address tuple. * Add class documentation * Add CHANGES
1 parent ac15d52 commit a7857e1

File tree

5 files changed

+291
-2
lines changed

5 files changed

+291
-2
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Add `address_remap` parameter to `RedisCluster`
12
* Fix incorrect usage of once flag in async Sentinel
23
* asyncio: Fix memory leak caused by hiredis (#2693)
34
* Allow data to drain from async PythonParser when reading during a disconnect()

redis/asyncio/cluster.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import warnings
66
from typing import (
77
Any,
8+
Callable,
89
Deque,
910
Dict,
1011
Generator,
1112
List,
1213
Mapping,
1314
Optional,
15+
Tuple,
1416
Type,
1517
TypeVar,
1618
Union,
@@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
147149
maximum number of connections are already created, a
148150
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
149151
by :attr:`connection_error_retry_attempts`
152+
:param address_remap:
153+
| An optional callable which, when provided with an internal network
154+
address of a node, e.g. a `(host, port)` tuple, will return the address
155+
where the node is reachable. This can be used to map the addresses at
156+
which the nodes _think_ they are, to addresses at which a client may
157+
reach them, such as when they sit behind a proxy.
150158
151159
| Rest of the arguments will be passed to the
152160
:class:`~redis.asyncio.connection.Connection` instances when created
@@ -250,6 +258,7 @@ def __init__(
250258
ssl_certfile: Optional[str] = None,
251259
ssl_check_hostname: bool = False,
252260
ssl_keyfile: Optional[str] = None,
261+
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
253262
) -> None:
254263
if db:
255264
raise RedisClusterException(
@@ -337,7 +346,12 @@ def __init__(
337346
if host and port:
338347
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
339348

340-
self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
349+
self.nodes_manager = NodesManager(
350+
startup_nodes,
351+
require_full_coverage,
352+
kwargs,
353+
address_remap=address_remap,
354+
)
341355
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
342356
self.read_from_replicas = read_from_replicas
343357
self.reinitialize_steps = reinitialize_steps
@@ -1059,17 +1073,20 @@ class NodesManager:
10591073
"require_full_coverage",
10601074
"slots_cache",
10611075
"startup_nodes",
1076+
"address_remap",
10621077
)
10631078

10641079
def __init__(
10651080
self,
10661081
startup_nodes: List["ClusterNode"],
10671082
require_full_coverage: bool,
10681083
connection_kwargs: Dict[str, Any],
1084+
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
10691085
) -> None:
10701086
self.startup_nodes = {node.name: node for node in startup_nodes}
10711087
self.require_full_coverage = require_full_coverage
10721088
self.connection_kwargs = connection_kwargs
1089+
self.address_remap = address_remap
10731090

10741091
self.default_node: "ClusterNode" = None
10751092
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1228,6 +1245,7 @@ async def initialize(self) -> None:
12281245
if host == "":
12291246
host = startup_node.host
12301247
port = int(primary_node[1])
1248+
host, port = self.remap_host_port(host, port)
12311249

12321250
target_node = tmp_nodes_cache.get(get_node_name(host, port))
12331251
if not target_node:
@@ -1246,6 +1264,7 @@ async def initialize(self) -> None:
12461264
for replica_node in replica_nodes:
12471265
host = replica_node[0]
12481266
port = replica_node[1]
1267+
host, port = self.remap_host_port(host, port)
12491268

12501269
target_replica_node = tmp_nodes_cache.get(
12511270
get_node_name(host, port)
@@ -1319,6 +1338,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
13191338
)
13201339
)
13211340

1341+
def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1342+
"""
1343+
Remap the host and port returned from the cluster to a different
1344+
internal value. Useful if the client is not connecting directly
1345+
to the cluster.
1346+
"""
1347+
if self.address_remap:
1348+
return self.address_remap((host, port))
1349+
return host, port
1350+
13221351

13231352
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
13241353
"""

redis/cluster.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def __init__(
466466
read_from_replicas: bool = False,
467467
dynamic_startup_nodes: bool = True,
468468
url: Optional[str] = None,
469+
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
469470
**kwargs,
470471
):
471472
"""
@@ -514,6 +515,12 @@ def __init__(
514515
reinitialize_steps to 1.
515516
To avoid reinitializing the cluster on moved errors, set
516517
reinitialize_steps to 0.
518+
:param address_remap:
519+
An optional callable which, when provided with an internal network
520+
address of a node, e.g. a `(host, port)` tuple, will return the address
521+
where the node is reachable. This can be used to map the addresses at
522+
which the nodes _think_ they are, to addresses at which a client may
523+
reach them, such as when they sit behind a proxy.
517524
518525
:**kwargs:
519526
Extra arguments that will be sent into Redis instance when created
@@ -594,6 +601,7 @@ def __init__(
594601
from_url=from_url,
595602
require_full_coverage=require_full_coverage,
596603
dynamic_startup_nodes=dynamic_startup_nodes,
604+
address_remap=address_remap,
597605
**kwargs,
598606
)
599607

@@ -1269,6 +1277,7 @@ def __init__(
12691277
lock=None,
12701278
dynamic_startup_nodes=True,
12711279
connection_pool_class=ConnectionPool,
1280+
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
12721281
**kwargs,
12731282
):
12741283
self.nodes_cache = {}
@@ -1280,6 +1289,7 @@ def __init__(
12801289
self._require_full_coverage = require_full_coverage
12811290
self._dynamic_startup_nodes = dynamic_startup_nodes
12821291
self.connection_pool_class = connection_pool_class
1292+
self.address_remap = address_remap
12831293
self._moved_exception = None
12841294
self.connection_kwargs = kwargs
12851295
self.read_load_balancer = LoadBalancer()
@@ -1502,6 +1512,7 @@ def initialize(self):
15021512
if host == "":
15031513
host = startup_node.host
15041514
port = int(primary_node[1])
1515+
host, port = self.remap_host_port(host, port)
15051516

15061517
target_node = self._get_or_create_cluster_node(
15071518
host, port, PRIMARY, tmp_nodes_cache
@@ -1518,6 +1529,7 @@ def initialize(self):
15181529
for replica_node in replica_nodes:
15191530
host = str_if_bytes(replica_node[0])
15201531
port = replica_node[1]
1532+
host, port = self.remap_host_port(host, port)
15211533

15221534
target_replica_node = self._get_or_create_cluster_node(
15231535
host, port, REPLICA, tmp_nodes_cache
@@ -1591,6 +1603,16 @@ def reset(self):
15911603
# The read_load_balancer is None, do nothing
15921604
pass
15931605

1606+
def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1607+
"""
1608+
Remap the host and port returned from the cluster to a different
1609+
internal value. Useful if the client is not connecting directly
1610+
to the cluster.
1611+
"""
1612+
if self.address_remap:
1613+
return self.address_remap((host, port))
1614+
return host, port
1615+
15941616

15951617
class ClusterPubSub(PubSub):
15961618
"""

tests/test_asyncio/test_cluster.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from _pytest.fixtures import FixtureRequest
1212

1313
from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
14-
from redis.asyncio.connection import Connection, SSLConnection
14+
from redis.asyncio.connection import Connection, SSLConnection, async_timeout
1515
from redis.asyncio.parser import CommandsParser
1616
from redis.asyncio.retry import Retry
1717
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
@@ -49,6 +49,71 @@
4949
]
5050

5151

52+
class NodeProxy:
53+
"""A class to proxy a node connection to a different port"""
54+
55+
def __init__(self, addr, redis_addr):
56+
self.addr = addr
57+
self.redis_addr = redis_addr
58+
self.send_event = asyncio.Event()
59+
self.server = None
60+
self.task = None
61+
self.n_connections = 0
62+
63+
async def start(self):
64+
# test that we can connect to redis
65+
async with async_timeout(2):
66+
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
67+
redis_writer.close()
68+
self.server = await asyncio.start_server(
69+
self.handle, *self.addr, reuse_address=True
70+
)
71+
self.task = asyncio.create_task(self.server.serve_forever())
72+
73+
async def handle(self, reader, writer):
74+
# establish connection to redis
75+
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
76+
try:
77+
self.n_connections += 1
78+
pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
79+
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
80+
await asyncio.gather(pipe1, pipe2)
81+
finally:
82+
redis_writer.close()
83+
84+
async def aclose(self):
85+
self.task.cancel()
86+
try:
87+
await self.task
88+
except asyncio.CancelledError:
89+
pass
90+
await self.server.wait_closed()
91+
92+
async def pipe(
93+
self,
94+
reader: asyncio.StreamReader,
95+
writer: asyncio.StreamWriter,
96+
):
97+
while True:
98+
data = await reader.read(1000)
99+
if not data:
100+
break
101+
writer.write(data)
102+
await writer.drain()
103+
104+
105+
@pytest.fixture
106+
def redis_addr(request):
107+
redis_url = request.config.getoption("--redis-url")
108+
scheme, netloc = urlparse(redis_url)[:2]
109+
assert scheme == "redis"
110+
if ":" in netloc:
111+
host, port = netloc.split(":")
112+
return host, int(port)
113+
else:
114+
return netloc, 6379
115+
116+
52117
@pytest_asyncio.fixture()
53118
async def slowlog(r: RedisCluster) -> None:
54119
"""
@@ -809,6 +874,49 @@ async def test_default_node_is_replaced_after_exception(self, r):
809874
# Rollback to the old default node
810875
r.replace_default_node(curr_default_node)
811876

877+
async def test_address_remap(self, create_redis, redis_addr):
878+
"""Test that we can create a rediscluster object with
879+
a host-port remapper and map connections through proxy objects
880+
"""
881+
882+
# we remap the first n nodes
883+
offset = 1000
884+
n = 6
885+
ports = [redis_addr[1] + i for i in range(n)]
886+
887+
def address_remap(address):
888+
# remap first three nodes to our local proxy
889+
# old = host, port
890+
host, port = address
891+
if int(port) in ports:
892+
host, port = "127.0.0.1", int(port) + offset
893+
# print(f"{old} {host, port}")
894+
return host, port
895+
896+
# create the proxies
897+
proxies = [
898+
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
899+
for port in ports
900+
]
901+
await asyncio.gather(*[p.start() for p in proxies])
902+
try:
903+
# create cluster:
904+
r = await create_redis(
905+
cls=RedisCluster, flushdb=False, address_remap=address_remap
906+
)
907+
try:
908+
assert await r.ping() is True
909+
assert await r.set("byte_string", b"giraffe")
910+
assert await r.get("byte_string") == b"giraffe"
911+
finally:
912+
await r.close()
913+
finally:
914+
await asyncio.gather(*[p.aclose() for p in proxies])
915+
916+
# verify that the proxies were indeed used
917+
n_used = sum((1 if p.n_connections else 0) for p in proxies)
918+
assert n_used > 1
919+
812920

813921
class TestClusterRedisCommands:
814922
"""

0 commit comments

Comments
 (0)