From e7fc06b53db6ea2265a7c4d16be7fc7e10e4636f Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sun, 10 Apr 2022 23:17:18 +0530 Subject: [PATCH 01/23] Copy Cluster Client, Commands, Commands Parser, Tests for asyncio --- redis/asyncio/client.py | 10 +- redis/asyncio/cluster.py | 2251 ++++++++++++++++ redis/asyncio/parser.py | 143 + redis/cluster.py | 168 +- redis/commands/cluster.py | 582 ++++- tests/test_asyncio/test_cluster.py | 2729 ++++++++++++++++++++ tests/test_asyncio/test_commands.py | 127 +- tests/test_asyncio/test_connection_pool.py | 70 +- 8 files changed, 5873 insertions(+), 207 deletions(-) create mode 100644 redis/asyncio/cluster.py create mode 100644 redis/asyncio/parser.py create mode 100644 tests/test_asyncio/test_cluster.py diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 559b74c9f6..1fd46e5b16 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -263,11 +263,7 @@ def get_connection_kwargs(self): """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -426,9 +422,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __del__(self, _warnings: Any = warnings) -> None: if self.connection is not None: _warnings.warn( - f"Unclosed client session {self!r}", - ResourceWarning, - source=self, + f"Unclosed client session {self!r}", ResourceWarning, source=self ) context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_event_loop().call_exception_handler(context) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py new file mode 100644 index 0000000000..92469e4f37 --- /dev/null +++ b/redis/asyncio/cluster.py @@ -0,0 +1,2251 @@ +import copy +import logging +import random +import socket +import sys +import threading +import time +from collections import OrderedDict + +from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan +from redis.commands import CommandsParser, RedisClusterCommands +from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + ConnectionError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + RedisError, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.lock import Lock +from redis.utils import ( + dict_merge, + list_keys_to_dict, + merge_result, + safe_str, + str_if_bytes, +) + +log = logging.getLogger(__name__) + + +def get_node_name(host, port): + return f"{host}:{port}" + + +def get_connection(redis_node, *args, **options): + return redis_node.connection or redis_node.connection_pool.get_connection( + args[0], **options + ) + + +def parse_scan_result(command, res, **options): + cursors = {} + ret = [] + for node_name, response in res.items(): + cursor, r = parse_scan(response, **options) + cursors[node_name] = cursor + ret += r + + return cursors, ret + + +def parse_pubsub_numsub(command, res, **options): + numsub_d = OrderedDict() + for numsub_tups in res.values(): + for channel, numsubbed in numsub_tups: + try: + numsub_d[channel] += numsubbed + except KeyError: + numsub_d[channel] = numsubbed + + ret_numsub = [(channel, numsub) for channel, numsub in numsub_d.items()] + return ret_numsub + + +def parse_cluster_slots(resp, **options): + current_host = options.get("current_host", "") + + def fix_server(*args): + return str_if_bytes(args[0]) or current_host, args[1] + + slots = {} + for slot in resp: + start, end, primary = slot[:3] + replicas = slot[3:] + slots[start, end] = { + "primary": fix_server(*primary), + "replicas": [fix_server(*replica) for replica in replicas], + } + + return slots + + +PRIMARY = "primary" +REPLICA = "replica" +SLOT_ID = "slot-id" + +REDIS_ALLOWED_KEYS = ( + "charset", + "connection_class", + "connection_pool", + "client_name", + "db", + "decode_responses", + "encoding", + "encoding_errors", + "errors", + "host", + "max_connections", + "nodes_flag", + "redis_connect_func", + "password", + "port", + "retry", + "retry_on_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_timeout", + "ssl", + "ssl_ca_certs", + "ssl_ca_data", + "ssl_certfile", + "ssl_cert_reqs", + "ssl_keyfile", + "ssl_password", + "unix_socket_path", + "username", +) +KWARGS_DISABLED_KEYS = ("host", "port") + +# Not complete, but covers the major ones +# https://redis.io/commands +READ_COMMANDS = frozenset( + [ + "BITCOUNT", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", + ] +) + + +def cleanup_kwargs(**kwargs): + """ + Remove unsupported or disabled keys from kwargs + """ + connection_kwargs = { + k: v + for k, v in kwargs.items() + if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS + } + + return connection_kwargs + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) + + +class RedisCluster(RedisClusterCommands): + RedisClusterRequestTTL = 16 + + PRIMARIES = "primaries" + REPLICAS = "replicas" + ALL_NODES = "all" + RANDOM = "random" + DEFAULT_NODE = "default-node" + + NODE_FLAGS = {PRIMARIES, REPLICAS, ALL_NODES, RANDOM, DEFAULT_NODE} + + COMMAND_FLAGS = dict_merge( + list_keys_to_dict( + [ + "ACL CAT", + "ACL DELUSER", + "ACL GENPASS", + "ACL GETUSER", + "ACL HELP", + "ACL LIST", + "ACL LOG", + "ACL LOAD", + "ACL SAVE", + "ACL SETUSER", + "ACL USERS", + "ACL WHOAMI", + "AUTH", + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT GETNAME", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "TIME", + "PUBSUB CHANNELS", + "PUBSUB NUMPAT", + "PUBSUB NUMSUB", + "PING", + "INFO", + "SHUTDOWN", + "KEYS", + "DBSIZE", + "BGSAVE", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "WAIT", + "SAVE", + "MEMORY PURGE", + "MEMORY MALLOC-STATS", + "MEMORY STATS", + "LASTSAVE", + "CLIENT TRACKINGINFO", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + "CLIENT UNBLOCK", + "CLIENT ID", + "CLIENT REPLY", + "CLIENT GETREDIR", + "CLIENT INFO", + "CLIENT KILL", + "READONLY", + "READWRITE", + "CLUSTER INFO", + "CLUSTER MEET", + "CLUSTER NODES", + "CLUSTER REPLICAS", + "CLUSTER RESET", + "CLUSTER SET-CONFIG-EPOCH", + "CLUSTER SLOTS", + "CLUSTER COUNT-FAILURE-REPORTS", + "CLUSTER KEYSLOT", + "COMMAND", + "COMMAND COUNT", + "COMMAND GETKEYS", + "CONFIG GET", + "DEBUG", + "RANDOMKEY", + "READONLY", + "READWRITE", + "TIME", + "GRAPH.CONFIG", + ], + DEFAULT_NODE, + ), + list_keys_to_dict( + [ + "FLUSHALL", + "FLUSHDB", + "FUNCTION DELETE", + "FUNCTION FLUSH", + "FUNCTION LIST", + "FUNCTION LOAD", + "FUNCTION RESTORE", + "SCAN", + "SCRIPT EXISTS", + "SCRIPT FLUSH", + "SCRIPT LOAD", + ], + PRIMARIES, + ), + list_keys_to_dict(["FUNCTION DUMP"], RANDOM), + list_keys_to_dict( + [ + "CLUSTER COUNTKEYSINSLOT", + "CLUSTER DELSLOTS", + "CLUSTER DELSLOTSRANGE", + "CLUSTER GETKEYSINSLOT", + "CLUSTER SETSLOT", + ], + SLOT_ID, + ), + ) + + SEARCH_COMMANDS = ( + [ + "FT.CREATE", + "FT.SEARCH", + "FT.AGGREGATE", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT,PROFILE", + "FT.ALTER", + "FT.DROPINDEX", + "FT.ALIASADD", + "FT.ALIASUPDATE", + "FT.ALIASDEL", + "FT.TAGVALS", + "FT.SUGADD", + "FT.SUGGET", + "FT.SUGDEL", + "FT.SUGLEN", + "FT.SYNUPDATE", + "FT.SYNDUMP", + "FT.SPELLCHECK", + "FT.DICTADD", + "FT.DICTDEL", + "FT.DICTDUMP", + "FT.INFO", + "FT._LIST", + "FT.CONFIG", + "FT.ADD", + "FT.DEL", + "FT.DROP", + "FT.GET", + "FT.MGET", + "FT.SYNADD", + ], + ) + + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { + "CLUSTER ADDSLOTS": bool, + "CLUSTER ADDSLOTSRANGE": bool, + "CLUSTER COUNT-FAILURE-REPORTS": int, + "CLUSTER COUNTKEYSINSLOT": int, + "CLUSTER DELSLOTS": bool, + "CLUSTER DELSLOTSRANGE": bool, + "CLUSTER FAILOVER": bool, + "CLUSTER FORGET": bool, + "CLUSTER GETKEYSINSLOT": list, + "CLUSTER KEYSLOT": int, + "CLUSTER MEET": bool, + "CLUSTER REPLICATE": bool, + "CLUSTER RESET": bool, + "CLUSTER SAVECONFIG": bool, + "CLUSTER SET-CONFIG-EPOCH": bool, + "CLUSTER SETSLOT": bool, + "CLUSTER SLOTS": parse_cluster_slots, + "ASKING": bool, + "READONLY": bool, + "READWRITE": bool, + } + + RESULT_CALLBACKS = dict_merge( + list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), + list_keys_to_dict( + ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) + ), + list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), + list_keys_to_dict( + [ + "PING", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "CLIENT SETNAME", + "BGSAVE", + "SLOWLOG RESET", + "SAVE", + "MEMORY PURGE", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + ], + lambda command, res: all(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + ["DBSIZE", "WAIT"], + lambda command, res: sum(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + ["CLIENT UNBLOCK"], lambda command, res: 1 if sum(res.values()) > 0 else 0 + ), + list_keys_to_dict(["SCAN"], parse_scan_result), + list_keys_to_dict( + ["SCRIPT LOAD"], lambda command, res: list(res.values()).pop() + ), + list_keys_to_dict( + ["SCRIPT EXISTS"], lambda command, res: [all(k) for k in zip(*res.values())] + ), + list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), + ) + + ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + def __init__( + self, + host=None, + port=6379, + startup_nodes=None, + cluster_error_retry_attempts=3, + require_full_coverage=False, + reinitialize_steps=10, + read_from_replicas=False, + url=None, + **kwargs, + ): + """ + Initialize a new RedisCluster client. + + :startup_nodes: 'list[ClusterNode]' + List of nodes from which initial bootstrapping can be done + :host: 'str' + Can be used to point to a startup node + :port: 'int' + Can be used to point to a startup node + :require_full_coverage: 'bool' + When set to False (default value): the client will not require a + full coverage of the slots. However, if not all slots are covered, + and at least one node has 'cluster-require-full-coverage' set to + 'yes,' the server will throw a ClusterDownError for some key-based + commands. See - + https://redis.io/topics/cluster-tutorial#redis-cluster-configuration-parameters + When set to True: all slots must be covered to construct the + cluster client. If not all slots are covered, RedisClusterException + will be thrown. + :read_from_replicas: 'bool' + Enable read from replicas in READONLY mode. You can read possibly + stale data. + When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. + :cluster_error_retry_attempts: 'int' + Retry command execution attempts when encountering ClusterDownError + or ConnectionError + :reinitialize_steps: 'int' + Specifies the number of MOVED errors that need to occur before + reinitializing the whole cluster topology. If a MOVED error occurs + and the cluster does not need to be reinitialized on this current + error handling, only the MOVED slot will be patched with the + redirected node. + To reinitialize the cluster on every MOVED error, set + reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set + reinitialize_steps to 0. + + :**kwargs: + Extra arguments that will be sent into Redis instance when created + (See Official redis-py doc for supported kwargs + [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) + Some kwargs are not supported and will raise a + RedisClusterException: + - db (Redis do not support database SELECT in cluster mode) + """ + if startup_nodes is None: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + from_url = False + if url is not None: + from_url = True + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections" + ) + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + host = kwargs.get("host") + port = kwargs.get("port", port) + startup_nodes.append(ClusterNode(host, port)) + elif host is not None and port is not None: + startup_nodes.append(ClusterNode(host, port)) + elif len(startup_nodes) == 0: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])" + ) + log.debug(f"startup_nodes : {startup_nodes}") + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + # If the user passed on_connect function we'll save it and run it + # inside the RedisCluster.on_connect() function + self.user_on_connect_func = kwargs.pop("redis_connect_func", None) + kwargs.update({"redis_connect_func": self.on_connect}) + kwargs = cleanup_kwargs(**kwargs) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = None + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + from_url=from_url, + require_full_coverage=require_full_coverage, + **kwargs, + ) + + self.cluster_response_callbacks = CaseInsensitiveDict( + self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS + ) + self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + self.commands_parser = CommandsParser(self) + self._lock = threading.Lock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def disconnect_connection_pools(self): + for node in self.get_nodes(): + if node.redis_connection: + try: + node.redis_connection.connection_pool.disconnect() + except OSError: + # Client was already disconnected. do nothing + pass + + def on_connect(self, connection): + """ + Initialize the connection, authenticate and select a database and send + READONLY if it is set during object initialization. + """ + connection.set_parser(ClusterParser) + connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + connection.send_command("READONLY") + if str_if_bytes(connection.read_response()) != "OK": + raise ConnectionError("READONLY command failed") + + if self.user_on_connect_func is not None: + self.user_on_connect_func(connection) + + def get_redis_connection(self, node): + if not node.redis_connection: + with self._lock: + if not node.redis_connection: + self.nodes_manager.create_redis_connections([node]) + return node.redis_connection + + def get_node(self, host=None, port=None, node_name=None): + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self): + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self): + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self): + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self): + return list(self.nodes_manager.nodes_cache.values()) + + def get_node_from_key(self, key, replica=False): + """ + Get the node that holds the key's slot. + If replica set to True but the slot doesn't have any replicas, None is + returned. + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if slot_cache is None or len(slot_cache) == 0: + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') + if replica and len(self.nodes_manager.slots_cache[slot]) < 2: + return None + elif replica: + node_idx = 1 + else: + # primary + node_idx = 0 + + return slot_cache[node_idx] + + def get_default_node(self): + """ + Get the cluster's default node + """ + return self.nodes_manager.default_node + + def set_default_node(self, node): + """ + Set the default node of the cluster. + :param node: 'ClusterNode' + :return True if the default node was set, else False + """ + if node is None or self.get_node(node_name=node.name) is None: + log.info( + "The requested node does not exist in the cluster, so " + "the default node was not changed." + ) + return False + self.nodes_manager.default_node = node + log.info(f"Changed the default cluster node to {node}") + return True + + def monitor(self, target_node=None): + """ + Returns a Monitor object for the specified target node. + The default cluster node will be selected if no target node was + specified. + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + if target_node is None: + target_node = self.get_default_node() + if target_node.redis_connection is None: + raise RedisClusterException( + f"Cluster Node {target_node.name} has no redis_connection" + ) + return target_node.redis_connection.monitor() + + def pubsub(self, node=None, host=None, port=None, **kwargs): + """ + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + + def pipeline(self, transaction=None, shard_hint=None): + """ + Cluster impl: + Pipelines do not work in cluster mode the same way they + do in normal mode. Create a clone of this object so + that simulating pipelines will work correctly. Each + command will be called directly when used and + when calling execute() will only return the result stack. + """ + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline( + nodes_manager=self.nodes_manager, + commands_parser=self.commands_parser, + startup_nodes=self.nodes_manager.startup_nodes, + result_callbacks=self.result_callbacks, + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + read_from_replicas=self.read_from_replicas, + reinitialize_steps=self.reinitialize_steps, + ) + + def lock( + self, + name, + timeout=None, + sleep=0.1, + blocking_timeout=None, + lock_class=None, + thread_local=True, + ): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def set_response_callback(self, command, callback): + """Set a custom Response Callback""" + self.cluster_response_callbacks[command] = callback + + def _determine_nodes(self, *args, **kwargs): + command = args[0] + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the nodes group for this command if it was predefined + command_flag = self.command_flags.get(command) + if command_flag: + log.debug(f"Target node/s for {command}: {command_flag}") + if command_flag == self.__class__.RANDOM: + # return a random node + return [self.get_random_node()] + elif command_flag == self.__class__.PRIMARIES: + # return all primaries + return self.get_primaries() + elif command_flag == self.__class__.REPLICAS: + # return all replicas + return self.get_replicas() + elif command_flag == self.__class__.ALL_NODES: + # return all nodes + return self.get_nodes() + elif command_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + elif command in self.__class__.SEARCH_COMMANDS[0]: + return [self.nodes_manager.default_node] + else: + # get the node that holds the key's slot + slot = self.determine_slot(*args) + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) + log.debug(f"Target for {args}: slot {slot}") + return [node] + + def _should_reinitialized(self): + # To reinitialize the cluster on every MOVED error, + # set reinitialize_steps to 1. + # To avoid reinitializing the cluster on moved errors, set + # reinitialize_steps to 0. + if self.reinitialize_steps == 0: + return False + else: + return self.reinitialize_counter % self.reinitialize_steps == 0 + + def keyslot(self, key): + """ + Calculate keyslot for a given key. + See Keys distribution model in https://redis.io/topics/cluster-spec + """ + k = self.encoder.encode(key) + return key_slot(k) + + def _get_command_keys(self, *args): + """ + Get the keys in the command. If the command has no keys in in, None is + returned. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + redis_conn = self.get_default_node().redis_connection + return self.commands_parser.get_keys(redis_conn, *args) + + def determine_slot(self, *args): + """ + Figure out what slot to use based on args. + + Raises a RedisClusterException if there's a missing key and we can't + determine what slots to map the command to; or, if the keys don't + all map to the same key slot. + """ + command = args[0] + if self.command_flags.get(command) == SLOT_ID: + # The command contains the slot ID + return args[1] + + # Get the keys in the command + + # EVAL and EVALSHA are common enough that it's wasteful to go to the + # redis server to parse the keys. Besides, there is a bug in redis<7.0 + # where `self._get_command_keys()` fails anyway. So, we special case + # EVAL/EVALSHA. + if command in ("EVAL", "EVALSHA"): + # command syntax: EVAL "script body" num_keys ... + if len(args) <= 2: + raise RedisClusterException(f"Invalid args in command: {args}") + num_actual_keys = args[2] + eval_keys = args[3 : 3 + num_actual_keys] + # if there are 0 keys, that means the script can be run on any node + # so we can just return a random slot + if len(eval_keys) == 0: + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + keys = eval_keys + else: + keys = self._get_command_keys(*args) + if keys is None or len(keys) == 0: + # FCALL can call a function with 0 keys, that means the function + # can be run on any node so we can just return a random slot + if command in ("FCALL", "FCALL_RO"): + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + f"target nodes.\nCommand: {args}" + ) + + # single key command + if len(keys) == 1: + return self.keyslot(keys[0]) + + # multi-key command; we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException( + f"{command} - all keys must map to the same key slot" + ) + + return slots.pop() + + def reinitialize_caches(self): + self.nodes_manager.initialize() + + def get_encoder(self): + """ + Get the connections' encoder + """ + return self.encoder + + def get_connection_kwargs(self): + """ + Get the connections' key-word arguments + """ + return self.nodes_manager.connection_kwargs + + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or dict. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + def execute_command(self, *args, **kwargs): + """ + Wrapper for ERRORS_ALLOW_RETRY error handling. + + It will try the number of times specified by the config option + "self.cluster_error_retry_attempts" which defaults to 3 unless manually + configured. + + If it reaches the number of times, the command will raise the exception + + Key argument :target_nodes: can be passed with the following types: + nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM + ClusterNode + list + dict + """ + target_nodes_specified = False + target_nodes = None + passed_targets = kwargs.pop("target_nodes", None) + if passed_targets is not None and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + # If an error that allows retrying was thrown, the nodes and slots + # cache were reinitialized. We will retry executing the command with + # the updated cluster setup only when the target nodes can be + # determined again with the new cache tables. Therefore, when target + # nodes were passed to this function, we cannot retry the command + # execution since the nodes may not be valid anymore after the tables + # were reinitialized. So in case of passed target nodes, + # retry_attempts will be set to 1. + retry_attempts = ( + 1 if target_nodes_specified else self.cluster_error_retry_attempts + ) + exception = None + for _ in range(0, retry_attempts): + try: + res = {} + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = self._determine_nodes( + *args, **kwargs, nodes_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {args} command on" + ) + for node in target_nodes: + res[node.name] = self._execute_command(node, *args, **kwargs) + # Return the processed result + return self._process_result(args[0], res, **kwargs) + except BaseException as e: + if type(e) in RedisCluster.ERRORS_ALLOW_RETRY: + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. + exception = e + else: + # All other errors should be raised. + raise e + + # If it fails the configured number of times then raise exception back + # to caller of this method + raise exception + + def _execute_command(self, target_node, *args, **kwargs): + """ + Send a command to a node in the cluster + """ + command = args[0] + redis_node = None + connection = None + redirect_addr = None + asking = False + moved = False + ttl = int(self.RedisClusterRequestTTL) + connection_error_retry_counter = 0 + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = self.determine_slot(*args) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) + moved = False + + log.debug( + f"Executing command {command} on target node: " + f"{target_node.server_type} {target_node.name}" + ) + redis_node = self.get_redis_connection(target_node) + connection = get_connection(redis_node, *args, **kwargs) + if asking: + connection.send_command("ASKING") + redis_node.parse_response(connection, "ASKING", **kwargs) + asking = False + + connection.send_command(*args) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + return response + + except (RedisClusterException, BusyLoadingError) as e: + log.exception(type(e)) + raise + except (ConnectionError, TimeoutError) as e: + log.exception(type(e)) + # ConnectionError can also be raised if we couldn't get a + # connection from the pool before timing out, so check that + # this is an actual connection before attempting to disconnect. + if connection is not None: + connection.disconnect() + connection_error_retry_counter += 1 + + # Give the node 0.25 seconds to get back up and retry again + # with same node and configuration. After 5 attempts then try + # to reinitialize the cluster and see if the nodes + # configuration has changed or not + if connection_error_retry_counter < 5: + time.sleep(0.25) + else: + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + self.nodes_manager.initialize() + raise + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when + # the same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + log.exception("MovedError") + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + # Reset the counter + self.reinitialize_counter = 0 + else: + self.nodes_manager.update_moved_exception(e) + moved = True + except TryAgainError: + log.exception("TryAgainError") + + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except AskError as e: + log.exception("AskError") + + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError as e: + log.exception("ClusterDownError") + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + time.sleep(0.25) + self.nodes_manager.initialize() + raise e + except ResponseError as e: + message = e.__str__() + log.exception(f"ResponseError: {message}") + raise e + except BaseException as e: + log.exception("BaseException") + if connection: + connection.disconnect() + raise e + finally: + if connection is not None: + redis_node.connection_pool.release(connection) + + raise ClusterError("TTL exhausted.") + + def close(self): + try: + with self._lock: + if self.nodes_manager: + self.nodes_manager.close() + except AttributeError: + # RedisCluster's __init__ can fail before nodes_manager is set + pass + + def _process_result(self, command, res, **kwargs): + """ + Process the result of the executed command. + The function would return a dict or a single value. + + :type command: str + :type res: dict + + `res` should be in the following format: + Dict + """ + if command in self.result_callbacks: + return self.result_callbacks[command](command, res, **kwargs) + elif len(res) == 1: + # When we execute the command on a single node, we can + # remove the dictionary and return a single response + return list(res.values())[0] + else: + return res + + def load_external_module(self, funcname, func): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + ``funcname`` - A string containing the name of the function to create + ``func`` - The function, being added to this class. + """ + setattr(self, funcname, func) + + +class ClusterNode: + def __init__(self, host, port, server_type=None, redis_connection=None): + if host == "localhost": + host = socket.gethostbyname(host) + + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + self.redis_connection = redis_connection + + def __repr__(self): + return ( + f"[host={self.host}," + f"port={self.port}," + f"name={self.name}," + f"server_type={self.server_type}," + f"redis_connection={self.redis_connection}]" + ) + + def __eq__(self, obj): + return isinstance(obj, ClusterNode) and obj.name == self.name + + def __del__(self): + if self.redis_connection is not None: + self.redis_connection.close() + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self, start_index=0): + self.primary_to_idx = {} + self.start_index = start_index + + def get_server_index(self, primary, list_size): + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + # Update the index + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + + def reset(self): + self.primary_to_idx.clear() + + +class NodesManager: + def __init__( + self, + startup_nodes, + from_url=False, + require_full_coverage=False, + lock=None, + **kwargs, + ): + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {} + self.default_node = None + self.populate_startup_nodes(startup_nodes) + self.from_url = from_url + self._require_full_coverage = require_full_coverage + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + if lock is None: + lock = threading.Lock() + self._lock = lock + self.initialize() + + def get_node(self, host=None, port=None, node_name=None): + """ + Get the requested node from the cluster's nodes. + nodes. + :return: ClusterNode if the node exists, else None + """ + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + log.error( + "get_node requires one of the following: " + "1. node name " + "2. host and port" + ) + return None + + def update_moved_exception(self, exception): + self._moved_exception = exception + + def _update_moved_slots(self): + """ + Update the slot's node with the redirected one + """ + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node is not None: + # The node already exists + if redirected_node.server_type is not PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode(e.host, e.port, PRIMARY) + self.nodes_cache[redirected_node.name] = redirected_node + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + """ + Gets a node that servers this hash slot + """ + if self._moved_exception: + with self._lock: + if self._moved_exception: + self._update_moved_slots() + + if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + f'"require_full_coverage={self._require_full_coverage}"' + ) + + if read_from_replicas is True: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + elif ( + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 + ): + # return a primary + node_idx = 0 + else: + # return a replica + # randomly choose one of the replicas + node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) + + return self.slots_cache[slot][node_idx] + + def get_nodes_by_server_type(self, server_type): + """ + Get all nodes with the specified server type + :param server_type: 'primary' or 'replica' + :return: list of ClusterNode + """ + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + def populate_startup_nodes(self, nodes): + """ + Populate all startup nodes and filters out any duplicates + """ + for n in nodes: + self.startup_nodes[n.name] = n + + def check_slots_coverage(self, slots_cache): + # Validate if all slots are covered or if we should try next + # startup node + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in slots_cache: + return False + return True + + def create_redis_connections(self, nodes): + """ + This function will create a redis connection to all nodes in :nodes: + """ + for node in nodes: + if node.redis_connection is None: + node.redis_connection = self.create_redis_node( + host=node.host, port=node.port, **self.connection_kwargs + ) + + def create_redis_node(self, host, port, **kwargs): + if self.from_url: + # Create a redis node with a costumed connection pool + kwargs.update({"host": host}) + kwargs.update({"port": port}) + r = Redis(connection_pool=ConnectionPool(**kwargs)) + else: + r = Redis(host=host, port=port, **kwargs) + return r + + def initialize(self): + """ + Initializes the nodes cache, slots cache and redis connections. + :startup_nodes: + Responsible for discovering other nodes in the cluster + """ + log.debug("Initializing the nodes' topology of the cluster") + self.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + fully_covered = False + kwargs = self.connection_kwargs + for startup_node in self.startup_nodes.values(): + try: + if startup_node.redis_connection: + r = startup_node.redis_connection + else: + # Create a new Redis connection and let Redis decode the + # responses so we won't need to handle that + copy_kwargs = copy.deepcopy(kwargs) + copy_kwargs.update({"decode_responses": True, "encoding": "utf-8"}) + r = self.create_redis_node( + startup_node.host, startup_node.port, **copy_kwargs + ) + self.startup_nodes[startup_node.name].redis_connection = r + # Make sure cluster mode is enabled on this node + if bool(r.info().get("cluster_enabled")) is False: + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + startup_nodes_reachable = True + except (ConnectionError, TimeoutError) as e: + msg = e.__str__ + log.exception( + "An exception occurred while trying to" + " initialize the cluster using the seed node" + f" {startup_node.name}:\n{msg}" + ) + continue + except ResponseError as e: + log.exception('ReseponseError sending "cluster slots" to redis server') + + # Isn't a cluster connection, so it won't parse these + # exceptions automatically + message = e.__str__() + if "CLUSTERDOWN" in message or "MASTERDOWN" in message: + continue + else: + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server: {startup_node}. error: {message}" + ) + except Exception as e: + message = e.__str__() + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server {startup_node.name}. error: {message}" + ) + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if target_node is None: + target_node = ClusterNode(host, port, PRIMARY) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if target_replica_node is None: + target_replica_node = ClusterNode(host, port, REPLICA) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" + ) + + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f'slots cache: {", ".join(disagreements)}' + ) + + fully_covered = self.check_slots_coverage(tmp_slots) + if fully_covered: + # Don't need to continue to the next startup node if all + # slots are covered + break + + if not startup_nodes_reachable: + raise RedisClusterException( + "Redis Cluster cannot be connected. Please provide at least " + "one reachable node. " + ) + + # Create Redis connections to all nodes + self.create_redis_connections(list(tmp_nodes_cache.values())) + + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) + + # Set the tmp variables to the real variables + self.nodes_cache = tmp_nodes_cache + self.slots_cache = tmp_slots + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + # Populate the startup nodes with all discovered nodes + self.populate_startup_nodes(self.nodes_cache.values()) + # If initialize was called after a MovedError, clear it + self._moved_exception = None + + def close(self): + self.default_node = None + for node in self.nodes_cache.values(): + if node.redis_connection: + node.redis_connection.close() + + def reset(self): + try: + self.read_load_balancer.reset() + except TypeError: + # The read_load_balancer is None, do nothing + pass + + +class ClusterPubSub(PubSub): + """ + Wrapper for PubSub class. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + """ + + def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + + :type redis_cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int + """ + self.node = None + self.set_pubsub_node(redis_cluster, node, host, port) + connection_pool = ( + None + if self.node is None + else redis_cluster.get_redis_connection(self.node).connection_pool + ) + self.cluster = redis_cluster + super().__init__( + **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + ) + + def set_pubsub_node(self, cluster, node=None, host=None, port=None): + """ + The pubsub node will be set according to the passed node, host and port + When none of the node, host, or port are specified - the node is set + to None and will be determined by the keyslot of the channel in the + first command to be executed. + RedisClusterException will be thrown if the passed node does not exist + in the cluster. + If host is passed without port, or vice versa, a DataError will be + thrown. + :type cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int + """ + if node is not None: + # node is passed by the user + self._raise_on_invalid_node(cluster, node, node.host, node.port) + pubsub_node = node + elif host is not None and port is not None: + # host and port passed by the user + node = cluster.get_node(host=host, port=port) + self._raise_on_invalid_node(cluster, node, host, port) + pubsub_node = node + elif any([host, port]) is True: + # only 'host' or 'port' passed + raise DataError("Passing a host requires passing a port, " "and vice versa") + else: + # nothing passed by the user. set node to None + pubsub_node = None + + self.node = pubsub_node + + def get_pubsub_node(self): + """ + Get the node that is being used as the pubsub connection + """ + return self.node + + def _raise_on_invalid_node(self, redis_cluster, node, host, port): + """ + Raise a RedisClusterException if the node is None or doesn't exist in + the cluster. + """ + if node is None or redis_cluster.get_node(node_name=node.name) is None: + raise RedisClusterException( + f"Node {host}:{port} doesn't exist in the cluster" + ) + + def execute_command(self, *args, **kwargs): + """ + Execute a publish/subscribe command. + + Taken code from redis-py and tweak to make it work within a cluster. + """ + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + if self.connection_pool is None: + if len(args) > 1: + # Hash the first channel and get one of the nodes holding + # this slot + channel = args[1] + slot = self.cluster.keyslot(channel) + node = self.cluster.nodes_manager.get_node_from_slot( + slot, self.cluster.read_from_replicas + ) + else: + # Get a random node + node = self.cluster.get_random_node() + self.node = node + redis_connection = self.cluster.get_redis_connection(node) + self.connection_pool = redis_connection.connection_pool + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + self._execute(connection, connection.send_command, *args) + + def get_redis_connection(self): + """ + Get the Redis connection of the pubsub connected node. + """ + if self.node is not None: + return self.node.redis_connection + + +class ClusterPipeline(RedisCluster): + """ + Support for Redis pipeline + in cluster mode + """ + + ERRORS_ALLOW_RETRY = ( + ConnectionError, + TimeoutError, + MovedError, + AskError, + TryAgainError, + ) + + def __init__( + self, + nodes_manager, + commands_parser, + result_callbacks=None, + cluster_response_callbacks=None, + startup_nodes=None, + read_from_replicas=False, + cluster_error_retry_attempts=5, + reinitialize_steps=10, + **kwargs, + ): + """ """ + self.command_stack = [] + self.nodes_manager = nodes_manager + self.commands_parser = commands_parser + self.refresh_table_asap = False + self.result_callbacks = ( + result_callbacks or self.__class__.RESULT_CALLBACKS.copy() + ) + self.startup_nodes = startup_nodes if startup_nodes else [] + self.read_from_replicas = read_from_replicas + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.cluster_response_callbacks = cluster_response_callbacks + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + + def __repr__(self): + """ """ + return f"{type(self).__name__}" + + def __enter__(self): + """ """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ """ + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self): + """ """ + return len(self.command_stack) + + def __nonzero__(self): + "Pipeline instances should always evaluate to True on Python 2.7" + return True + + def __bool__(self): + "Pipeline instances should always evaluate to True on Python 3+" + return True + + def execute_command(self, *args, **kwargs): + """ + Wrapper function for pipeline_execute_command + """ + return self.pipeline_execute_command(*args, **kwargs) + + def pipeline_execute_command(self, *args, **options): + """ + Appends the executed command to the pipeline's command stack + """ + self.command_stack.append( + PipelineCommand(args, options, len(self.command_stack)) + ) + return self + + def raise_first_error(self, stack): + """ + Raise the first exception on the stack + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def annotate_exception(self, exception, number, command): + """ + Provides extra context to the exception prior to it being handled + """ + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] + + def execute(self, raise_on_error=True): + """ + Execute all the commands in the current pipeline + """ + stack = self.command_stack + try: + return self.send_cluster_commands(stack, raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self.command_stack = [] + + self.scripts = set() + + # TODO: Implement + # make sure to reset the connection state in the event that we were + # watching something + # if self.watching and self.connection: + # try: + # # call this manually since our unwatch or + # # immediate_execute_command methods can call reset() + # self.connection.send_command('UNWATCH') + # self.connection.read_response() + # except ConnectionError: + # # disconnect will also remove any previous WATCHes + # self.connection.disconnect() + + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + + # TODO: Implement + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + # if self.connection: + # self.connection_pool.release(self.connection) + # self.connection = None + + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Wrapper for CLUSTERDOWN error handling. + + If the cluster reports it is down it is assumed that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the config option "self.cluster_error_retry_attempts" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + if not stack: + return [] + + for _ in range(0, self.cluster_error_retry_attempts): + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except ClusterDownError: + # Try again with the new cluster setup. All other errors + # should be raised. + pass + + # If it fails the configured number of times then raise + # exception back to caller of this method + raise ClusterDownError("CLUSTERDOWN error. Unable to rebuild the cluster") + + def _send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + # refer to our internal node -> slot table that + # tells us where a given + # command should route to. + node = self._determine_nodes(*c.args) + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node[0].name + if node_name not in nodes: + redis_node = self.get_redis_connection(node[0]) + connection = get_connection(redis_node, c.args) + nodes[node_name] = NodeCommands( + redis_node.parse_response, redis_node.connection_pool, connection + ) + + nodes[node_name].append(c) + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network essentially in parallel + # so that we can read them all in parallel as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted( + ( + c + for c in attempt + if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) + ), + key=lambda x: x.position, + ) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaing comamnds one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + log.exception( + f"An exception occurred during pipeline execution. " + f"args: {attempt[-1].args}, " + f"error: {type(attempt[-1].result).__name__} " + f"{str(attempt[-1].result)}" + ) + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = super().execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [] + for c in sorted(stack, key=lambda x: x.position): + if c.args[0] in self.cluster_response_callbacks: + c.result = self.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) + response.append(c.result) + + if raise_on_error: + self.raise_first_error(stack) + + return response + + def _fail_on_redirect(self, allow_redirections): + """ """ + if not allow_redirections: + raise RedisClusterException( + "ASK & MOVED redirection not allowed in this pipeline" + ) + + def exists(self, *keys): + return self.execute_command("EXISTS", *keys) + + def eval(self): + """ """ + raise RedisClusterException("method eval() is not implemented") + + def multi(self): + """ """ + raise RedisClusterException("method multi() is not implemented") + + def immediate_execute_command(self, *args, **options): + """ """ + raise RedisClusterException( + "method immediate_execute_command() is not implemented" + ) + + def _execute_transaction(self, *args, **kwargs): + """ """ + raise RedisClusterException("method _execute_transaction() is not implemented") + + def load_scripts(self): + """ """ + raise RedisClusterException("method load_scripts() is not implemented") + + def watch(self, *names): + """ """ + raise RedisClusterException("method watch() is not implemented") + + def unwatch(self): + """ """ + raise RedisClusterException("method unwatch() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented" + ) + + def delete(self, *names): + """ + "Delete a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not " "implemented in pipeline command" + ) + + return self.execute_command("DEL", names[0]) + + +def block_pipeline_command(func): + """ + Prints error because some pipelined commands should + be blocked when running in cluster-mode + """ + + def inner(*args, **kwargs): + raise RedisClusterException( + f"ERROR: Calling pipelined function {func.__name__} is blocked " + f"when running redis in cluster mode..." + ) + + return inner + + +# Blocked pipeline commands +ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) +ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) +ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname) +ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) +ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname) +ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) +ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) +ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) +ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) +ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) +ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) +ClusterPipeline.move = block_pipeline_command(RedisCluster.move) +ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) +ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) +ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) +ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) +ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) +ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) +ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) +ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) +ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) +ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) +ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) +ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) +ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) +ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) +ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) +ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) +ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) +ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) +ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) +ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite) +ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly) + + +class PipelineCommand: + """ """ + + def __init__(self, args, options=None, position=None): + self.args = args + if options is None: + options = {} + self.options = options + self.position = position + self.result = None + self.node = None + self.asking = False + + +class NodeCommands: + """ """ + + def __init__(self, parse_response, connection_pool, connection): + """ """ + self.parse_response = parse_response + self.connection_pool = connection_pool + self.connection = connection + self.commands = [] + + def append(self, c): + """ """ + self.commands.append(c) + + def write(self): + """ + Code borrowed from Redis so it can be fixed + """ + connection = self.connection + commands = self.commands + + # We are going to clobber the commands with the write, so go ahead + # and ensure that nothing is sitting there from a previous run. + for c in commands: + c.result = None + + # build up all commands into a single request to increase network perf + # send all the commands and catch connection and timeout errors. + try: + connection.send_packed_command( + connection.pack_commands([c.args for c in commands]) + ) + except (ConnectionError, TimeoutError) as e: + for c in commands: + c.result = e + + def read(self): + """ """ + connection = self.connection + for c in self.commands: + + # if there is a result on this command, + # it means we ran into an exception + # like a connection error. Trying to parse + # a response on a connection that + # is no longer open will result in a + # connection error raised by redis-py. + # but redis-py doesn't check in parse_response + # that the sock object is + # still set and if you try to + # read from a closed connection, it will + # result in an AttributeError because + # it will do a readline() call on None. + # This can have all kinds of nasty side-effects. + # Treating this case as a connection error + # is fine because it will dump + # the connection object back into the + # pool and on the next write, it will + # explicitly open the connection and all will be well. + if c.result is None: + try: + c.result = self.parse_response(connection, c.args[0], **c.options) + except (ConnectionError, TimeoutError) as e: + for c in self.commands: + c.result = e + return + except RedisError: + c.result = sys.exc_info()[1] diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py new file mode 100644 index 0000000000..89292ab2d3 --- /dev/null +++ b/redis/asyncio/parser.py @@ -0,0 +1,143 @@ +from redis.exceptions import RedisError, ResponseError +from redis.utils import str_if_bytes + + +class CommandsParser: + """ + Parses Redis commands to get command keys. + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with + 'movablekeys', and these commands' keys are determined by the command + 'COMMAND GETKEYS'. + """ + + def __init__(self, redis_connection): + self.initialized = False + self.commands = {} + self.initialize(redis_connection) + + def initialize(self, r): + commands = r.execute_command("COMMAND") + uppercase_commands = [] + for cmd in commands: + if any(x.isupper() for x in cmd): + uppercase_commands.append(cmd) + for cmd in uppercase_commands: + commands[cmd.lower()] = commands.pop(cmd) + self.commands = commands + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + def get_keys(self, redis_conn, *args): + """ + Get the keys from the passed command. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + if cmd_name not in self.commands: + # try to split the command name and to take only the main command, + # e.g. 'memory' for 'memory usage' + cmd_name_split = cmd_name.split() + cmd_name = cmd_name_split[0] + if cmd_name in self.commands: + # save the splitted command to args + args = cmd_name_split + list(args[1:]) + else: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands.get(cmd_name) + if "movablekeys" in command["flags"]: + keys = self._get_moveable_keys(redis_conn, *args) + elif "pubsub" in command["flags"]: + keys = self._get_pubsub_keys(*args) + else: + if ( + command["step_count"] == 0 + and command["first_key_pos"] == 0 + and command["last_key_pos"] == 0 + ): + # The command doesn't have keys in it + return None + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) - abs(last_key_pos) + keys_pos = list( + range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) + ) + keys = [args[pos] for pos in keys_pos] + + return keys + + def _get_moveable_keys(self, redis_conn, *args): + """ + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + pieces = [] + cmd_name = args[0] + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces = pieces + cmd_name.split() + pieces = pieces + list(args[1:]) + try: + keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys + + def _get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + keys = None + if command == "PUBSUB": + # the second argument is a part of the command name, e.g. + # ['PUBSUB', 'NUMSUB', 'foo']. + pubsub_type = args[1].upper() + if pubsub_type in ["CHANNELS", "NUMSUB"]: + keys = args[2:] + elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: + # format example: + # SUBSCRIBE channel [channel ...] + keys = list(args[1:]) + elif command == "PUBLISH": + # format example: + # PUBLISH channel message + keys = [args[1]] + return keys diff --git a/redis/cluster.py b/redis/cluster.py index 221df856c1..92469e4f37 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -128,10 +128,7 @@ def fix_server(*args): "unix_socket_path", "username", ) -KWARGS_DISABLED_KEYS = ( - "host", - "port", -) +KWARGS_DISABLED_KEYS = ("host", "port") # Not complete, but covers the major ones # https://redis.io/commands @@ -308,10 +305,7 @@ class RedisCluster(RedisClusterCommands): ], PRIMARIES, ), - list_keys_to_dict( - ["FUNCTION DUMP"], - RANDOM, - ), + list_keys_to_dict(["FUNCTION DUMP"], RANDOM), list_keys_to_dict( [ "CLUSTER COUNTKEYSINSLOT", @@ -384,25 +378,11 @@ class RedisCluster(RedisClusterCommands): } RESULT_CALLBACKS = dict_merge( + list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), list_keys_to_dict( - [ - "PUBSUB NUMSUB", - ], - parse_pubsub_numsub, - ), - list_keys_to_dict( - [ - "PUBSUB NUMPAT", - ], - lambda command, res: sum(list(res.values())), - ), - list_keys_to_dict( - [ - "KEYS", - "PUBSUB CHANNELS", - ], - merge_result, + ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), + list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), list_keys_to_dict( [ "PING", @@ -420,49 +400,67 @@ class RedisCluster(RedisClusterCommands): lambda command, res: all(res.values()) if isinstance(res, dict) else res, ), list_keys_to_dict( - [ - "DBSIZE", - "WAIT", - ], + ["DBSIZE", "WAIT"], lambda command, res: sum(res.values()) if isinstance(res, dict) else res, ), list_keys_to_dict( - [ - "CLIENT UNBLOCK", - ], - lambda command, res: 1 if sum(res.values()) > 0 else 0, - ), - list_keys_to_dict( - [ - "SCAN", - ], - parse_scan_result, - ), - list_keys_to_dict( - [ - "SCRIPT LOAD", - ], - lambda command, res: list(res.values()).pop(), + ["CLIENT UNBLOCK"], lambda command, res: 1 if sum(res.values()) > 0 else 0 ), + list_keys_to_dict(["SCAN"], parse_scan_result), list_keys_to_dict( - [ - "SCRIPT EXISTS", - ], - lambda command, res: [all(k) for k in zip(*res.values())], + ["SCRIPT LOAD"], lambda command, res: list(res.values()).pop() ), list_keys_to_dict( - [ - "SCRIPT FLUSH", - ], - lambda command, res: all(res.values()), + ["SCRIPT EXISTS"], lambda command, res: [all(k) for k in zip(*res.values())] ), + list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - ClusterDownError, - ) + ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) def __init__( self, @@ -617,50 +615,6 @@ def disconnect_connection_pools(self): # Client was already disconnected. do nothing pass - @classmethod - def from_url(cls, url, **kwargs): - """ - Return a Redis client object configured from the given URL - - For example:: - - redis://[[username]:[password]]@localhost:6379/0 - rediss://[[username]:[password]]@localhost:6379/0 - unix://[[username]:[password]]@/path/to/socket.sock?db=0 - - Three URL schemes are supported: - - - `redis://` creates a TCP socket connection. See more at: - - - `rediss://` creates a SSL wrapped TCP socket connection. See more at: - - - ``unix://``: creates a Unix Domain Socket connection. - - The username, password, hostname, path and all querystring values - are passed through urllib.parse.unquote in order to replace any - percent-encoded values with their corresponding characters. - - There are several ways to specify a database number. The first value - found will be used: - - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument - of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. - - If none of these options are specified, the default db=0 is used. - - All querystring options are cast to their appropriate Python types. - Boolean arguments can be specified with string values "True"/"False" - or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments - and keyword arguments are passed to the ``ConnectionPool``'s - class initializer. In the case of conflicting arguments, querystring - arguments always win. - - """ - return cls(url=url, **kwargs) - def on_connect(self, connection): """ Initialize the connection, authenticate and select a database and send @@ -1243,11 +1197,7 @@ def _process_result(self, command, res, **kwargs): else: return res - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -1461,9 +1411,7 @@ def create_redis_connections(self, nodes): for node in nodes: if node.redis_connection is None: node.redis_connection = self.create_redis_node( - host=node.host, - port=node.port, - **self.connection_kwargs, + host=node.host, port=node.port, **self.connection_kwargs ) def create_redis_node(self, host, port, **kwargs): diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index e4628dbaa2..ddeafc43e4 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -167,6 +167,157 @@ def unlink(self, *keys): return self._split_command_across_slots("UNLINK", *keys) +class AsyncClusterMultiKeyCommands: + """ + A class containing commands that handle more than one key + """ + + def _partition_keys_by_slot(self, keys): + """ + Split keys into a dictionary that maps a slot to + a list of keys. + """ + slots_to_keys = {} + for key in keys: + k = self.encoder.encode(key) + slot = key_slot(k) + slots_to_keys.setdefault(slot, []).append(key) + + return slots_to_keys + + def mget_nonatomic(self, keys, *args): + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + + For more information see https://redis.io/commands/mget + """ + + from redis.client import EMPTY_RESPONSE + + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Call MGET for every slot and concatenate + # the results + # We must make sure that the keys are returned in order + all_results = {} + for slot_keys in slots_to_keys.values(): + slot_values = self.execute_command("MGET", *slot_keys, **options) + + slot_results = dict(zip(slot_keys, slot_values)) + all_results.update(slot_results) + + # Sort the results + vals_in_order = [all_results[key] for key in keys] + return vals_in_order + + def mset_nonatomic(self, mapping): + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + For more information see https://redis.io/commands/mset + """ + + # Partition the keys by slot + slots_to_pairs = {} + for pair in mapping.items(): + # encode the key + k = self.encoder.encode(pair[0]) + slot = key_slot(k) + slots_to_pairs.setdefault(slot, []).extend(pair) + + # Call MSET for every slot and concatenate + # the results (one result per slot) + res = [] + for pairs in slots_to_pairs.values(): + res.append(self.execute_command("MSET", *pairs)) + + return res + + def _split_command_across_slots(self, command, *keys): + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + total = 0 + for slot_keys in slots_to_keys.values(): + total += self.execute_command(command, *slot_keys) + + return total + + def exists(self, *keys): + """ + Returns the number of ``names`` that exist in the + whole cluster. The keys are first split up into slots + and then an EXISTS command is sent for every slot + + For more information see https://redis.io/commands/exists + """ + return self._split_command_across_slots("EXISTS", *keys) + + def delete(self, *keys): + """ + Deletes the given keys in the cluster. + The keys are first split up into slots + and then an DEL command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were deleted. + + For more information see https://redis.io/commands/del + """ + return self._split_command_across_slots("DEL", *keys) + + def touch(self, *keys): + """ + Updates the last access time of given keys across the + cluster. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were touched. + + For more information see https://redis.io/commands/touch + """ + return self._split_command_across_slots("TOUCH", *keys) + + def unlink(self, *keys): + """ + Remove the specified keys in a different thread. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were unlinked. + + For more information see https://redis.io/commands/unlink + """ + return self._split_command_across_slots("UNLINK", *keys) + + class ClusterManagementCommands(ManagementCommands): """ A class for Redis Cluster management commands @@ -200,6 +351,39 @@ def swapdb(self, *args, **kwargs): raise RedisClusterException("SWAPDB is not supported in cluster" " mode") +class AsyncClusterManagementCommands(ManagementCommands): + """ + A class for Redis Cluster management commands + + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ + + def slaveof(self, *args, **kwargs): + """ + Make the server a replica of another instance, or promote it as master. + + For more information see https://redis.io/commands/slaveof + """ + raise RedisClusterException("SLAVEOF is not supported in cluster mode") + + def replicaof(self, *args, **kwargs): + """ + Make the server a replica of another instance, or promote it as master. + + For more information see https://redis.io/commands/replicaof + """ + raise RedisClusterException("REPLICAOF is not supported in cluster" " mode") + + def swapdb(self, *args, **kwargs): + """ + Swaps two Redis databases. + + For more information see https://redis.io/commands/swapdb + """ + raise RedisClusterException("SWAPDB is not supported in cluster" " mode") + + class ClusterDataAccessCommands(DataAccessCommands): """ A class for Redis Cluster Data Access Commands @@ -291,7 +475,403 @@ def scan_iter( } -class RedisClusterCommands( +class AsyncClusterDataAccessCommands(DataAccessCommands): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + def stralgo( + self, + algo, + value1, + value2, + specific_argument="strings", + len=False, + idx=False, + minmatchlen=None, + withmatchlen=False, + **kwargs, + ): + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + + For more information see https://redis.io/commands/stralgo + """ + target_nodes = kwargs.pop("target_nodes", None) + if specific_argument == "strings" and target_nodes is None: + target_nodes = "default-node" + kwargs.update({"target_nodes": target_nodes}) + return super().stralgo( + algo, + value1, + value2, + specific_argument, + len, + idx, + minmatchlen, + withmatchlen, + **kwargs, + ) + + def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> Iterator: + # Do the first query with cursor=0 for all nodes + cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) + yield from data + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + yield from data + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class RedisClusterCommands( + ClusterMultiKeyCommands, + ClusterManagementCommands, + ACLCommands, + PubSubCommands, + ClusterDataAccessCommands, + ScriptCommands, + FunctionCommands, + RedisModuleCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ + + def cluster_myid(self, target_node): + """ + Returns the node’s id. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information check https://redis.io/commands/cluster-myid/ + """ + return self.execute_command("CLUSTER MYID", target_nodes=target_node) + + def cluster_addslots(self, target_node, *slots): + """ + Assign new hash slots to receiving node. Sends to specified node. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslots + """ + return self.execute_command( + "CLUSTER ADDSLOTS", *slots, target_nodes=target_node + ) + + def cluster_addslotsrange(self, target_node, *slots): + """ + Similar to the CLUSTER ADDSLOTS command. + The difference between the two commands is that ADDSLOTS takes a list of slots + to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges + (specified by start and end slots) to assign to the node. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslotsrange + """ + return self.execute_command( + "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node + ) + + def cluster_countkeysinslot(self, slot_id): + """ + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + + For more information see https://redis.io/commands/cluster-countkeysinslot + """ + return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) + + def cluster_count_failure_report(self, node_id): + """ + Return the number of failure reports active for a given node + Sends to a random node + + For more information see https://redis.io/commands/cluster-count-failure-reports + """ + return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) + + def cluster_delslots(self, *slots): + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + + For more information see https://redis.io/commands/cluster-delslots + """ + return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + + def cluster_delslotsrange(self, *slots): + """ + Similar to the CLUSTER DELSLOTS command. + The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove + from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove + from the node. + + For more information see https://redis.io/commands/cluster-delslotsrange + """ + return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) + + def cluster_failover(self, target_node, option=None): + """ + Forces a slave to perform a manual failover of its master + Sends to specified node + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-failover + """ + if option: + if option.upper() not in ["FORCE", "TAKEOVER"]: + raise RedisError( + f"Invalid option for CLUSTER FAILOVER command: {option}" + ) + else: + return self.execute_command( + "CLUSTER FAILOVER", option, target_nodes=target_node + ) + else: + return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) + + def cluster_info(self, target_nodes=None): + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. + + For more information see https://redis.io/commands/cluster-info + """ + return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) + + def cluster_keyslot(self, key): + """ + Returns the hash slot of the specified key + Sends to random node in the cluster + + For more information see https://redis.io/commands/cluster-keyslot + """ + return self.execute_command("CLUSTER KEYSLOT", key) + + def cluster_meet(self, host, port, target_nodes=None): + """ + Force a node cluster to handshake with another node. + Sends to specified node. + + For more information see https://redis.io/commands/cluster-meet + """ + return self.execute_command( + "CLUSTER MEET", host, port, target_nodes=target_nodes + ) + + def cluster_nodes(self): + """ + Get Cluster config for the node. + Sends to random node in the cluster + + For more information see https://redis.io/commands/cluster-nodes + """ + return self.execute_command("CLUSTER NODES") + + def cluster_replicate(self, target_nodes, node_id): + """ + Reconfigure a node as a slave of the specified master node + + For more information see https://redis.io/commands/cluster-replicate + """ + return self.execute_command( + "CLUSTER REPLICATE", node_id, target_nodes=target_nodes + ) + + def cluster_reset(self, soft=True, target_nodes=None): + """ + Reset a Redis Cluster node + + If 'soft' is True then it will send 'SOFT' argument + If 'soft' is False then it will send 'HARD' argument + + For more information see https://redis.io/commands/cluster-reset + """ + return self.execute_command( + "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes + ) + + def cluster_save_config(self, target_nodes=None): + """ + Forces the node to save cluster state on disk + + For more information see https://redis.io/commands/cluster-saveconfig + """ + return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) + + def cluster_get_keys_in_slot(self, slot, num_keys): + """ + Returns the number of keys in the specified cluster slot + + For more information see https://redis.io/commands/cluster-getkeysinslot + """ + return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) + + def cluster_set_config_epoch(self, epoch, target_nodes=None): + """ + Set the configuration epoch in a new node + + For more information see https://redis.io/commands/cluster-set-config-epoch + """ + return self.execute_command( + "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes + ) + + def cluster_setslot(self, target_node, node_id, slot_id, state): + """ + Bind an hash slot to a specific node + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-setslot + """ + if state.upper() in ("IMPORTING", "NODE", "MIGRATING"): + return self.execute_command( + "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node + ) + elif state.upper() == "STABLE": + raise RedisError('For "stable" state please use ' "cluster_setslot_stable") + else: + raise RedisError(f"Invalid slot state: {state}") + + def cluster_setslot_stable(self, slot_id): + """ + Clears migrating / importing state from the slot. + It determines by it self what node the slot is in and sends it there. + + For more information see https://redis.io/commands/cluster-setslot + """ + return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") + + def cluster_replicas(self, node_id, target_nodes=None): + """ + Provides a list of replica nodes replicating from the specified primary + target node. + + For more information see https://redis.io/commands/cluster-replicas + """ + return self.execute_command( + "CLUSTER REPLICAS", node_id, target_nodes=target_nodes + ) + + def cluster_slots(self, target_nodes=None): + """ + Get array of Cluster slot to node mappings + + For more information see https://redis.io/commands/cluster-slots + """ + return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) + + def cluster_links(self, target_node): + """ + Each node in a Redis Cluster maintains a pair of long-lived TCP link with each + peer in the cluster: One for sending outbound messages towards the peer and one + for receiving inbound messages from the peer. + + This command outputs information of all such peer links as an array. + + For more information see https://redis.io/commands/cluster-links + """ + return self.execute_command("CLUSTER LINKS", target_nodes=target_node) + + def readonly(self, target_nodes=None): + """ + Enables read queries. + The command will be sent to the default cluster node if target_nodes is + not specified. + + For more information see https://redis.io/commands/readonly + """ + if target_nodes == "replicas" or target_nodes == "all": + # read_from_replicas will only be enabled if the READONLY command + # is sent to all replicas + self.read_from_replicas = True + return self.execute_command("READONLY", target_nodes=target_nodes) + + def readwrite(self, target_nodes=None): + """ + Disables read queries. + The command will be sent to the default cluster node if target_nodes is + not specified. + + For more information see https://redis.io/commands/readwrite + """ + # Reset read from replicas flag + self.read_from_replicas = False + return self.execute_command("READWRITE", target_nodes=target_nodes) + + +class AsyncRedisClusterCommands( ClusterMultiKeyCommands, ClusterManagementCommands, ACLCommands, diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py new file mode 100644 index 0000000000..2a2bc883bf --- /dev/null +++ b/tests/test_asyncio/test_cluster.py @@ -0,0 +1,2729 @@ +import binascii +import datetime +import warnings +from time import sleep +from unittest.mock import DEFAULT, Mock, call, patch + +import pytest + +from redis import Redis +from redis.cluster import ( + PRIMARY, + REDIS_CLUSTER_HASH_SLOTS, + REPLICA, + ClusterNode, + NodesManager, + RedisCluster, + get_node_name, +) +from redis.commands import CommandsParser +from redis.connection import Connection +from redis.crc import key_slot +from redis.exceptions import ( + AskError, + ClusterDownError, + ConnectionError, + DataError, + MovedError, + NoPermissionError, + RedisClusterException, + RedisError, + ResponseError, +) +from redis.utils import str_if_bytes +from tests.test_pubsub import wait_for_message + +from .conftest import ( + _get_client, + skip_if_redis_enterprise, + skip_if_server_version_lt, + skip_unless_arch_bits, + wait_for_command, +) + +default_host = "127.0.0.1" +default_port = 7000 +default_cluster_slots = [ + [0, 8191, ["127.0.0.1", 7000, "node_0"], ["127.0.0.1", 7003, "node_3"]], + [8192, 16383, ["127.0.0.1", 7001, "node_1"], ["127.0.0.1", 7002, "node_2"]], +] + + +@pytest.fixture() +def slowlog(request, r): + """ + Set the slowlog threshold to 0, and the + max length to 128. This will force every + command into the slowlog and allow us + to test it + """ + # Save old values + current_config = r.config_get(target_nodes=r.get_primaries()[0]) + old_slower_than_value = current_config["slowlog-log-slower-than"] + old_max_legnth_value = current_config["slowlog-max-len"] + + # Function to restore the old values + def cleanup(): + r.config_set("slowlog-log-slower-than", old_slower_than_value) + r.config_set("slowlog-max-len", old_max_legnth_value) + + request.addfinalizer(cleanup) + + # Set the new values + r.config_set("slowlog-log-slower-than", 0) + r.config_set("slowlog-max-len", 128) + + +def get_mocked_redis_client(func=None, *args, **kwargs): + """ + Return a stable RedisCluster object that have deterministic + nodes and slots setup to remove the problem of different IP addresses + on different installations and machines. + """ + cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) + coverage_res = kwargs.pop("coverage_result", "yes") + cluster_enabled = kwargs.pop("cluster_enabled", True) + with patch.object(Redis, "execute_command") as execute_command_mock: + + def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": + mock_cluster_slots = cluster_slots + return mock_cluster_slots + elif _args[0] == "COMMAND": + return {"get": [], "set": []} + elif _args[0] == "INFO": + return {"cluster_enabled": cluster_enabled} + elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": coverage_res} + elif func is not None: + return func(*args, **kwargs) + else: + return execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + return RedisCluster(*args, **kwargs) + + +def mock_node_resp(node, response): + connection = Mock() + connection.read_response.return_value = response + node.redis_connection.connection = connection + return node + + +def mock_node_resp_func(node, func): + connection = Mock() + connection.read_response.side_effect = func + node.redis_connection.connection = connection + return node + + +def mock_all_nodes_resp(rc, response): + for node in rc.get_nodes(): + mock_node_resp(node, response) + return rc + + +def find_node_ip_based_on_port(cluster_client, port): + for node in cluster_client.get_nodes(): + if node.port == port: + return node.host + + +def moved_redirection_helper(request, failover=False): + """ + Test that the client handles MOVED response after a failover. + Redirection after a failover means that the redirection address is of a + replica that was promoted to a primary. + + At first call it should return a MOVED ResponseError that will point + the client to the next server it should talk to. + + Verify that: + 1. it tries to talk to the redirected node + 2. it updates the slot's primary to the redirected node + + For a failover, also verify: + 3. the redirected node's server type updated to 'primary' + 4. the server type of the previous slot owner updated to 'replica' + """ + rc = _get_client(RedisCluster, request, flushdb=False) + slot = 12182 + redirect_node = None + # Get the current primary that holds this slot + prev_primary = rc.nodes_manager.get_node_from_slot(slot) + if failover: + if len(rc.nodes_manager.slots_cache[slot]) < 2: + warnings.warn("Skipping this test since it requires to have a " "replica") + return + redirect_node = rc.nodes_manager.slots_cache[slot][1] + else: + # Use one of the primaries to be the redirected node + redirect_node = rc.get_primaries()[0] + r_host = redirect_node.host + r_port = redirect_node.port + with patch.object(Redis, "parse_response") as parse_response: + + def moved_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == r_host + assert connection.port == r_port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise MovedError(f"{slot} {r_host}:{r_port}") + + parse_response.side_effect = moved_redirect_effect + assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" + slot_primary = rc.nodes_manager.slots_cache[slot][0] + assert slot_primary == redirect_node + if failover: + assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY + assert prev_primary.server_type == REPLICA + + +@pytest.mark.onlycluster +class TestRedisClusterObj: + """ + Tests for the RedisCluster class + """ + + def test_host_port_startup_node(self): + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = get_mocked_redis_client(host=default_host, port=default_port) + assert cluster.get_node(host=default_host, port=default_port) is not None + + def test_startup_nodes(self): + """ + Test that it is possible to use startup_nodes + argument to init the cluster + """ + port_1 = 7000 + port_2 = 7001 + startup_nodes = [ + ClusterNode(default_host, port_1), + ClusterNode(default_host, port_2), + ] + cluster = get_mocked_redis_client(startup_nodes=startup_nodes) + assert ( + cluster.get_node(host=default_host, port=port_1) is not None + and cluster.get_node(host=default_host, port=port_2) is not None + ) + + def test_empty_startup_nodes(self): + """ + Test that exception is raised when empty providing empty startup_nodes + """ + with pytest.raises(RedisClusterException) as ex: + RedisCluster(startup_nodes=[]) + + assert str(ex.value).startswith( + "RedisCluster requires at least one node to discover the " "cluster" + ), str_if_bytes(ex.value) + + def test_from_url(self, r): + redis_url = f"redis://{default_host}:{default_port}/0" + with patch.object(RedisCluster, "from_url") as from_url: + + def from_url_mocked(_url, **_kwargs): + return get_mocked_redis_client(url=_url, **_kwargs) + + from_url.side_effect = from_url_mocked + cluster = RedisCluster.from_url(redis_url) + assert cluster.get_node(host=default_host, port=default_port) is not None + + def test_execute_command_errors(self, r): + """ + Test that if no key is provided then exception should be raised. + """ + with pytest.raises(RedisClusterException) as ex: + r.execute_command("GET") + assert str(ex.value).startswith( + "No way to dispatch this command to " "Redis Cluster. Missing key." + ) + + def test_execute_command_node_flag_primaries(self, r): + """ + Test command execution with nodes flag PRIMARIES + """ + primaries = r.get_primaries() + replicas = r.get_replicas() + mock_all_nodes_resp(r, "PONG") + assert r.ping(target_nodes=RedisCluster.PRIMARIES) is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_replicas(self, r): + """ + Test command execution with nodes flag REPLICAS + """ + replicas = r.get_replicas() + if not replicas: + r = get_mocked_redis_client(default_host, default_port) + primaries = r.get_primaries() + mock_all_nodes_resp(r, "PONG") + assert r.ping(target_nodes=RedisCluster.REPLICAS) is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_all_nodes(self, r): + """ + Test command execution with nodes flag ALL_NODES + """ + mock_all_nodes_resp(r, "PONG") + assert r.ping(target_nodes=RedisCluster.ALL_NODES) is True + for node in r.get_nodes(): + conn = node.redis_connection.connection + assert conn.read_response.called is True + + def test_execute_command_node_flag_random(self, r): + """ + Test command execution with nodes flag RANDOM + """ + mock_all_nodes_resp(r, "PONG") + assert r.ping(target_nodes=RedisCluster.RANDOM) is True + called_count = 0 + for node in r.get_nodes(): + conn = node.redis_connection.connection + if conn.read_response.called is True: + called_count += 1 + assert called_count == 1 + + def test_execute_command_default_node(self, r): + """ + Test command execution without node flag is being executed on the + default node + """ + def_node = r.get_default_node() + mock_node_resp(def_node, "PONG") + assert r.ping() is True + conn = def_node.redis_connection.connection + assert conn.read_response.called + + def test_ask_redirection(self, r): + """ + Test that the server handles ASK response. + + At first call it should return a ASK ResponseError that will point + the client to the next server it should talk to. + + Important thing to verify is that it tries to talk to the second node. + """ + redirect_node = r.get_nodes()[0] + with patch.object(Redis, "parse_response") as parse_response: + + def ask_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == redirect_node.host + assert connection.port == redirect_node.port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise AskError(f"12182 {redirect_node.host}:{redirect_node.port}") + + parse_response.side_effect = ask_redirect_effect + + assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" + + def test_moved_redirection(self, request): + """ + Test that the client handles MOVED response. + """ + moved_redirection_helper(request, failover=False) + + def test_moved_redirection_after_failover(self, request): + """ + Test that the client handles MOVED response after a failover. + """ + moved_redirection_helper(request, failover=True) + + def test_refresh_using_specific_nodes(self, request): + """ + Test making calls on specific nodes when the cluster has failed over to + another node + """ + node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) + node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) + with patch.object(Redis, "parse_response") as parse_response: + with patch.object(NodesManager, "initialize", autospec=True) as initialize: + with patch.multiple( + Connection, send_command=DEFAULT, connect=DEFAULT, can_read=DEFAULT + ) as mocks: + # simulate 7006 as a failed node + def parse_response_mock(connection, command_name, **options): + if connection.port == 7006: + parse_response.failed_calls += 1 + raise ClusterDownError( + "CLUSTERDOWN The cluster is " + "down. Use CLUSTER INFO for " + "more information" + ) + elif connection.port == 7007: + parse_response.successful_calls += 1 + + def initialize_mock(self): + # start with all slots mapped to 7006 + self.nodes_cache = {node_7006.name: node_7006} + self.default_node = node_7006 + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7006] + + # After the first connection fails, a reinitialize + # should follow the cluster to 7007 + def map_7007(self): + self.nodes_cache = {node_7007.name: node_7007} + self.default_node = node_7007 + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7007] + + # Change initialize side effect for the second call + initialize.side_effect = map_7007 + + parse_response.side_effect = parse_response_mock + parse_response.successful_calls = 0 + parse_response.failed_calls = 0 + initialize.side_effect = initialize_mock + mocks["can_read"].return_value = False + mocks["send_command"].return_value = "MOCK_OK" + mocks["connect"].return_value = None + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = _get_client(RedisCluster, request, flushdb=False) + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7006.name) is not None + + rc.get("foo") + + # Cluster should now point to 7007, and there should be + # one failed and one successful call + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7007.name) is not None + assert rc.get_node(node_name=node_7006.name) is None + assert parse_response.failed_calls == 1 + assert parse_response.successful_calls == 1 + + def test_reading_from_replicas_in_round_robin(self): + with patch.multiple( + Connection, + send_command=DEFAULT, + read_response=DEFAULT, + _connect=DEFAULT, + can_read=DEFAULT, + on_connect=DEFAULT, + ) as mocks: + with patch.object(Redis, "parse_response") as parse_response: + + def parse_response_mock_first(connection, *args, **options): + # Primary + assert connection.port == 7001 + parse_response.side_effect = parse_response_mock_second + return "MOCK_OK" + + def parse_response_mock_second(connection, *args, **options): + # Replica + assert connection.port == 7002 + parse_response.side_effect = parse_response_mock_third + return "MOCK_OK" + + def parse_response_mock_third(connection, *args, **options): + # Primary + assert connection.port == 7001 + return "MOCK_OK" + + # We don't need to create a real cluster connection but we + # do want RedisCluster.on_connect function to get called, + # so we'll mock some of the Connection's functions to allow it + parse_response.side_effect = parse_response_mock_first + mocks["send_command"].return_value = True + mocks["read_response"].return_value = "OK" + mocks["_connect"].return_value = True + mocks["can_read"].return_value = False + mocks["on_connect"].return_value = True + + # Create a cluster with reading from replications + read_cluster = get_mocked_redis_client( + host=default_host, port=default_port, read_from_replicas=True + ) + assert read_cluster.read_from_replicas is True + # Check that we read from the slot's nodes in a round robin + # matter. + # 'foo' belongs to slot 12182 and the slot's nodes are: + # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] + read_cluster.get("foo") + read_cluster.get("foo") + read_cluster.get("foo") + mocks["send_command"].assert_has_calls([call("READONLY")]) + + def test_keyslot(self, r): + """ + Test that method will compute correct key in all supported cases + """ + assert r.keyslot("foo") == 12182 + assert r.keyslot("{foo}bar") == 12182 + assert r.keyslot("{foo}") == 12182 + assert r.keyslot(1337) == 4314 + + assert r.keyslot(125) == r.keyslot(b"125") + assert r.keyslot(125) == r.keyslot("\x31\x32\x35") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(1337.1234) == r.keyslot("1337.1234") + assert r.keyslot(1337) == r.keyslot("1337") + assert r.keyslot(b"abc") == r.keyslot("abc") + + def test_get_node_name(self): + assert ( + get_node_name(default_host, default_port) + == f"{default_host}:{default_port}" + ) + + def test_all_nodes(self, r): + """ + Set a list of nodes and it should be possible to iterate over all + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values()] + + for i, node in enumerate(r.get_nodes()): + assert node in nodes + + def test_all_nodes_masters(self, r): + """ + Set a list of nodes with random primaries/replicas config and it shold + be possible to iterate over all of them. + """ + nodes = [ + node + for node in r.nodes_manager.nodes_cache.values() + if node.server_type == PRIMARY + ] + + for node in r.get_primaries(): + assert node in nodes + + @pytest.mark.parametrize("error", RedisCluster.ERRORS_ALLOW_RETRY) + def test_cluster_down_overreaches_retry_attempts(self, error): + """ + When error that allows retry is thrown, test that we retry executing + the command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with patch.object(RedisCluster, "_execute_command") as execute_command: + + def raise_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise error("mocked error") + + execute_command.side_effect = raise_error + + rc = get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(error): + rc.get("bar") + assert execute_command.failed_calls == rc.cluster_error_retry_attempts + + def test_user_on_connect_function(self, request): + """ + Test support in passing on_connect function by the user + """ + + def on_connect(connection): + assert connection is not None + + mock = Mock(side_effect=on_connect) + + _get_client(RedisCluster, request, redis_connect_func=mock) + assert mock.called is True + + def test_set_default_node_success(self, r): + """ + test successful replacement of the default cluster node + """ + default_node = r.get_default_node() + # get a different node + new_def_node = None + for node in r.get_nodes(): + if node != default_node: + new_def_node = node + break + assert r.set_default_node(new_def_node) is True + assert r.get_default_node() == new_def_node + + def test_set_default_node_failure(self, r): + """ + test failed replacement of the default cluster node + """ + default_node = r.get_default_node() + new_def_node = ClusterNode("1.1.1.1", 1111) + assert r.set_default_node(None) is False + assert r.set_default_node(new_def_node) is False + assert r.get_default_node() == default_node + + def test_get_node_from_key(self, r): + """ + Test that get_node_from_key function returns the correct node + """ + key = "bar" + slot = r.keyslot(key) + slot_nodes = r.nodes_manager.slots_cache.get(slot) + primary = slot_nodes[0] + assert r.get_node_from_key(key, replica=False) == primary + replica = r.get_node_from_key(key, replica=True) + if replica is not None: + assert replica.server_type == REPLICA + assert replica in slot_nodes + + @skip_if_redis_enterprise() + def test_not_require_full_coverage_cluster_down_error(self, r): + """ + When require_full_coverage is set to False (default client config) and not + all slots are covered, if one of the nodes has 'cluster-require_full_coverage' + config set to 'yes' some key-based commands should throw ClusterDownError + """ + node = r.get_node_from_key("foo") + missing_slot = r.keyslot("foo") + assert r.set("foo", "bar") is True + try: + assert all(r.cluster_delslots(missing_slot)) + with pytest.raises(ClusterDownError): + r.exists("foo") + finally: + try: + # Add back the missing slot + assert r.cluster_addslots(node, missing_slot) is True + # Make sure we are not getting ClusterDownError anymore + assert r.exists("foo") == 1 + except ResponseError as e: + if f"Slot {missing_slot} is already busy" in str(e): + # It can happen if the test failed to delete this slot + pass + else: + raise e + + +@pytest.mark.onlycluster +class TestClusterRedisCommands: + """ + Tests for RedisCluster unique commands + """ + + def test_case_insensitive_command_names(self, r): + assert ( + r.cluster_response_callbacks["cluster addslots"] + == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] + ) + + def test_get_and_set(self, r): + # get and set can't be tested independently of each other + assert r.get("a") is None + byte_string = b"value" + integer = 5 + unicode_string = chr(3456) + "abcd" + chr(3421) + assert r.set("byte_string", byte_string) + assert r.set("integer", 5) + assert r.set("unicode_string", unicode_string) + assert r.get("byte_string") == byte_string + assert r.get("integer") == str(integer).encode() + assert r.get("unicode_string").decode("utf-8") == unicode_string + + def test_mget_nonatomic(self, r): + assert r.mget_nonatomic([]) == [] + assert r.mget_nonatomic(["a", "b"]) == [None, None] + r["a"] = "1" + r["b"] = "2" + r["c"] = "3" + + assert r.mget_nonatomic("a", "other", "b", "c") == [b"1", None, b"2", b"3"] + + def test_mset_nonatomic(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + assert r.mset_nonatomic(d) + for k, v in d.items(): + assert r[k] == v + + def test_config_set(self, r): + assert r.config_set("slowlog-log-slower-than", 0) + + def test_cluster_config_resetstat(self, r): + r.ping(target_nodes="all") + all_info = r.info(target_nodes="all") + prior_commands_processed = -1 + for node_info in all_info.values(): + prior_commands_processed = node_info["total_commands_processed"] + assert prior_commands_processed >= 1 + r.config_resetstat(target_nodes="all") + all_info = r.info(target_nodes="all") + for node_info in all_info.values(): + reset_commands_processed = node_info["total_commands_processed"] + assert reset_commands_processed < prior_commands_processed + + def test_client_setname(self, r): + node = r.get_random_node() + r.client_setname("redis_py_test", target_nodes=node) + client_name = r.client_getname(target_nodes=node) + assert client_name == "redis_py_test" + + def test_exists(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + r.mset_nonatomic(d) + assert r.exists(*d.keys()) == len(d) + + def test_delete(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + r.mset_nonatomic(d) + assert r.delete(*d.keys()) == len(d) + assert r.delete(*d.keys()) == 0 + + def test_touch(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + r.mset_nonatomic(d) + assert r.touch(*d.keys()) == len(d) + + def test_unlink(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + r.mset_nonatomic(d) + assert r.unlink(*d.keys()) == len(d) + # Unlink is non-blocking so we sleep before + # verifying the deletion + sleep(0.1) + assert r.unlink(*d.keys()) == 0 + + def test_pubsub_channels_merge_results(self, r): + nodes = r.get_nodes() + channels = [] + pubsub_nodes = [] + i = 0 + for node in nodes: + channel = f"foo{i}" + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + pubsub_nodes.append(p) + p.subscribe(channel) + b_channel = channel.encode("utf-8") + channels.append(b_channel) + # Assert that each node returns only the channel it subscribed to + sub_channels = node.redis_connection.pubsub_channels() + if not sub_channels: + # Try again after a short sleep + sleep(0.3) + sub_channels = node.redis_connection.pubsub_channels() + assert sub_channels == [b_channel] + i += 1 + # Assert that the cluster's pubsub_channels function returns ALL of + # the cluster's channels + result = r.pubsub_channels(target_nodes="all") + result.sort() + assert result == channels + + def test_pubsub_numsub_merge_results(self, r): + nodes = r.get_nodes() + pubsub_nodes = [] + channel = "foo" + b_channel = channel.encode("utf-8") + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + pubsub_nodes.append(p) + p.subscribe(channel) + # Assert that each node returns that only one client is subscribed + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + if sub_chann_num == [(b_channel, 0)]: + sleep(0.3) + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + assert sub_chann_num == [(b_channel, 1)] + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numsub(channel, target_nodes="all") == [(b_channel, len(nodes))] + + def test_pubsub_numpat_merge_results(self, r): + nodes = r.get_nodes() + pubsub_nodes = [] + pattern = "foo*" + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + pubsub_nodes.append(p) + p.psubscribe(pattern) + # Assert that each node returns that only one client is subscribed + sub_num_pat = node.redis_connection.pubsub_numpat() + if sub_num_pat == 0: + sleep(0.3) + sub_num_pat = node.redis_connection.pubsub_numpat() + assert sub_num_pat == 1 + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numpat(target_nodes="all") == len(nodes) + + @skip_if_server_version_lt("2.8.0") + def test_cluster_pubsub_channels(self, r): + p = r.pubsub() + p.subscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert wait_for_message(p, timeout=0.5)["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all( + [channel in r.pubsub_channels(target_nodes="all") for channel in expected] + ) + + @skip_if_server_version_lt("2.8.0") + def test_cluster_pubsub_numsub(self, r): + p1 = r.pubsub() + p1.subscribe("foo", "bar", "baz") + for i in range(3): + assert wait_for_message(p1, timeout=0.5)["type"] == "subscribe" + p2 = r.pubsub() + p2.subscribe("bar", "baz") + for i in range(2): + assert wait_for_message(p2, timeout=0.5)["type"] == "subscribe" + p3 = r.pubsub() + p3.subscribe("baz") + assert wait_for_message(p3, timeout=0.5)["type"] == "subscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_numsub("foo", "bar", "baz", target_nodes="all") == channels + + @skip_if_redis_enterprise() + def test_cluster_myid(self, r): + node = r.get_random_node() + myid = r.cluster_myid(node) + assert len(myid) == 40 + + @skip_if_redis_enterprise() + def test_cluster_slots(self, r): + mock_all_nodes_resp(r, default_cluster_slots) + cluster_slots = r.cluster_slots() + assert isinstance(cluster_slots, dict) + assert len(default_cluster_slots) == len(cluster_slots) + assert cluster_slots.get((0, 8191)) is not None + assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) + + @skip_if_redis_enterprise() + def test_cluster_addslots(self, r): + node = r.get_random_node() + mock_node_resp(node, "OK") + assert r.cluster_addslots(node, 1, 2, 3) is True + + @skip_if_server_version_lt("7.0.0") + @skip_if_redis_enterprise() + def test_cluster_addslotsrange(self, r): + node = r.get_random_node() + mock_node_resp(node, "OK") + assert r.cluster_addslotsrange(node, 1, 5) + + @skip_if_redis_enterprise() + def test_cluster_countkeysinslot(self, r): + node = r.nodes_manager.get_node_from_slot(1) + mock_node_resp(node, 2) + assert r.cluster_countkeysinslot(1) == 2 + + def test_cluster_count_failure_report(self, r): + mock_all_nodes_resp(r, 0) + assert r.cluster_count_failure_report("node_0") == 0 + + @skip_if_redis_enterprise() + def test_cluster_delslots(self): + cluster_slots = [ + [0, 8191, ["127.0.0.1", 7000, "node_0"]], + [8192, 16383, ["127.0.0.1", 7001, "node_1"]], + ] + r = get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=cluster_slots + ) + mock_all_nodes_resp(r, "OK") + node0 = r.get_node(default_host, 7000) + node1 = r.get_node(default_host, 7001) + assert r.cluster_delslots(0, 8192) == [True, True] + assert node0.redis_connection.connection.read_response.called + assert node1.redis_connection.connection.read_response.called + + @skip_if_server_version_lt("7.0.0") + @skip_if_redis_enterprise() + def test_cluster_delslotsrange(self, r): + node = r.get_random_node() + mock_node_resp(node, "OK") + r.cluster_addslots(node, 1, 2, 3, 4, 5) + assert r.cluster_delslotsrange(1, 5) + + @skip_if_redis_enterprise() + def test_cluster_failover(self, r): + node = r.get_random_node() + mock_node_resp(node, "OK") + assert r.cluster_failover(node) is True + assert r.cluster_failover(node, "FORCE") is True + assert r.cluster_failover(node, "TAKEOVER") is True + with pytest.raises(RedisError): + r.cluster_failover(node, "FORCT") + + @skip_if_redis_enterprise() + def test_cluster_info(self, r): + info = r.cluster_info() + assert isinstance(info, dict) + assert info["cluster_state"] == "ok" + + @skip_if_redis_enterprise() + def test_cluster_keyslot(self, r): + mock_all_nodes_resp(r, 12182) + assert r.cluster_keyslot("foo") == 12182 + + @skip_if_redis_enterprise() + def test_cluster_meet(self, r): + node = r.get_default_node() + mock_node_resp(node, "OK") + assert r.cluster_meet("127.0.0.1", 6379) is True + + @skip_if_redis_enterprise() + def test_cluster_nodes(self, r): + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + mock_all_nodes_resp(r, response) + nodes = r.cluster_nodes() + assert len(nodes) == 7 + assert nodes.get("172.17.0.7:7006") is not None + assert ( + nodes.get("172.17.0.7:7006").get("node_id") + == "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + ) + + @skip_if_redis_enterprise() + def test_cluster_nodes_importing_migrating(self, r): + response = ( + "488ead2fcce24d8c0f158f9172cb1f4a9e040fe5 127.0.0.1:16381@26381 " + "master - 0 1648975557664 3 connected 10923-16383\n" + "8ae2e70812db80776f739a72374e57fc4ae6f89d 127.0.0.1:16380@26380 " + "master - 0 1648975555000 2 connected 1 5461-10922 [" + "2-<-ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16]\n" + "ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16 127.0.0.1:16379@26379 " + "myself,master - 0 1648975556000 1 connected 0 2-5460 [" + "2->-8ae2e70812db80776f739a72374e57fc4ae6f89d]\n" + ) + mock_all_nodes_resp(r, response) + nodes = r.cluster_nodes() + assert len(nodes) == 3 + node_16379 = nodes.get("127.0.0.1:16379") + node_16380 = nodes.get("127.0.0.1:16380") + node_16381 = nodes.get("127.0.0.1:16381") + assert node_16379.get("migrations") == [ + { + "slot": "2", + "node_id": "8ae2e70812db80776f739a72374e57fc4ae6f89d", + "state": "migrating", + } + ] + assert node_16379.get("slots") == [["0"], ["2", "5460"]] + assert node_16380.get("migrations") == [ + { + "slot": "2", + "node_id": "ed8007ccfa2d91a7b76f8e6fba7ba7e257034a16", + "state": "importing", + } + ] + assert node_16380.get("slots") == [["1"], ["5461", "10922"]] + assert node_16381.get("slots") == [["10923", "16383"]] + assert node_16381.get("migrations") == [] + + @skip_if_redis_enterprise() + def test_cluster_replicate(self, r): + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, "OK") + assert r.cluster_replicate(node, "c8253bae761cb61857d") is True + results = r.cluster_replicate(all_replicas, "c8253bae761cb61857d") + if isinstance(results, dict): + for res in results.values(): + assert res is True + else: + assert results is True + + @skip_if_redis_enterprise() + def test_cluster_reset(self, r): + mock_all_nodes_resp(r, "OK") + assert r.cluster_reset() is True + assert r.cluster_reset(False) is True + all_results = r.cluster_reset(False, target_nodes="all") + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + def test_cluster_save_config(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, "OK") + assert r.cluster_save_config(node) is True + all_results = r.cluster_save_config(all_nodes) + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + def test_cluster_get_keys_in_slot(self, r): + response = [b"{foo}1", b"{foo}2"] + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, response) + keys = r.cluster_get_keys_in_slot(12182, 4) + assert keys == response + + @skip_if_redis_enterprise() + def test_cluster_set_config_epoch(self, r): + mock_all_nodes_resp(r, "OK") + assert r.cluster_set_config_epoch(3) is True + all_results = r.cluster_set_config_epoch(3, target_nodes="all") + for res in all_results.values(): + assert res is True + + @skip_if_redis_enterprise() + def test_cluster_setslot(self, r): + node = r.get_random_node() + mock_node_resp(node, "OK") + assert r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True + assert r.cluster_setslot(node, "node_0", 1218, "NODE") is True + assert r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True + with pytest.raises(RedisError): + r.cluster_failover(node, "STABLE") + with pytest.raises(RedisError): + r.cluster_failover(node, "STATE") + + def test_cluster_setslot_stable(self, r): + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, "OK") + assert r.cluster_setslot_stable(12182) is True + assert node.redis_connection.connection.read_response.called + + @skip_if_redis_enterprise() + def test_cluster_replicas(self, r): + response = [ + b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " + b"127.0.0.1:6377@16377 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + b"r4xfga22229cf3c652b6fca0d09ff69f3e0d4d " + b"127.0.0.1:6378@16378 slave " + b"52611e796814b78e90ad94be9d769a4f668f9a 0 " + b"1634550063436 4 connected", + ] + mock_all_nodes_resp(r, response) + replicas = r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") + assert replicas.get("127.0.0.1:6377") is not None + assert replicas.get("127.0.0.1:6378") is not None + assert ( + replicas.get("127.0.0.1:6378").get("node_id") + == "r4xfga22229cf3c652b6fca0d09ff69f3e0d4d" + ) + + @skip_if_server_version_lt("7.0.0") + def test_cluster_links(self, r): + node = r.get_random_node() + res = r.cluster_links(node) + links_to = sum(x.count("to") for x in res) + links_for = sum(x.count("from") for x in res) + assert links_to == links_for + print(res) + for i in range(0, len(res) - 1, 2): + assert res[i][3] == res[i + 1][3] + + @skip_if_redis_enterprise() + def test_readonly(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + mock_all_nodes_resp(r, "OK") + assert r.readonly() is True + all_replicas_results = r.readonly(target_nodes="replicas") + for res in all_replicas_results.values(): + assert res is True + for replica in r.get_replicas(): + assert replica.redis_connection.connection.read_response.called + + @skip_if_redis_enterprise() + def test_readwrite(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + mock_all_nodes_resp(r, "OK") + assert r.readwrite() is True + all_replicas_results = r.readwrite(target_nodes="replicas") + for res in all_replicas_results.values(): + assert res is True + for replica in r.get_replicas(): + assert replica.redis_connection.connection.read_response.called + + @skip_if_redis_enterprise() + def test_bgsave(self, r): + assert r.bgsave() + sleep(0.3) + assert r.bgsave(True) + + def test_info(self, r): + # Map keys to same slot + r.set("x{1}", 1) + r.set("y{1}", 2) + r.set("z{1}", 3) + # Get node that handles the slot + slot = r.keyslot("x{1}") + node = r.nodes_manager.get_node_from_slot(slot) + # Run info on that node + info = r.info(target_nodes=node) + assert isinstance(info, dict) + assert info["db0"]["keys"] == 3 + + def _init_slowlog_test(self, r, node): + slowlog_lim = r.config_get("slowlog-log-slower-than", target_nodes=node) + assert r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True + return slowlog_lim["slowlog-log-slower-than"] + + def _teardown_slowlog_test(self, r, node, prev_limit): + assert ( + r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) + is True + ) + + def test_slowlog_get(self, r, slowlog): + unicode_string = chr(3456) + "abcd" + chr(3421) + node = r.get_node_from_key(unicode_string) + slowlog_limit = self._init_slowlog_test(r, node) + assert r.slowlog_reset(target_nodes=node) + r.get(unicode_string) + slowlog = r.slowlog_get(target_nodes=node) + assert isinstance(slowlog, list) + commands = [log["command"] for log in slowlog] + + get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) + assert get_command in commands + assert b"SLOWLOG RESET" in commands + + # the order should be ['GET ', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]["start_time"], int) + assert isinstance(slowlog[0]["duration"], int) + # rollback the slowlog limit to its original value + self._teardown_slowlog_test(r, node, slowlog_limit) + + def test_slowlog_get_limit(self, r, slowlog): + assert r.slowlog_reset() + node = r.get_node_from_key("foo") + slowlog_limit = self._init_slowlog_test(r, node) + r.get("foo") + slowlog = r.slowlog_get(1, target_nodes=node) + assert isinstance(slowlog, list) + # only one command, based on the number we passed to slowlog_get() + assert len(slowlog) == 1 + self._teardown_slowlog_test(r, node, slowlog_limit) + + def test_slowlog_length(self, r, slowlog): + r.get("foo") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + slowlog_len = r.slowlog_len(target_nodes=node) + assert isinstance(slowlog_len, int) + + def test_time(self, r): + t = r.time(target_nodes=r.get_primaries()[0]) + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + @skip_if_server_version_lt("4.0.0") + def test_memory_usage(self, r): + r.set("foo", "bar") + assert isinstance(r.memory_usage("foo"), int) + + @skip_if_server_version_lt("4.0.0") + @skip_if_redis_enterprise() + def test_memory_malloc_stats(self, r): + assert r.memory_malloc_stats() + + @skip_if_server_version_lt("4.0.0") + @skip_if_redis_enterprise() + def test_memory_stats(self, r): + # put a key into the current db to make sure that "db." + # has data + r.set("foo", "bar") + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + stats = r.memory_stats(target_nodes=node) + assert isinstance(stats, dict) + for key, value in stats.items(): + if key.startswith("db."): + assert isinstance(value, dict) + + @skip_if_server_version_lt("4.0.0") + def test_memory_help(self, r): + with pytest.raises(NotImplementedError): + r.memory_help() + + @skip_if_server_version_lt("4.0.0") + def test_memory_doctor(self, r): + with pytest.raises(NotImplementedError): + r.memory_doctor() + + @skip_if_redis_enterprise() + def test_lastsave(self, r): + node = r.get_primaries()[0] + assert isinstance(r.lastsave(target_nodes=node), datetime.datetime) + + def test_cluster_echo(self, r): + node = r.get_primaries()[0] + assert r.echo("foo bar", target_nodes=node) == b"foo bar" + + @skip_if_server_version_lt("1.0.0") + def test_debug_segfault(self, r): + with pytest.raises(NotImplementedError): + r.debug_segfault() + + def test_config_resetstat(self, r): + node = r.get_primaries()[0] + r.ping(target_nodes=node) + prior_commands_processed = int( + r.info(target_nodes=node)["total_commands_processed"] + ) + assert prior_commands_processed >= 1 + r.config_resetstat(target_nodes=node) + reset_commands_processed = int( + r.info(target_nodes=node)["total_commands_processed"] + ) + assert reset_commands_processed < prior_commands_processed + + @skip_if_server_version_lt("6.2.0") + def test_client_trackinginfo(self, r): + node = r.get_primaries()[0] + res = r.client_trackinginfo(target_nodes=node) + assert len(res) > 2 + assert "prefixes" in res + + @skip_if_server_version_lt("2.9.50") + def test_client_pause(self, r): + node = r.get_primaries()[0] + assert r.client_pause(1, target_nodes=node) + assert r.client_pause(timeout=1, target_nodes=node) + with pytest.raises(RedisError): + r.client_pause(timeout="not an integer", target_nodes=node) + + @skip_if_server_version_lt("6.2.0") + @skip_if_redis_enterprise() + def test_client_unpause(self, r): + assert r.client_unpause() + + @skip_if_server_version_lt("5.0.0") + def test_client_id(self, r): + node = r.get_primaries()[0] + assert r.client_id(target_nodes=node) > 0 + + @skip_if_server_version_lt("5.0.0") + def test_client_unblock(self, r): + node = r.get_primaries()[0] + myid = r.client_id(target_nodes=node) + assert not r.client_unblock(myid, target_nodes=node) + assert not r.client_unblock(myid, error=True, target_nodes=node) + assert not r.client_unblock(myid, error=False, target_nodes=node) + + @skip_if_server_version_lt("6.0.0") + def test_client_getredir(self, r): + node = r.get_primaries()[0] + assert isinstance(r.client_getredir(target_nodes=node), int) + assert r.client_getredir(target_nodes=node) == -1 + + @skip_if_server_version_lt("6.2.0") + def test_client_info(self, r): + node = r.get_primaries()[0] + info = r.client_info(target_nodes=node) + assert isinstance(info, dict) + assert "addr" in info + + @skip_if_server_version_lt("2.6.9") + def test_client_kill(self, r, r2): + node = r.get_primaries()[0] + r.client_setname("redis-py-c1", target_nodes="all") + r2.client_setname("redis-py-c2", target_nodes="all") + clients = [ + client + for client in r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 2 + clients_by_name = {client.get("name"): client for client in clients} + + client_addr = clients_by_name["redis-py-c2"].get("addr") + assert r.client_kill(client_addr, target_nodes=node) is True + + clients = [ + client + for client in r.client_list(target_nodes=node) + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 1 + assert clients[0].get("name") == "redis-py-c1" + + @skip_if_server_version_lt("2.6.0") + def test_cluster_bitop_not_empty_string(self, r): + r["{foo}a"] = "" + r.bitop("not", "{foo}r", "{foo}a") + assert r.get("{foo}r") is None + + @skip_if_server_version_lt("2.6.0") + def test_cluster_bitop_not(self, r): + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + r["{foo}a"] = test_str + r.bitop("not", "{foo}r", "{foo}a") + assert int(binascii.hexlify(r["{foo}r"]), 16) == correct + + @skip_if_server_version_lt("2.6.0") + def test_cluster_bitop_not_in_place(self, r): + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + r["{foo}a"] = test_str + r.bitop("not", "{foo}a", "{foo}a") + assert int(binascii.hexlify(r["{foo}a"]), 16) == correct + + @skip_if_server_version_lt("2.6.0") + def test_cluster_bitop_single_string(self, r): + test_str = b"\x01\x02\xFF" + r["{foo}a"] = test_str + r.bitop("and", "{foo}res1", "{foo}a") + r.bitop("or", "{foo}res2", "{foo}a") + r.bitop("xor", "{foo}res3", "{foo}a") + assert r["{foo}res1"] == test_str + assert r["{foo}res2"] == test_str + assert r["{foo}res3"] == test_str + + @skip_if_server_version_lt("2.6.0") + def test_cluster_bitop_string_operands(self, r): + r["{foo}a"] = b"\x01\x02\xFF\xFF" + r["{foo}b"] = b"\x01\x02\xFF" + r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") + r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") + r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") + assert int(binascii.hexlify(r["{foo}res1"]), 16) == 0x0102FF00 + assert int(binascii.hexlify(r["{foo}res2"]), 16) == 0x0102FFFF + assert int(binascii.hexlify(r["{foo}res3"]), 16) == 0x000000FF + + @skip_if_server_version_lt("6.2.0") + def test_cluster_copy(self, r): + assert r.copy("{foo}a", "{foo}b") == 0 + r.set("{foo}a", "bar") + assert r.copy("{foo}a", "{foo}b") == 1 + assert r.get("{foo}a") == b"bar" + assert r.get("{foo}b") == b"bar" + + @skip_if_server_version_lt("6.2.0") + def test_cluster_copy_and_replace(self, r): + r.set("{foo}a", "foo1") + r.set("{foo}b", "foo2") + assert r.copy("{foo}a", "{foo}b") == 0 + assert r.copy("{foo}a", "{foo}b", replace=True) == 1 + + @skip_if_server_version_lt("6.2.0") + def test_cluster_lmove(self, r): + r.rpush("{foo}a", "one", "two", "three", "four") + assert r.lmove("{foo}a", "{foo}b") + assert r.lmove("{foo}a", "{foo}b", "right", "left") + + @skip_if_server_version_lt("6.2.0") + def test_cluster_blmove(self, r): + r.rpush("{foo}a", "one", "two", "three", "four") + assert r.blmove("{foo}a", "{foo}b", 5) + assert r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") + + def test_cluster_msetnx(self, r): + d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} + assert r.msetnx(d) + d2 = {"{foo}a": b"x", "{foo}d": b"4"} + assert not r.msetnx(d2) + for k, v in d.items(): + assert r[k] == v + assert r.get("{foo}d") is None + + def test_cluster_rename(self, r): + r["{foo}a"] = "1" + assert r.rename("{foo}a", "{foo}b") + assert r.get("{foo}a") is None + assert r["{foo}b"] == b"1" + + def test_cluster_renamenx(self, r): + r["{foo}a"] = "1" + r["{foo}b"] = "2" + assert not r.renamenx("{foo}a", "{foo}b") + assert r["{foo}a"] == b"1" + assert r["{foo}b"] == b"2" + + # LIST COMMANDS + def test_cluster_blpop(self, r): + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None + r.rpush("{foo}c", "1") + assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + def test_cluster_brpop(self, r): + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None + r.rpush("{foo}c", "1") + assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + def test_cluster_brpoplpush(self, r): + r.rpush("{foo}a", "1", "2") + r.rpush("{foo}b", "3", "4") + assert r.brpoplpush("{foo}a", "{foo}b") == b"2" + assert r.brpoplpush("{foo}a", "{foo}b") == b"1" + assert r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None + assert r.lrange("{foo}a", 0, -1) == [] + assert r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] + + def test_cluster_brpoplpush_empty_string(self, r): + r.rpush("{foo}a", "") + assert r.brpoplpush("{foo}a", "{foo}b") == b"" + + def test_cluster_rpoplpush(self, r): + r.rpush("{foo}a", "a1", "a2", "a3") + r.rpush("{foo}b", "b1", "b2", "b3") + assert r.rpoplpush("{foo}a", "{foo}b") == b"a3" + assert r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] + assert r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] + + def test_cluster_sdiff(self, r): + r.sadd("{foo}a", "1", "2", "3") + assert r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + r.sadd("{foo}b", "2", "3") + assert r.sdiff("{foo}a", "{foo}b") == {b"1"} + + def test_cluster_sdiffstore(self, r): + r.sadd("{foo}a", "1", "2", "3") + assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert r.smembers("{foo}c") == {b"1", b"2", b"3"} + r.sadd("{foo}b", "2", "3") + assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 + assert r.smembers("{foo}c") == {b"1"} + + def test_cluster_sinter(self, r): + r.sadd("{foo}a", "1", "2", "3") + assert r.sinter("{foo}a", "{foo}b") == set() + r.sadd("{foo}b", "2", "3") + assert r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} + + def test_cluster_sinterstore(self, r): + r.sadd("{foo}a", "1", "2", "3") + assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 + assert r.smembers("{foo}c") == set() + r.sadd("{foo}b", "2", "3") + assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 + assert r.smembers("{foo}c") == {b"2", b"3"} + + def test_cluster_smove(self, r): + r.sadd("{foo}a", "a1", "a2") + r.sadd("{foo}b", "b1", "b2") + assert r.smove("{foo}a", "{foo}b", "a1") + assert r.smembers("{foo}a") == {b"a2"} + assert r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} + + def test_cluster_sunion(self, r): + r.sadd("{foo}a", "1", "2") + r.sadd("{foo}b", "2", "3") + assert r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + + def test_cluster_sunionstore(self, r): + r.sadd("{foo}a", "1", "2") + r.sadd("{foo}b", "2", "3") + assert r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert r.smembers("{foo}c") == {b"1", b"2", b"3"} + + @skip_if_server_version_lt("6.2.0") + def test_cluster_zdiff(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + + @skip_if_server_version_lt("6.2.0") + def test_cluster_zdiffstore(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) + assert r.zrange("{foo}out", 0, -1) == [b"a3"] + assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + + @skip_if_server_version_lt("6.2.0") + def test_cluster_zinter(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] + # invalid aggregation + with pytest.raises(DataError): + r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) + # aggregate with SUM + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a3", 5), (b"a1", 6)] + # aggregate with MIN + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a3", 1)] + # with weights + assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] + + def test_cluster_zinterstore_sum(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + + def test_cluster_zinterstore_max(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + + def test_cluster_zinterstore_min(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + == 2 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + + def test_cluster_zinterstore_with_weight(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + + @skip_if_server_version_lt("4.9.0") + def test_cluster_bzpopmax(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2}) + r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None + r.zadd("{foo}c", {"c1": 100}) + assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("4.9.0") + def test_cluster_bzpopmin(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2}) + r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None + r.zadd("{foo}c", {"c1": 100}) + assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + + @skip_if_server_version_lt("6.2.0") + def test_cluster_zrangestore(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + assert r.zrangestore("{foo}b", "{foo}a", 0, 1) + assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + assert r.zrangestore("{foo}b", "{foo}a", 1, 2) + assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] + assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + # reversed order + assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) + assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + # by score + assert r.zrangestore( + "{foo}b", "{foo}a", 2, 1, byscore=True, offset=0, num=1, desc=True + ) + assert r.zrange("{foo}b", 0, -1) == [b"a2"] + # by lex + assert r.zrangestore( + "{foo}b", "{foo}a", "[a2", "(a3", bylex=True, offset=0, num=1 + ) + assert r.zrange("{foo}b", 0, -1) == [b"a2"] + + @skip_if_server_version_lt("6.2.0") + def test_cluster_zunion(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + # sum + assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] + assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + # max + assert r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + # min + assert r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + # with weight + assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + def test_cluster_zunionstore_sum(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + + def test_cluster_zunionstore_max(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + + def test_cluster_zunionstore_min(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + == 4 + ) + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + + def test_cluster_zunionstore_with_weight(self, r): + r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 + assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + @skip_if_server_version_lt("2.8.9") + def test_cluster_pfcount(self, r): + members = {b"1", b"2", b"3"} + r.pfadd("{foo}a", *members) + assert r.pfcount("{foo}a") == len(members) + members_b = {b"2", b"3", b"4"} + r.pfadd("{foo}b", *members_b) + assert r.pfcount("{foo}b") == len(members_b) + assert r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) + + @skip_if_server_version_lt("2.8.9") + def test_cluster_pfmerge(self, r): + mema = {b"1", b"2", b"3"} + memb = {b"2", b"3", b"4"} + memc = {b"5", b"6", b"7"} + r.pfadd("{foo}a", *mema) + r.pfadd("{foo}b", *memb) + r.pfadd("{foo}c", *memc) + r.pfmerge("{foo}d", "{foo}c", "{foo}a") + assert r.pfcount("{foo}d") == 6 + r.pfmerge("{foo}d", "{foo}b") + assert r.pfcount("{foo}d") == 7 + + def test_cluster_sort_store(self, r): + r.rpush("{foo}a", "2", "3", "1") + assert r.sort("{foo}a", store="{foo}sorted_values") == 3 + assert r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] + + # GEO COMMANDS + @skip_if_server_version_lt("6.2.0") + def test_cluster_geosearchstore(self, r): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + r.geoadd("{foo}barcelona", values) + r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + ) + assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("6.2.0") + def test_geosearchstore_dist(self, r): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + r.geoadd("{foo}barcelona", values) + r.geosearchstore( + "{foo}places_barcelona", + "{foo}barcelona", + longitude=2.191, + latitude=41.433, + radius=1000, + storedist=True, + ) + # instead of save the geo score, the distance is saved. + assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + + @skip_if_server_version_lt("3.2.0") + def test_cluster_georadius_store(self, r): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + r.geoadd("{foo}barcelona", values) + r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store="{foo}places_barcelona" + ) + assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + def test_cluster_georadius_store_dist(self, r): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + r.geoadd("{foo}barcelona", values) + r.georadius( + "{foo}barcelona", 2.191, 41.433, 1000, store_dist="{foo}places_barcelona" + ) + # instead of save the geo score, the distance is saved. + assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + + def test_cluster_dbsize(self, r): + d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} + assert r.mset_nonatomic(d) + assert r.dbsize(target_nodes="primaries") == len(d) + + def test_cluster_keys(self, r): + assert r.keys() == [] + keys_with_underscores = {b"test_a", b"test_b"} + keys = keys_with_underscores.union({b"testc"}) + for key in keys: + r[key] = 1 + assert ( + set(r.keys(pattern="test_*", target_nodes="primaries")) + == keys_with_underscores + ) + assert set(r.keys(pattern="test*", target_nodes="primaries")) == keys + + # SCAN COMMANDS + @skip_if_server_version_lt("2.8.0") + def test_cluster_scan(self, r): + r.set("a", 1) + r.set("b", 2) + r.set("c", 3) + + for target_nodes, nodes in zip( + ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] + ): + cursors, keys = r.scan(target_nodes=target_nodes) + assert sorted(keys) == [b"a", b"b", b"c"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + cursors, keys = r.scan(match="a*", target_nodes=target_nodes) + assert sorted(keys) == [b"a"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + @skip_if_server_version_lt("6.0.0") + def test_cluster_scan_type(self, r): + r.sadd("a-set", 1) + r.sadd("b-set", 1) + r.sadd("c-set", 1) + r.hset("a-hash", "foo", 2) + r.lpush("a-list", "aux", 3) + + for target_nodes, nodes in zip( + ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] + ): + cursors, keys = r.scan(_type="SET", target_nodes=target_nodes) + assert sorted(keys) == [b"a-set", b"b-set", b"c-set"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + cursors, keys = r.scan(_type="SET", match="a*", target_nodes=target_nodes) + assert sorted(keys) == [b"a-set"] + assert sorted(cursors.keys()) == sorted(node.name for node in nodes) + assert all(cursor == 0 for cursor in cursors.values()) + + @skip_if_server_version_lt("2.8.0") + def test_cluster_scan_iter(self, r): + keys_all = [] + keys_1 = [] + for i in range(100): + s = str(i) + r.set(s, 1) + keys_all.append(s.encode("utf-8")) + if s.startswith("1"): + keys_1.append(s.encode("utf-8")) + keys_all.sort() + keys_1.sort() + + for target_nodes in ["primaries", "replicas"]: + keys = r.scan_iter(target_nodes=target_nodes) + assert sorted(keys) == keys_all + + keys = r.scan_iter(match="1*", target_nodes=target_nodes) + assert sorted(keys) == keys_1 + + def test_cluster_randomkey(self, r): + node = r.get_node_from_key("{foo}") + assert r.randomkey(target_nodes=node) is None + for key in ("{foo}a", "{foo}b", "{foo}c"): + r[key] = 1 + assert r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") + + @skip_if_server_version_lt("6.0.0") + @skip_if_redis_enterprise() + def test_acl_log(self, r, request): + key = "{cache}:" + node = r.get_node_from_key(key) + username = "redis-py-user" + + def teardown(): + r.acl_deluser(username, target_nodes="primaries") + + request.addfinalizer(teardown) + r.acl_setuser( + username, + enabled=True, + reset=True, + commands=["+get", "+set", "+select", "+cluster", "+command", "+info"], + keys=["{cache}:*"], + nopass=True, + target_nodes="primaries", + ) + r.acl_log_reset(target_nodes=node) + + user_client = _get_client( + RedisCluster, request, flushdb=False, username=username + ) + + # Valid operation and key + assert user_client.set("{cache}:0", 1) + assert user_client.get("{cache}:0") == b"1" + + # Invalid key + with pytest.raises(NoPermissionError): + user_client.get("{cache}violated_cache:0") + + # Invalid operation + with pytest.raises(NoPermissionError): + user_client.hset("{cache}:0", "hkey", "hval") + + assert isinstance(r.acl_log(target_nodes=node), list) + assert len(r.acl_log(target_nodes=node)) == 2 + assert len(r.acl_log(count=1, target_nodes=node)) == 1 + assert isinstance(r.acl_log(target_nodes=node)[0], dict) + assert "client-info" in r.acl_log(count=1, target_nodes=node)[0] + assert r.acl_log_reset(target_nodes=node) + + +@pytest.mark.onlycluster +class TestNodesManager: + """ + Tests for the NodesManager class + """ + + def test_load_balancer(self, r): + n_manager = r.nodes_manager + lb = n_manager.read_load_balancer + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5], + } + primary1_name = n_manager.slots_cache[slot_1][0].name + primary2_name = n_manager.slots_cache[slot_2][0].name + list1_size = len(n_manager.slots_cache[slot_1]) + list2_size = len(n_manager.slots_cache[slot_2]) + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 + assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 1 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + lb.reset() + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + def test_init_slots_cache_not_all_slots_covered(self): + """ + Test that if not all slots are covered it should raise an exception + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], + ] + with pytest.raises(RedisClusterException) as ex: + get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=True, + ) + assert str(ex.value).startswith( + "All slots are not covered after query all startup_nodes." + ) + + def test_init_slots_cache_not_require_full_coverage_success(self): + """ + When require_full_coverage is set to False and not all slots are + covered the cluster client initialization should succeed + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], + ] + + rc = get_mocked_redis_client( + host=default_host, + port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + ) + + assert 5460 not in rc.nodes_manager.slots_cache + + def test_init_slots_cache(self): + """ + Test that slots cache can in initialized and all slots are covered + """ + good_slots_resp = [ + [0, 5460, ["127.0.0.1", 7000], ["127.0.0.2", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.2", 7004]], + [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.2", 7005]], + ] + + rc = get_mocked_redis_client( + host=default_host, port=default_port, cluster_slots=good_slots_resp + ) + n_manager = rc.nodes_manager + assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for slot_info in good_slots_resp: + all_hosts = ["127.0.0.1", "127.0.0.2"] + all_ports = [7000, 7001, 7002, 7003, 7004, 7005] + slot_start = slot_info[0] + slot_end = slot_info[1] + for i in range(slot_start, slot_end + 1): + assert len(n_manager.slots_cache[i]) == len(slot_info[2:]) + assert n_manager.slots_cache[i][0].host in all_hosts + assert n_manager.slots_cache[i][1].host in all_hosts + assert n_manager.slots_cache[i][0].port in all_ports + assert n_manager.slots_cache[i][1].port in all_ports + + assert len(n_manager.nodes_cache) == 6 + + def test_init_slots_cache_cluster_mode_disabled(self): + """ + Test that creating a RedisCluster failes if one of the startup nodes + has cluster mode disabled + """ + with pytest.raises(RedisClusterException) as e: + get_mocked_redis_client( + host=default_host, port=default_port, cluster_enabled=False + ) + assert "Cluster mode is not enabled on this node" in str(e.value) + + def test_empty_startup_nodes(self): + """ + It should not be possible to create a node manager with no nodes + specified + """ + with pytest.raises(RedisClusterException): + NodesManager([]) + + def test_wrong_startup_nodes_type(self): + """ + If something other then a list type itteratable is provided it should + fail + """ + with pytest.raises(RedisClusterException): + NodesManager({}) + + def test_init_slots_cache_slots_collision(self, request): + """ + Test that if 2 nodes do not agree on the same slots setup it should + raise an error. In this test both nodes will say that the first + slots block should be bound to different servers. + """ + with patch.object(NodesManager, "create_redis_node") as create_redis_node: + + def create_mocked_redis_node(host, port, **kwargs): + """ + Helper function to return custom slots cache data from + different redis nodes + """ + if port == 7000: + result = [ + [0, 5460, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], + ] + + elif port == 7001: + result = [ + [0, 5460, ["127.0.0.1", 7001], ["127.0.0.1", 7003]], + [5461, 10922, ["127.0.0.1", 7000], ["127.0.0.1", 7004]], + ] + else: + result = [] + + r_node = Redis(host=host, port=port) + + orig_execute_command = r_node.execute_command + + def execute_command(*args, **kwargs): + if args[0] == "CLUSTER SLOTS": + return result + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} + else: + return orig_execute_command(*args, **kwargs) + + r_node.execute_command = execute_command + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + with pytest.raises(RedisClusterException) as ex: + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) + RedisCluster(startup_nodes=[node_1, node_2]) + assert str(ex.value).startswith( + "startup_nodes could not agree on a valid slots cache" + ), str(ex.value) + + def test_cluster_one_instance(self): + """ + If the cluster exists of only 1 node then there is some hacks that must + be validated they work. + """ + node = ClusterNode(default_host, default_port) + cluster_slots = [[0, 16383, ["", default_port]]] + rc = get_mocked_redis_client(startup_nodes=[node], cluster_slots=cluster_slots) + + n = rc.nodes_manager + assert len(n.nodes_cache) == 1 + n_node = rc.get_node(node_name=node.name) + assert n_node is not None + assert n_node == node + assert n_node.server_type == PRIMARY + assert len(n.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + assert n.slots_cache[i] == [n_node] + + def test_init_with_down_node(self): + """ + If I can't connect to one of the nodes, everything should still work. + But if I can't connect to any of the nodes, exception should be thrown. + """ + with patch.object(NodesManager, "create_redis_node") as create_redis_node: + + def create_mocked_redis_node(host, port, **kwargs): + if port == 7000: + raise ConnectionError("mock connection error for 7000") + + r_node = Redis(host=host, port=port, decode_responses=True) + + def execute_command(*args, **kwargs): + if args[0] == "CLUSTER SLOTS": + return [ + [0, 8191, ["127.0.0.1", 7001, "node_1"]], + [8192, 16383, ["127.0.0.1", 7002, "node_2"]], + ] + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} + + r_node.execute_command = execute_command + + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + node_1 = ClusterNode("127.0.0.1", 7000) + node_2 = ClusterNode("127.0.0.1", 7001) + + # If all startup nodes fail to connect, connection error should be + # thrown + with pytest.raises(RedisClusterException) as e: + RedisCluster(startup_nodes=[node_1]) + assert "Redis Cluster cannot be connected" in str(e.value) + + with patch.object( + CommandsParser, "initialize", autospec=True + ) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = { + "get": { + "name": "get", + "arity": 2, + "flags": ["readonly", "fast"], + "first_key_pos": 1, + "last_key_pos": 1, + "step_count": 1, + } + } + + cmd_parser_initialize.side_effect = cmd_init_mock + # When at least one startup node is reachable, the cluster + # initialization should succeeds + rc = RedisCluster(startup_nodes=[node_1, node_2]) + assert rc.get_node(host=default_host, port=7001) is not None + assert rc.get_node(host=default_host, port=7002) is not None + + +@pytest.mark.onlycluster +class TestClusterPubSubObject: + """ + Tests for the ClusterPubSub class + """ + + def test_init_pubsub_with_host_and_port(self, r): + """ + Test creation of pubsub instance with passed host and port + """ + node = r.get_default_node() + p = r.pubsub(host=node.host, port=node.port) + assert p.get_pubsub_node() == node + + def test_init_pubsub_with_node(self, r): + """ + Test creation of pubsub instance with passed node + """ + node = r.get_default_node() + p = r.pubsub(node=node) + assert p.get_pubsub_node() == node + + def test_init_pubusub_without_specifying_node(self, r): + """ + Test creation of pubsub instance without specifying a node. The node + should be determined based on the keyslot of the first command + execution. + """ + channel_name = "foo" + node = r.get_node_from_key(channel_name) + p = r.pubsub() + assert p.get_pubsub_node() is None + p.subscribe(channel_name) + assert p.get_pubsub_node() == node + + def test_init_pubsub_with_a_non_existent_node(self, r): + """ + Test creation of pubsub instance with node that doesn't exists in the + cluster. RedisClusterException should be raised. + """ + node = ClusterNode("1.1.1.1", 1111) + with pytest.raises(RedisClusterException): + r.pubsub(node) + + def test_init_pubsub_with_a_non_existent_host_port(self, r): + """ + Test creation of pubsub instance with host and port that don't belong + to a node in the cluster. + RedisClusterException should be raised. + """ + with pytest.raises(RedisClusterException): + r.pubsub(host="1.1.1.1", port=1111) + + def test_init_pubsub_host_or_port(self, r): + """ + Test creation of pubsub instance with host but without port, and vice + versa. DataError should be raised. + """ + with pytest.raises(DataError): + r.pubsub(host="localhost") + + with pytest.raises(DataError): + r.pubsub(port=16379) + + def test_get_redis_connection(self, r): + """ + Test that get_redis_connection() returns the redis connection of the + set pubsub node + """ + node = r.get_default_node() + p = r.pubsub(node=node) + assert p.get_redis_connection() == node.redis_connection + + +@pytest.mark.onlycluster +class TestClusterPipeline: + """ + Tests for the ClusterPipeline class + """ + + def test_blocked_methods(self, r): + """ + Currently some method calls on a Cluster pipeline + is blocked when using in cluster mode. + They maybe implemented in the future. + """ + pipe = r.pipeline() + with pytest.raises(RedisClusterException): + pipe.multi() + + with pytest.raises(RedisClusterException): + pipe.immediate_execute_command() + + with pytest.raises(RedisClusterException): + pipe._execute_transaction(None, None, None) + + with pytest.raises(RedisClusterException): + pipe.load_scripts() + + with pytest.raises(RedisClusterException): + pipe.watch() + + with pytest.raises(RedisClusterException): + pipe.unwatch() + + with pytest.raises(RedisClusterException): + pipe.script_load_for_pipeline(None) + + with pytest.raises(RedisClusterException): + pipe.eval() + + def test_blocked_arguments(self, r): + """ + Currently some arguments is blocked when using in cluster mode. + They maybe implemented in the future. + """ + with pytest.raises(RedisClusterException) as ex: + r.pipeline(transaction=True) + + assert ( + str(ex.value).startswith("transaction is deprecated in cluster mode") + is True + ) + + with pytest.raises(RedisClusterException) as ex: + r.pipeline(shard_hint=True) + + assert ( + str(ex.value).startswith("shard_hint is deprecated in cluster mode") is True + ) + + def test_redis_cluster_pipeline(self, r): + """ + Test that we can use a pipeline with the RedisCluster class + """ + with r.pipeline() as pipe: + pipe.set("foo", "bar") + pipe.get("foo") + assert pipe.execute() == [True, b"bar"] + + def test_mget_disabled(self, r): + """ + Test that mget is disabled for ClusterPipeline + """ + with r.pipeline() as pipe: + with pytest.raises(RedisClusterException): + pipe.mget(["a"]) + + def test_mset_disabled(self, r): + """ + Test that mset is disabled for ClusterPipeline + """ + with r.pipeline() as pipe: + with pytest.raises(RedisClusterException): + pipe.mset({"a": 1, "b": 2}) + + def test_rename_disabled(self, r): + """ + Test that rename is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.rename("a", "b") + + def test_renamenx_disabled(self, r): + """ + Test that renamenx is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.renamenx("a", "b") + + def test_delete_single(self, r): + """ + Test a single delete operation + """ + r["a"] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.delete("a") + assert pipe.execute() == [1] + + def test_multi_delete_unsupported(self, r): + """ + Test that multi delete operation is unsupported + """ + with r.pipeline(transaction=False) as pipe: + r["a"] = 1 + r["b"] = 2 + with pytest.raises(RedisClusterException): + pipe.delete("a", "b") + + def test_brpoplpush_disabled(self, r): + """ + Test that brpoplpush is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.brpoplpush() + + def test_rpoplpush_disabled(self, r): + """ + Test that rpoplpush is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.rpoplpush() + + def test_sort_disabled(self, r): + """ + Test that sort is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sort() + + def test_sdiff_disabled(self, r): + """ + Test that sdiff is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sdiff() + + def test_sdiffstore_disabled(self, r): + """ + Test that sdiffstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sdiffstore() + + def test_sinter_disabled(self, r): + """ + Test that sinter is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sinter() + + def test_sinterstore_disabled(self, r): + """ + Test that sinterstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sinterstore() + + def test_smove_disabled(self, r): + """ + Test that move is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.smove() + + def test_sunion_disabled(self, r): + """ + Test that sunion is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sunion() + + def test_sunionstore_disabled(self, r): + """ + Test that sunionstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sunionstore() + + def test_spfmerge_disabled(self, r): + """ + Test that spfmerge is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.pfmerge() + + def test_multi_key_operation_with_a_single_slot(self, r): + """ + Test multi key operation with a single slot + """ + pipe = r.pipeline(transaction=False) + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") + + res = pipe.execute() + assert res == [True, True, True, b"1", b"2", b"3"] + + def test_multi_key_operation_with_multi_slots(self, r): + """ + Test multi key operation with more than one slot + """ + pipe = r.pipeline(transaction=False) + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.set("bar", 4) + pipe.set("bazz", 5) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") + pipe.get("bar") + pipe.get("bazz") + res = pipe.execute() + assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"] + + def test_connection_error_not_raised(self, r): + """ + Test that the pipeline doesn't raise an error on connection error when + raise_on_error=False + """ + key = "foo" + node = r.get_node_from_key(key, False) + + def raise_connection_error(): + e = ConnectionError("error") + return e + + with r.pipeline() as pipe: + mock_node_resp_func(node, raise_connection_error) + res = pipe.get(key).get(key).execute(raise_on_error=False) + assert node.redis_connection.connection.read_response.called + assert isinstance(res[0], ConnectionError) + + def test_connection_error_raised(self, r): + """ + Test that the pipeline raises an error on connection error when + raise_on_error=True + """ + key = "foo" + node = r.get_node_from_key(key, False) + + def raise_connection_error(): + e = ConnectionError("error") + return e + + with r.pipeline() as pipe: + mock_node_resp_func(node, raise_connection_error) + with pytest.raises(ConnectionError): + pipe.get(key).get(key).execute(raise_on_error=True) + + def test_asking_error(self, r): + """ + Test redirection on ASK error + """ + key = "foo" + first_node = r.get_node_from_key(key, False) + ask_node = None + for node in r.get_nodes(): + if node != first_node: + ask_node = node + break + if ask_node is None: + warnings.warn("skipping this test since the cluster has only one " "node") + return + ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}" + + def raise_ask_error(): + raise AskError(ask_msg) + + with r.pipeline() as pipe: + mock_node_resp_func(first_node, raise_ask_error) + mock_node_resp(ask_node, "MOCK_OK") + res = pipe.get(key).execute() + assert first_node.redis_connection.connection.read_response.called + assert ask_node.redis_connection.connection.read_response.called + assert res == ["MOCK_OK"] + + def test_empty_stack(self, r): + """ + If pipeline is executed with no commands it should + return a empty list. + """ + p = r.pipeline() + result = p.execute() + assert result == [] + + +@pytest.mark.onlycluster +class TestReadOnlyPipeline: + """ + Tests for ClusterPipeline class in readonly mode + """ + + def test_pipeline_readonly(self, r): + """ + On readonly mode, we supports get related stuff only. + """ + r.readonly(target_nodes="all") + r.set("foo71", "a1") # we assume this key is set on 127.0.0.1:7001 + r.zadd("foo88", {"z1": 1}) # we assume this key is set on 127.0.0.1:7002 + r.zadd("foo88", {"z2": 4}) + + with r.pipeline() as readonly_pipe: + readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) + assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] + + def test_moved_redirection_on_slave_with_default(self, r): + """ + On Pipeline, we redirected once and finally get from master with + readonly client when data is completely moved. + """ + key = "bar" + r.set(key, "foo") + # set read_from_replicas to True + r.read_from_replicas = True + primary = r.get_node_from_key(key, False) + replica = r.get_node_from_key(key, True) + with r.pipeline() as readwrite_pipe: + mock_node_resp(primary, "MOCK_FOO") + if replica is not None: + moved_error = f"{r.keyslot(key)} {primary.host}:{primary.port}" + + def raise_moved_error(): + raise MovedError(moved_error) + + mock_node_resp_func(replica, raise_moved_error) + assert readwrite_pipe.reinitialize_counter == 0 + readwrite_pipe.get(key).get(key) + assert readwrite_pipe.execute() == ["MOCK_FOO", "MOCK_FOO"] + if replica is not None: + # the slot has a replica as well, so MovedError should have + # occurred. If MovedError occurs, we should see the + # reinitialize_counter increase. + assert readwrite_pipe.reinitialize_counter == 1 + conn = replica.redis_connection.connection + assert conn.read_response.called is True + + def test_readonly_pipeline_from_readonly_client(self, request): + """ + Test that the pipeline is initialized with readonly mode if the client + has it enabled + """ + # Create a cluster with reading from replications + ro = _get_client(RedisCluster, request, read_from_replicas=True) + key = "bar" + ro.set(key, "foo") + import time + + time.sleep(0.2) + with ro.pipeline() as readonly_pipe: + mock_all_nodes_resp(ro, "MOCK_OK") + assert readonly_pipe.read_from_replicas is True + assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + if len(slot_nodes) > 1: + executed_on_replica = False + for node in slot_nodes: + if node.server_type == REPLICA: + conn = node.redis_connection.connection + executed_on_replica = conn.read_response.called + if executed_on_replica: + break + assert executed_on_replica is True + + +@pytest.mark.onlycluster +class TestClusterMonitor: + def test_wait_command_not_found(self, r): + "Make sure the wait_for_command func works when command is not found" + key = "foo" + node = r.get_node_from_key(key) + with r.monitor(target_node=node) as m: + response = wait_for_command(r, m, "nothing", key=key) + assert response is None + + def test_response_values(self, r): + db = 0 + key = "foo" + node = r.get_node_from_key(key) + with r.monitor(target_node=node) as m: + r.ping(target_nodes=node) + response = wait_for_command(r, m, "PING", key=key) + assert isinstance(response["time"], float) + assert response["db"] == db + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" + + def test_command_with_quoted_key(self, r): + key = "{foo}1" + node = r.get_node_from_key(key) + with r.monitor(node) as m: + r.get('{foo}"bar') + response = wait_for_command(r, m, 'GET {foo}"bar', key=key) + assert response["command"] == 'GET {foo}"bar' + + def test_command_with_binary_data(self, r): + key = "{foo}1" + node = r.get_node_from_key(key) + with r.monitor(target_node=node) as m: + byte_string = b"{foo}bar\x92" + r.get(byte_string) + response = wait_for_command(r, m, "GET {foo}bar\\x92", key=key) + assert response["command"] == "GET {foo}bar\\x92" + + def test_command_with_escaped_data(self, r): + key = "{foo}1" + node = r.get_node_from_key(key) + with r.monitor(target_node=node) as m: + byte_string = b"{foo}bar\\x92" + r.get(byte_string) + response = wait_for_command(r, m, "GET {foo}bar\\\\x92", key=key) + assert response["command"] == "GET {foo}bar\\\\x92" + + def test_flush(self, r): + r.set("x", "1") + r.set("z", "1") + r.flushall() + assert r.get("x") is None + assert r.get("y") is None diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 78220406eb..9bb7480a42 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -53,6 +53,7 @@ async def get_stream_message(client: redis.Redis, stream: str, message_id: str): # RESPONSE CALLBACKS +@pytest.mark.onlynoncluster class TestResponseCallbacks: """Tests for the response callback system""" @@ -241,6 +242,7 @@ def teardown(): assert f"user {username} off sanitize-payload &* -@all" in users @skip_if_server_version_lt(REDIS_6_VERSION) + @pytest.mark.onlynoncluster async def test_acl_log(self, r: redis.Redis, request, event_loop, create_redis): username = "redis-py-user" @@ -348,6 +350,7 @@ async def test_acl_whoami(self, r: redis.Redis): username = await r.acl_whoami() assert isinstance(username, str) + @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): clients = await r.client_list() assert isinstance(clients[0], dict) @@ -362,10 +365,12 @@ async def test_client_list_type(self, r: redis.Redis): assert isinstance(clients, list) @skip_if_server_version_lt("5.0.0") + @pytest.mark.onlynoncluster async def test_client_id(self, r: redis.Redis): assert await r.client_id() > 0 @skip_if_server_version_lt("5.0.0") + @pytest.mark.onlynoncluster async def test_client_unblock(self, r: redis.Redis): myid = await r.client_id() assert not await r.client_unblock(myid) @@ -373,15 +378,18 @@ async def test_client_unblock(self, r: redis.Redis): assert not await r.client_unblock(myid, error=False) @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_getname(self, r: redis.Redis): assert await r.client_getname() is None @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") assert await r.client_getname() == "redis_py_test" @skip_if_server_version_lt("2.6.9") + @pytest.mark.onlynoncluster async def test_client_kill(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -420,6 +428,7 @@ async def test_client_kill_filter_invalid_params(self, r: redis.Redis): await r.client_kill_filter(_type="caster") # type: ignore @skip_if_server_version_lt("2.8.12") + @pytest.mark.onlynoncluster async def test_client_kill_filter_by_id(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -445,6 +454,7 @@ async def test_client_kill_filter_by_id(self, r: redis.Redis, r2): assert clients[0].get("name") == "redis-py-c1" @skip_if_server_version_lt("2.8.12") + @pytest.mark.onlynoncluster async def test_client_kill_filter_by_addr(self, r: redis.Redis, r2): await r.client_setname("redis-py-c1") await r2.client_setname("redis-py-c2") @@ -477,6 +487,7 @@ async def test_client_list_after_client_setname(self, r: redis.Redis): assert "redis_py_test" in [c["name"] for c in clients] @skip_if_server_version_lt("2.9.50") + @pytest.mark.onlynoncluster async def test_client_pause(self, r: redis.Redis): assert await r.client_pause(1) assert await r.client_pause(timeout=1) @@ -488,6 +499,7 @@ async def test_config_get(self, r: redis.Redis): assert "maxmemory" in data assert data["maxmemory"].isdigit() + @pytest.mark.onlynoncluster async def test_config_resetstat(self, r: redis.Redis): await r.ping() prior_commands_processed = int((await r.info())["total_commands_processed"]) @@ -505,14 +517,17 @@ async def test_config_set(self, r: redis.Redis): finally: assert await r.config_set("dbfilename", rdbname) + @pytest.mark.onlynoncluster async def test_dbsize(self, r: redis.Redis): await r.set("a", "foo") await r.set("b", "bar") assert await r.dbsize() == 2 + @pytest.mark.onlynoncluster async def test_echo(self, r: redis.Redis): assert await r.echo("foo bar") == b"foo bar" + @pytest.mark.onlynoncluster async def test_info(self, r: redis.Redis): await r.set("a", "foo") await r.set("b", "bar") @@ -520,6 +535,7 @@ async def test_info(self, r: redis.Redis): assert isinstance(info, dict) assert info["db9"]["keys"] == 2 + @pytest.mark.onlynoncluster async def test_lastsave(self, r: redis.Redis): assert isinstance(await r.lastsave(), datetime.datetime) @@ -533,6 +549,7 @@ async def test_object(self, r: redis.Redis): async def test_ping(self, r: redis.Redis): assert await r.ping() + @pytest.mark.onlynoncluster async def test_slowlog_get(self, r: redis.Redis, slowlog): assert await r.slowlog_reset() unicode_string = chr(3456) + "abcd" + chr(3421) @@ -554,6 +571,7 @@ async def test_slowlog_get(self, r: redis.Redis, slowlog): assert isinstance(slowlog[0]["start_time"], int) assert isinstance(slowlog[0]["duration"], int) + @pytest.mark.onlynoncluster async def test_slowlog_get_limit(self, r: redis.Redis, slowlog): assert await r.slowlog_reset() await r.get("foo") @@ -562,6 +580,7 @@ async def test_slowlog_get_limit(self, r: redis.Redis, slowlog): # only one command, based on the number we passed to slowlog_get() assert len(slowlog) == 1 + @pytest.mark.onlynoncluster async def test_slowlog_length(self, r: redis.Redis, slowlog): await r.get("foo") assert isinstance(await r.slowlog_len(), int) @@ -600,12 +619,14 @@ async def test_bitcount(self, r: redis.Redis): assert await r.bitcount("a", 1, 1) == 1 @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not_empty_string(self, r: redis.Redis): await r.set("a", "") await r.bitop("not", "r", "a") assert await r.get("r") is None @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not(self, r: redis.Redis): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF @@ -614,6 +635,7 @@ async def test_bitop_not(self, r: redis.Redis): assert int(binascii.hexlify(await r.get("r")), 16) == correct @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_not_in_place(self, r: redis.Redis): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF @@ -622,6 +644,7 @@ async def test_bitop_not_in_place(self, r: redis.Redis): assert int(binascii.hexlify(await r.get("a")), 16) == correct @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_single_string(self, r: redis.Redis): test_str = b"\x01\x02\xFF" await r.set("a", test_str) @@ -633,6 +656,7 @@ async def test_bitop_single_string(self, r: redis.Redis): assert await r.get("res3") == test_str @skip_if_server_version_lt("2.6.0") + @pytest.mark.onlynoncluster async def test_bitop_string_operands(self, r: redis.Redis): await r.set("a", b"\x01\x02\xFF\xFF") await r.set("b", b"\x01\x02\xFF") @@ -643,6 +667,7 @@ async def test_bitop_string_operands(self, r: redis.Redis): assert int(binascii.hexlify(await r.get("res2")), 16) == 0x0102FFFF assert int(binascii.hexlify(await r.get("res3")), 16) == 0x000000FF + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") async def test_bitpos(self, r: redis.Redis): key = "key:bitpos" @@ -838,6 +863,7 @@ async def test_incrbyfloat(self, r: redis.Redis): assert await r.incrbyfloat("a", 1.1) == 2.1 assert float(await r.get("a")) == float(2.1) + @pytest.mark.onlynoncluster async def test_keys(self, r: redis.Redis): assert await r.keys() == [] keys_with_underscores = {b"test_a", b"test_b"} @@ -847,6 +873,7 @@ async def test_keys(self, r: redis.Redis): assert set(await r.keys(pattern="test_*")) == keys_with_underscores assert set(await r.keys(pattern="test*")) == keys + @pytest.mark.onlynoncluster async def test_mget(self, r: redis.Redis): assert await r.mget([]) == [] assert await r.mget(["a", "b"]) == [None, None] @@ -855,12 +882,14 @@ async def test_mget(self, r: redis.Redis): await r.set("c", "3") assert await r.mget("a", "other", "b", "c") == [b"1", None, b"2", b"3"] + @pytest.mark.onlynoncluster async def test_mset(self, r: redis.Redis): d = {"a": b"1", "b": b"2", "c": b"3"} assert await r.mset(d) for k, v in d.items(): assert await r.get(k) == v + @pytest.mark.onlynoncluster async def test_msetnx(self, r: redis.Redis): d = {"a": b"1", "b": b"2", "c": b"3"} assert await r.msetnx(d) @@ -926,18 +955,21 @@ async def test_pttl_no_key(self, r: redis.Redis): """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None for key in ("a", "b", "c"): await r.set(key, 1) assert await r.randomkey() in (b"a", b"b", b"c") + @pytest.mark.onlynoncluster async def test_rename(self, r: redis.Redis): await r.set("a", "1") assert await r.rename("a", "b") assert await r.get("a") is None assert await r.get("b") == b"1" + @pytest.mark.onlynoncluster async def test_renamenx(self, r: redis.Redis): await r.set("a", "1") await r.set("b", "2") @@ -1055,6 +1087,7 @@ async def test_type(self, r: redis.Redis): assert await r.type("a") == b"zset" # LIST COMMANDS + @pytest.mark.onlynoncluster async def test_blpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1066,6 +1099,7 @@ async def test_blpop(self, r: redis.Redis): await r.rpush("c", "1") assert await r.blpop("c", timeout=1) == (b"c", b"1") + @pytest.mark.onlynoncluster async def test_brpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1077,6 +1111,7 @@ async def test_brpop(self, r: redis.Redis): await r.rpush("c", "1") assert await r.brpop("c", timeout=1) == (b"c", b"1") + @pytest.mark.onlynoncluster async def test_brpoplpush(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") @@ -1086,6 +1121,7 @@ async def test_brpoplpush(self, r: redis.Redis): assert await r.lrange("a", 0, -1) == [] assert await r.lrange("b", 0, -1) == [b"1", b"2", b"3", b"4"] + @pytest.mark.onlynoncluster async def test_brpoplpush_empty_string(self, r: redis.Redis): await r.rpush("a", "") assert await r.brpoplpush("a", "b") == b"" @@ -1163,6 +1199,7 @@ async def test_rpop(self, r: redis.Redis): assert await r.rpop("a") == b"1" assert await r.rpop("a") is None + @pytest.mark.onlynoncluster async def test_rpoplpush(self, r: redis.Redis): await r.rpush("a", "a1", "a2", "a3") await r.rpush("b", "b1", "b2", "b3") @@ -1217,6 +1254,7 @@ async def test_rpushx(self, r: redis.Redis): # SCAN COMMANDS @skip_if_server_version_lt("2.8.0") + @pytest.mark.onlynoncluster async def test_scan(self, r: redis.Redis): await r.set("a", 1) await r.set("b", 2) @@ -1228,6 +1266,7 @@ async def test_scan(self, r: redis.Redis): assert set(keys) == {b"a"} @skip_if_server_version_lt(REDIS_6_VERSION) + @pytest.mark.onlynoncluster async def test_scan_type(self, r: redis.Redis): await r.sadd("a-set", 1) await r.hset("a-hash", "foo", 2) @@ -1236,6 +1275,7 @@ async def test_scan_type(self, r: redis.Redis): assert set(keys) == {b"a-set"} @skip_if_server_version_lt("2.8.0") + @pytest.mark.onlynoncluster async def test_scan_iter(self, r: redis.Redis): await r.set("a", 1) await r.set("b", 2) @@ -1306,12 +1346,14 @@ async def test_scard(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.scard("a") == 3 + @pytest.mark.onlynoncluster async def test_sdiff(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sdiff("a", "b") == {b"1", b"2", b"3"} await r.sadd("b", "2", "3") assert await r.sdiff("a", "b") == {b"1"} + @pytest.mark.onlynoncluster async def test_sdiffstore(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sdiffstore("c", "a", "b") == 3 @@ -1320,12 +1362,14 @@ async def test_sdiffstore(self, r: redis.Redis): assert await r.sdiffstore("c", "a", "b") == 1 assert await r.smembers("c") == {b"1"} + @pytest.mark.onlynoncluster async def test_sinter(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sinter("a", "b") == set() await r.sadd("b", "2", "3") assert await r.sinter("a", "b") == {b"2", b"3"} + @pytest.mark.onlynoncluster async def test_sinterstore(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.sinterstore("c", "a", "b") == 0 @@ -1345,6 +1389,7 @@ async def test_smembers(self, r: redis.Redis): await r.sadd("a", "1", "2", "3") assert await r.smembers("a") == {b"1", b"2", b"3"} + @pytest.mark.onlynoncluster async def test_smove(self, r: redis.Redis): await r.sadd("a", "a1", "a2") await r.sadd("b", "b1", "b2") @@ -1390,11 +1435,13 @@ async def test_srem(self, r: redis.Redis): assert await r.srem("a", "2", "4") == 2 assert await r.smembers("a") == {b"1", b"3"} + @pytest.mark.onlynoncluster async def test_sunion(self, r: redis.Redis): await r.sadd("a", "1", "2") await r.sadd("b", "2", "3") assert await r.sunion("a", "b") == {b"1", b"2", b"3"} + @pytest.mark.onlynoncluster async def test_sunionstore(self, r: redis.Redis): await r.sadd("a", "1", "2") await r.sadd("b", "2", "3") @@ -1479,6 +1526,7 @@ async def test_zlexcount(self, r: redis.Redis): assert await r.zlexcount("a", "-", "+") == 7 assert await r.zlexcount("a", "[b", "[f") == 5 + @pytest.mark.onlynoncluster async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1486,6 +1534,7 @@ async def test_zinterstore_sum(self, r: redis.Redis): assert await r.zinterstore("d", ["a", "b", "c"]) == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1493,6 +1542,7 @@ async def test_zinterstore_max(self, r: redis.Redis): assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) @@ -1500,6 +1550,7 @@ async def test_zinterstore_min(self, r: redis.Redis): assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1524,6 +1575,7 @@ async def test_zpopmin(self, r: redis.Redis): assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] @skip_if_server_version_lt("4.9.0") + @pytest.mark.onlynoncluster async def test_bzpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) @@ -1536,6 +1588,7 @@ async def test_bzpopmax(self, r: redis.Redis): assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) @skip_if_server_version_lt("4.9.0") + @pytest.mark.onlynoncluster async def test_bzpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) @@ -1703,6 +1756,7 @@ async def test_zscore(self, r: redis.Redis): assert await r.zscore("a", "a2") == 2.0 assert await r.zscore("a", "a4") is None + @pytest.mark.onlynoncluster async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1715,6 +1769,7 @@ async def test_zunionstore_sum(self, r: redis.Redis): (b"a1", 9), ] + @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1727,6 +1782,7 @@ async def test_zunionstore_max(self, r: redis.Redis): (b"a1", 6), ] + @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) @@ -1739,6 +1795,7 @@ async def test_zunionstore_min(self, r: redis.Redis): (b"a4", 4), ] + @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) @@ -1760,6 +1817,7 @@ async def test_pfadd(self, r: redis.Redis): assert await r.pfcount("a") == len(members) @skip_if_server_version_lt("2.8.9") + @pytest.mark.onlynoncluster async def test_pfcount(self, r: redis.Redis): members = {b"1", b"2", b"3"} await r.pfadd("a", *members) @@ -1770,6 +1828,7 @@ async def test_pfcount(self, r: redis.Redis): assert await r.pfcount("a", "b") == len(members_b.union(members)) @skip_if_server_version_lt("2.8.9") + @pytest.mark.onlynoncluster async def test_pfmerge(self, r: redis.Redis): mema = {b"1", b"2", b"3"} memb = {b"2", b"3", b"4"} @@ -1900,6 +1959,7 @@ async def test_sort_limited(self, r: redis.Redis): await r.rpush("a", "3", "2", "1", "4") assert await r.sort("a", start=1, num=2) == [b"2", b"3"] + @pytest.mark.onlynoncluster async def test_sort_by(self, r: redis.Redis): await r.set("score:1", 8) await r.set("score:2", 3) @@ -1907,6 +1967,7 @@ async def test_sort_by(self, r: redis.Redis): await r.rpush("a", "3", "2", "1") assert await r.sort("a", by="score:*") == [b"2", b"3", b"1"] + @pytest.mark.onlynoncluster async def test_sort_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1914,6 +1975,7 @@ async def test_sort_get(self, r: redis.Redis): await r.rpush("a", "2", "3", "1") assert await r.sort("a", get="user:*") == [b"u1", b"u2", b"u3"] + @pytest.mark.onlynoncluster async def test_sort_get_multi(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1928,6 +1990,7 @@ async def test_sort_get_multi(self, r: redis.Redis): b"3", ] + @pytest.mark.onlynoncluster async def test_sort_get_groups_two(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1939,6 +2002,7 @@ async def test_sort_get_groups_two(self, r: redis.Redis): (b"u3", b"3"), ] + @pytest.mark.onlynoncluster async def test_sort_groups_string_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1947,6 +2011,7 @@ async def test_sort_groups_string_get(self, r: redis.Redis): with pytest.raises(exceptions.DataError): await r.sort("a", get="user:*", groups=True) + @pytest.mark.onlynoncluster async def test_sort_groups_just_one_get(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1963,6 +2028,7 @@ async def test_sort_groups_no_get(self, r: redis.Redis): with pytest.raises(exceptions.DataError): await r.sort("a", groups=True) + @pytest.mark.onlynoncluster async def test_sort_groups_three_gets(self, r: redis.Redis): await r.set("user:1", "u1") await r.set("user:2", "u2") @@ -1985,11 +2051,13 @@ async def test_sort_alpha(self, r: redis.Redis): await r.rpush("a", "e", "c", "b", "d", "a") assert await r.sort("a", alpha=True) == [b"a", b"b", b"c", b"d", b"e"] + @pytest.mark.onlynoncluster async def test_sort_store(self, r: redis.Redis): await r.rpush("a", "2", "3", "1") assert await r.sort("a", store="sorted_values") == 3 assert await r.lrange("sorted_values", 0, -1) == [b"1", b"2", b"3"] + @pytest.mark.onlynoncluster async def test_sort_all_options(self, r: redis.Redis): await r.set("user:1:username", "zeus") await r.set("user:2:username", "titan") @@ -2033,70 +2101,88 @@ async def test_sort_issue_924(self, r: redis.Redis): await r.execute_command("SADD", "issue#924", 1) await r.execute_command("SORT", "issue#924") + @pytest.mark.onlynoncluster async def test_cluster_addslots(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("ADDSLOTS", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_count_failure_reports(self, mock_cluster_resp_int): assert isinstance( await mock_cluster_resp_int.cluster("COUNT-FAILURE-REPORTS", "node"), int ) + @pytest.mark.onlynoncluster async def test_cluster_countkeysinslot(self, mock_cluster_resp_int): assert isinstance( await mock_cluster_resp_int.cluster("COUNTKEYSINSLOT", 2), int ) + @pytest.mark.onlynoncluster async def test_cluster_delslots(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("DELSLOTS", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_failover(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("FAILOVER", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_forget(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("FORGET", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_info(self, mock_cluster_resp_info): assert isinstance(await mock_cluster_resp_info.cluster("info"), dict) + @pytest.mark.onlynoncluster async def test_cluster_keyslot(self, mock_cluster_resp_int): assert isinstance(await mock_cluster_resp_int.cluster("keyslot", "asdf"), int) + @pytest.mark.onlynoncluster async def test_cluster_meet(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("meet", "ip", "port", 1) is True + @pytest.mark.onlynoncluster async def test_cluster_nodes(self, mock_cluster_resp_nodes): assert isinstance(await mock_cluster_resp_nodes.cluster("nodes"), dict) + @pytest.mark.onlynoncluster async def test_cluster_replicate(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("replicate", "nodeid") is True + @pytest.mark.onlynoncluster async def test_cluster_reset(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("reset", "hard") is True + @pytest.mark.onlynoncluster async def test_cluster_saveconfig(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.cluster("saveconfig") is True + @pytest.mark.onlynoncluster async def test_cluster_setslot(self, mock_cluster_resp_ok): assert ( await mock_cluster_resp_ok.cluster("setslot", 1, "IMPORTING", "nodeid") is True ) + @pytest.mark.onlynoncluster async def test_cluster_slaves(self, mock_cluster_resp_slaves): assert isinstance( await mock_cluster_resp_slaves.cluster("slaves", "nodeid"), dict ) @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readwrite(self, r: redis.Redis): assert await r.readwrite() @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readonly_invalid_cluster_state(self, r: redis.Redis): with pytest.raises(exceptions.RedisError): await r.readonly() @skip_if_server_version_lt("3.0.0") + @pytest.mark.onlynoncluster async def test_readonly(self, mock_cluster_resp_ok): assert await mock_cluster_resp_ok.readonly() is True @@ -2313,6 +2399,7 @@ async def test_georadius_sort(self, r: redis.Redis): ] @skip_if_server_version_lt("3.2.0") + @pytest.mark.onlynoncluster async def test_georadius_store(self, r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, @@ -2326,6 +2413,7 @@ async def test_georadius_store(self, r: redis.Redis): @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") + @pytest.mark.onlynoncluster async def test_georadius_store_dist(self, r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, @@ -2721,25 +2809,11 @@ async def test_xread(self, r: redis.Redis): # xread starting at 0 returns both messages assert await r.xread(streams={stream: 0}) == expected - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] # xread starting at 0 and count=1 returns only the first message assert await r.xread(streams={stream: 0}, count=1) == expected - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m2), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] # xread starting at m1 returns only the second message assert await r.xread(streams={stream: m1}) == expected @@ -2770,14 +2844,7 @@ async def test_xreadgroup(self, r: redis.Redis): await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - ], - ] - ] + expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] # xread with count=1 returns only the first message assert ( await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) @@ -2815,15 +2882,7 @@ async def test_xreadgroup(self, r: redis.Redis): await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [ - [ - stream.encode(), - [ - (m1, {}), - (m2, {}), - ], - ] - ] + expected = [[stream.encode(), [(m1, {}), (m2, {})]]] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected @@ -2956,11 +3015,13 @@ async def test_memory_usage(self, r: redis.Redis): assert isinstance(await r.memory_usage("foo"), int) @skip_if_server_version_lt("4.0.0") + @pytest.mark.onlynoncluster async def test_module_list(self, r: redis.Redis): assert isinstance(await r.module_list(), list) assert not await r.module_list() +@pytest.mark.onlynoncluster class TestBinarySave: async def test_binary_get_set(self, r: redis.Redis): assert await r.set(" foo bar ", "123") diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index f9dfefd5cc..b5a5e77f1f 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -301,33 +301,23 @@ class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url("redis://my.host") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "my.host", - } + assert pool.connection_kwargs == {"host": "my.host"} def test_quoted_hostname(self): pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "my / host +=+", - } + assert pool.connection_kwargs == {"host": "my / host +=+"} def test_port(self): pool = redis.ConnectionPool.from_url("redis://localhost:6380") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "port": 6380, - } + assert pool.connection_kwargs == {"host": "localhost", "port": 6380} @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "username": "myuser", - } + assert pool.connection_kwargs == {"host": "localhost", "username": "myuser"} @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): @@ -343,10 +333,7 @@ def test_quoted_username(self): def test_password(self): pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "password": "mypassword", - } + assert pool.connection_kwargs == {"host": "localhost", "password": "mypassword"} def test_quoted_password(self): pool = redis.ConnectionPool.from_url( @@ -371,26 +358,17 @@ def test_username_and_password(self): def test_db_as_argument(self): pool = redis.ConnectionPool.from_url("redis://localhost", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 1, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 1} def test_db_in_path(self): pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 2, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 2} def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) assert pool.connection_class == redis.Connection - assert pool.connection_kwargs == { - "host": "localhost", - "db": 3, - } + assert pool.connection_kwargs == {"host": "localhost", "db": 3} def test_extra_typed_querystring_options(self): pool = redis.ConnectionPool.from_url( @@ -450,9 +428,7 @@ def test_calling_from_subclass_returns_correct_instance(self): def test_client_creates_connection_pool(self): r = redis.Redis.from_url("redis://myhost") assert r.connection_pool.connection_class == redis.Connection - assert r.connection_pool.connection_kwargs == { - "host": "myhost", - } + assert r.connection_pool.connection_kwargs == {"host": "myhost"} def test_invalid_scheme_raises_error(self): with pytest.raises(ValueError) as cm: @@ -468,18 +444,13 @@ class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - } + assert pool.connection_kwargs == {"path": "/socket"} @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "username": "myuser", - } + assert pool.connection_kwargs == {"path": "/socket", "username": "myuser"} @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): @@ -495,10 +466,7 @@ def test_quoted_username(self): def test_password(self): pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "password": "mypassword", - } + assert pool.connection_kwargs == {"path": "/socket", "password": "mypassword"} def test_quoted_password(self): pool = redis.ConnectionPool.from_url( @@ -523,18 +491,12 @@ def test_quoted_path(self): def test_db_as_argument(self): pool = redis.ConnectionPool.from_url("unix:///socket", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 1, - } + assert pool.connection_kwargs == {"path": "/socket", "db": 1} def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) assert pool.connection_class == redis.UnixDomainSocketConnection - assert pool.connection_kwargs == { - "path": "/socket", - "db": 2, - } + assert pool.connection_kwargs == {"path": "/socket", "db": 2} def test_client_name_in_querystring(self): pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") @@ -551,9 +513,7 @@ class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") assert pool.connection_class == redis.SSLConnection - assert pool.connection_kwargs == { - "host": "my.host", - } + assert pool.connection_kwargs == {"host": "my.host"} def test_cert_reqs_options(self): import ssl From 0b077066f31b4aba3156044621ac3ac353937ea8 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Mon, 11 Apr 2022 01:02:53 +0530 Subject: [PATCH 02/23] Async Cluster Tests: Async/Await --- tests/test_asyncio/test_cluster.py | 1949 ++++++++++------------------ 1 file changed, 701 insertions(+), 1248 deletions(-) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 2a2bc883bf..f6634e8782 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -31,14 +31,12 @@ ResponseError, ) from redis.utils import str_if_bytes -from tests.test_pubsub import wait_for_message from .conftest import ( _get_client, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, - wait_for_command, ) default_host = "127.0.0.1" @@ -50,7 +48,7 @@ @pytest.fixture() -def slowlog(request, r): +async def slowlog(request, r): """ Set the slowlog threshold to 0, and the max length to 128. This will force every @@ -58,23 +56,21 @@ def slowlog(request, r): to test it """ # Save old values - current_config = r.config_get(target_nodes=r.get_primaries()[0]) + current_config = await r.config_get(target_nodes=r.get_primaries()[0]) old_slower_than_value = current_config["slowlog-log-slower-than"] - old_max_legnth_value = current_config["slowlog-max-len"] + old_max_length_value = current_config["slowlog-max-len"] - # Function to restore the old values - def cleanup(): - r.config_set("slowlog-log-slower-than", old_slower_than_value) - r.config_set("slowlog-max-len", old_max_legnth_value) + # Set the new values + await r.config_set("slowlog-log-slower-than", 0) + await r.config_set("slowlog-max-len", 128) - request.addfinalizer(cleanup) + yield - # Set the new values - r.config_set("slowlog-log-slower-than", 0) - r.config_set("slowlog-max-len", 128) + await r.config_set("slowlog-log-slower-than", old_slower_than_value) + await r.config_set("slowlog-max-len", old_max_length_value) -def get_mocked_redis_client(func=None, *args, **kwargs): +async def get_mocked_redis_client(func=None, *args, **kwargs): """ Return a stable RedisCluster object that have deterministic nodes and slots setup to remove the problem of different IP addresses @@ -85,7 +81,7 @@ def get_mocked_redis_client(func=None, *args, **kwargs): cluster_enabled = kwargs.pop("cluster_enabled", True) with patch.object(Redis, "execute_command") as execute_command_mock: - def execute_command(*_args, **_kwargs): + async def execute_command(*_args, **_kwargs): if _args[0] == "CLUSTER SLOTS": mock_cluster_slots = cluster_slots return mock_cluster_slots @@ -149,7 +145,7 @@ def find_node_ip_based_on_port(cluster_client, port): return node.host -def moved_redirection_helper(request, failover=False): +async def moved_redirection_helper(request, failover=False): """ Test that the client handles MOVED response after a failover. Redirection after a failover means that the redirection address is of a @@ -194,7 +190,7 @@ def ok_response(connection, *args, **options): raise MovedError(f"{slot} {r_host}:{r_port}") parse_response.side_effect = moved_redirect_effect - assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" + assert await rc.execute_command("SET", "foo", "bar") == "MOCK_OK" slot_primary = rc.nodes_manager.slots_cache[slot][0] assert slot_primary == redirect_node if failover: @@ -208,15 +204,15 @@ class TestRedisClusterObj: Tests for the RedisCluster class """ - def test_host_port_startup_node(self): + async def test_host_port_startup_node(self): """ Test that it is possible to use host & port arguments as startup node args """ - cluster = get_mocked_redis_client(host=default_host, port=default_port) + cluster = await get_mocked_redis_client(host=default_host, port=default_port) assert cluster.get_node(host=default_host, port=default_port) is not None - def test_startup_nodes(self): + async def test_startup_nodes(self): """ Test that it is possible to use startup_nodes argument to init the cluster @@ -227,13 +223,13 @@ def test_startup_nodes(self): ClusterNode(default_host, port_1), ClusterNode(default_host, port_2), ] - cluster = get_mocked_redis_client(startup_nodes=startup_nodes) + cluster = await get_mocked_redis_client(startup_nodes=startup_nodes) assert ( cluster.get_node(host=default_host, port=port_1) is not None and cluster.get_node(host=default_host, port=port_2) is not None ) - def test_empty_startup_nodes(self): + async def test_empty_startup_nodes(self): """ Test that exception is raised when empty providing empty startup_nodes """ @@ -244,35 +240,35 @@ def test_empty_startup_nodes(self): "RedisCluster requires at least one node to discover the " "cluster" ), str_if_bytes(ex.value) - def test_from_url(self, r): + async def test_from_url(self, r): redis_url = f"redis://{default_host}:{default_port}/0" with patch.object(RedisCluster, "from_url") as from_url: - def from_url_mocked(_url, **_kwargs): - return get_mocked_redis_client(url=_url, **_kwargs) + async def from_url_mocked(_url, **_kwargs): + return await get_mocked_redis_client(url=_url, **_kwargs) from_url.side_effect = from_url_mocked - cluster = RedisCluster.from_url(redis_url) + cluster = await RedisCluster.from_url(redis_url) assert cluster.get_node(host=default_host, port=default_port) is not None - def test_execute_command_errors(self, r): + async def test_execute_command_errors(self, r): """ Test that if no key is provided then exception should be raised. """ with pytest.raises(RedisClusterException) as ex: - r.execute_command("GET") + await r.execute_command("GET") assert str(ex.value).startswith( "No way to dispatch this command to " "Redis Cluster. Missing key." ) - def test_execute_command_node_flag_primaries(self, r): + async def test_execute_command_node_flag_primaries(self, r): """ Test command execution with nodes flag PRIMARIES """ primaries = r.get_primaries() replicas = r.get_replicas() mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.PRIMARIES) is True + assert await r.ping(target_nodes=RedisCluster.PRIMARIES) is True for primary in primaries: conn = primary.redis_connection.connection assert conn.read_response.called is True @@ -280,16 +276,16 @@ def test_execute_command_node_flag_primaries(self, r): conn = replica.redis_connection.connection assert conn.read_response.called is not True - def test_execute_command_node_flag_replicas(self, r): + async def test_execute_command_node_flag_replicas(self, r): """ Test command execution with nodes flag REPLICAS """ replicas = r.get_replicas() if not replicas: - r = get_mocked_redis_client(default_host, default_port) + r = await get_mocked_redis_client(default_host, default_port) primaries = r.get_primaries() mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.REPLICAS) is True + assert await r.ping(target_nodes=RedisCluster.REPLICAS) is True for replica in replicas: conn = replica.redis_connection.connection assert conn.read_response.called is True @@ -297,22 +293,22 @@ def test_execute_command_node_flag_replicas(self, r): conn = primary.redis_connection.connection assert conn.read_response.called is not True - def test_execute_command_node_flag_all_nodes(self, r): + async def test_execute_command_node_flag_all_nodes(self, r): """ Test command execution with nodes flag ALL_NODES """ mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.ALL_NODES) is True + assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True for node in r.get_nodes(): conn = node.redis_connection.connection assert conn.read_response.called is True - def test_execute_command_node_flag_random(self, r): + async def test_execute_command_node_flag_random(self, r): """ Test command execution with nodes flag RANDOM """ mock_all_nodes_resp(r, "PONG") - assert r.ping(target_nodes=RedisCluster.RANDOM) is True + assert await r.ping(target_nodes=RedisCluster.RANDOM) is True called_count = 0 for node in r.get_nodes(): conn = node.redis_connection.connection @@ -320,18 +316,18 @@ def test_execute_command_node_flag_random(self, r): called_count += 1 assert called_count == 1 - def test_execute_command_default_node(self, r): + async def test_execute_command_default_node(self, r): """ Test command execution without node flag is being executed on the default node """ def_node = r.get_default_node() mock_node_resp(def_node, "PONG") - assert r.ping() is True + assert await r.ping() is True conn = def_node.redis_connection.connection assert conn.read_response.called - def test_ask_redirection(self, r): + async def test_ask_redirection(self, r): """ Test that the server handles ASK response. @@ -355,21 +351,21 @@ def ok_response(connection, *args, **options): parse_response.side_effect = ask_redirect_effect - assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" + assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" - def test_moved_redirection(self, request): + async def test_moved_redirection(self, request): """ Test that the client handles MOVED response. """ - moved_redirection_helper(request, failover=False) + await moved_redirection_helper(request, failover=False) - def test_moved_redirection_after_failover(self, request): + async def test_moved_redirection_after_failover(self, request): """ Test that the client handles MOVED response after a failover. """ - moved_redirection_helper(request, failover=True) + await moved_redirection_helper(request, failover=True) - def test_refresh_using_specific_nodes(self, request): + async def test_refresh_using_specific_nodes(self, request): """ Test making calls on specific nodes when the cluster has failed over to another node @@ -444,7 +440,7 @@ def cmd_init_mock(self, r): assert len(rc.get_nodes()) == 1 assert rc.get_node(node_name=node_7006.name) is not None - rc.get("foo") + await rc.get("foo") # Cluster should now point to 7007, and there should be # one failed and one successful call @@ -454,7 +450,7 @@ def cmd_init_mock(self, r): assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 - def test_reading_from_replicas_in_round_robin(self): + async def test_reading_from_replicas_in_round_robin(self): with patch.multiple( Connection, send_command=DEFAULT, @@ -493,7 +489,7 @@ def parse_response_mock_third(connection, *args, **options): mocks["on_connect"].return_value = True # Create a cluster with reading from replications - read_cluster = get_mocked_redis_client( + read_cluster = await get_mocked_redis_client( host=default_host, port=default_port, read_from_replicas=True ) assert read_cluster.read_from_replicas is True @@ -501,12 +497,12 @@ def parse_response_mock_third(connection, *args, **options): # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] - read_cluster.get("foo") - read_cluster.get("foo") - read_cluster.get("foo") + await read_cluster.get("foo") + await read_cluster.get("foo") + await read_cluster.get("foo") mocks["send_command"].assert_has_calls([call("READONLY")]) - def test_keyslot(self, r): + async def test_keyslot(self, r): """ Test that method will compute correct key in all supported cases """ @@ -523,13 +519,13 @@ def test_keyslot(self, r): assert r.keyslot(1337) == r.keyslot("1337") assert r.keyslot(b"abc") == r.keyslot("abc") - def test_get_node_name(self): + async def test_get_node_name(self): assert ( get_node_name(default_host, default_port) == f"{default_host}:{default_port}" ) - def test_all_nodes(self, r): + async def test_all_nodes(self, r): """ Set a list of nodes and it should be possible to iterate over all """ @@ -538,7 +534,7 @@ def test_all_nodes(self, r): for i, node in enumerate(r.get_nodes()): assert node in nodes - def test_all_nodes_masters(self, r): + async def test_all_nodes_masters(self, r): """ Set a list of nodes with random primaries/replicas config and it shold be possible to iterate over all of them. @@ -553,7 +549,7 @@ def test_all_nodes_masters(self, r): assert node in nodes @pytest.mark.parametrize("error", RedisCluster.ERRORS_ALLOW_RETRY) - def test_cluster_down_overreaches_retry_attempts(self, error): + async def test_cluster_down_overreaches_retry_attempts(self, error): """ When error that allows retry is thrown, test that we retry executing the command as many times as configured in cluster_error_retry_attempts @@ -567,13 +563,13 @@ def raise_error(target_node, *args, **kwargs): execute_command.side_effect = raise_error - rc = get_mocked_redis_client(host=default_host, port=default_port) + rc = await get_mocked_redis_client(host=default_host, port=default_port) with pytest.raises(error): - rc.get("bar") + await rc.get("bar") assert execute_command.failed_calls == rc.cluster_error_retry_attempts - def test_user_on_connect_function(self, request): + async def test_user_on_connect_function(self, request): """ Test support in passing on_connect function by the user """ @@ -586,7 +582,7 @@ def on_connect(connection): _get_client(RedisCluster, request, redis_connect_func=mock) assert mock.called is True - def test_set_default_node_success(self, r): + async def test_set_default_node_success(self, r): """ test successful replacement of the default cluster node """ @@ -600,7 +596,7 @@ def test_set_default_node_success(self, r): assert r.set_default_node(new_def_node) is True assert r.get_default_node() == new_def_node - def test_set_default_node_failure(self, r): + async def test_set_default_node_failure(self, r): """ test failed replacement of the default cluster node """ @@ -610,7 +606,7 @@ def test_set_default_node_failure(self, r): assert r.set_default_node(new_def_node) is False assert r.get_default_node() == default_node - def test_get_node_from_key(self, r): + async def test_get_node_from_key(self, r): """ Test that get_node_from_key function returns the correct node """ @@ -625,7 +621,7 @@ def test_get_node_from_key(self, r): assert replica in slot_nodes @skip_if_redis_enterprise() - def test_not_require_full_coverage_cluster_down_error(self, r): + async def test_not_require_full_coverage_cluster_down_error(self, r): """ When require_full_coverage is set to False (default client config) and not all slots are covered, if one of the nodes has 'cluster-require_full_coverage' @@ -633,17 +629,17 @@ def test_not_require_full_coverage_cluster_down_error(self, r): """ node = r.get_node_from_key("foo") missing_slot = r.keyslot("foo") - assert r.set("foo", "bar") is True + assert await r.set("foo", "bar") is True try: - assert all(r.cluster_delslots(missing_slot)) + assert all(await r.cluster_delslots(missing_slot)) with pytest.raises(ClusterDownError): - r.exists("foo") + await r.exists("foo") finally: try: # Add back the missing slot - assert r.cluster_addslots(node, missing_slot) is True + assert await r.cluster_addslots(node, missing_slot) is True # Make sure we are not getting ClusterDownError anymore - assert r.exists("foo") == 1 + assert await r.exists("foo") == 1 except ResponseError as e: if f"Slot {missing_slot} is already busy" in str(e): # It can happen if the test failed to delete this slot @@ -658,275 +654,183 @@ class TestClusterRedisCommands: Tests for RedisCluster unique commands """ - def test_case_insensitive_command_names(self, r): + async def test_case_insensitive_command_names(self, r): assert ( r.cluster_response_callbacks["cluster addslots"] == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] ) - def test_get_and_set(self, r): + async def test_get_and_set(self, r): # get and set can't be tested independently of each other - assert r.get("a") is None + assert await r.get("a") is None byte_string = b"value" integer = 5 unicode_string = chr(3456) + "abcd" + chr(3421) - assert r.set("byte_string", byte_string) - assert r.set("integer", 5) - assert r.set("unicode_string", unicode_string) - assert r.get("byte_string") == byte_string - assert r.get("integer") == str(integer).encode() - assert r.get("unicode_string").decode("utf-8") == unicode_string - - def test_mget_nonatomic(self, r): - assert r.mget_nonatomic([]) == [] - assert r.mget_nonatomic(["a", "b"]) == [None, None] - r["a"] = "1" - r["b"] = "2" - r["c"] = "3" - - assert r.mget_nonatomic("a", "other", "b", "c") == [b"1", None, b"2", b"3"] - - def test_mset_nonatomic(self, r): + assert await r.set("byte_string", byte_string) + assert await r.set("integer", 5) + assert await r.set("unicode_string", unicode_string) + assert await r.get("byte_string") == byte_string + assert await r.get("integer") == str(integer).encode() + assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + + async def test_mget_nonatomic(self, r): + assert await r.mget_nonatomic([]) == [] + assert await r.mget_nonatomic(["a", "b"]) == [None, None] + await r.set("a", "1") + await r.set("b", "2") + await r.set("c", "3") + + assert await r.mget_nonatomic("a", "other", "b", "c") == [ + b"1", + None, + b"2", + b"3", + ] + + async def test_mset_nonatomic(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - assert r.mset_nonatomic(d) + assert await r.mset_nonatomic(d) for k, v in d.items(): - assert r[k] == v + assert await r.get(k) == v - def test_config_set(self, r): - assert r.config_set("slowlog-log-slower-than", 0) + async def test_config_set(self, r): + assert await r.config_set("slowlog-log-slower-than", 0) - def test_cluster_config_resetstat(self, r): - r.ping(target_nodes="all") - all_info = r.info(target_nodes="all") + async def test_cluster_config_resetstat(self, r): + await r.ping(target_nodes="all") + all_info = await r.info(target_nodes="all") prior_commands_processed = -1 for node_info in all_info.values(): prior_commands_processed = node_info["total_commands_processed"] assert prior_commands_processed >= 1 - r.config_resetstat(target_nodes="all") - all_info = r.info(target_nodes="all") + await r.config_resetstat(target_nodes="all") + all_info = await r.info(target_nodes="all") for node_info in all_info.values(): reset_commands_processed = node_info["total_commands_processed"] assert reset_commands_processed < prior_commands_processed - def test_client_setname(self, r): + async def test_client_setname(self, r): node = r.get_random_node() - r.client_setname("redis_py_test", target_nodes=node) - client_name = r.client_getname(target_nodes=node) + await r.client_setname("redis_py_test", target_nodes=node) + client_name = await r.client_getname(target_nodes=node) assert client_name == "redis_py_test" - def test_exists(self, r): + async def test_exists(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.exists(*d.keys()) == len(d) + await r.mset_nonatomic(d) + assert await r.exists(*d.keys()) == len(d) - def test_delete(self, r): + async def test_delete(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.delete(*d.keys()) == len(d) - assert r.delete(*d.keys()) == 0 + await r.mset_nonatomic(d) + assert await r.delete(*d.keys()) == len(d) + assert await r.delete(*d.keys()) == 0 - def test_touch(self, r): + async def test_touch(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.touch(*d.keys()) == len(d) + await r.mset_nonatomic(d) + assert await r.touch(*d.keys()) == len(d) - def test_unlink(self, r): + async def test_unlink(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - r.mset_nonatomic(d) - assert r.unlink(*d.keys()) == len(d) + await r.mset_nonatomic(d) + assert await r.unlink(*d.keys()) == len(d) # Unlink is non-blocking so we sleep before # verifying the deletion sleep(0.1) - assert r.unlink(*d.keys()) == 0 - - def test_pubsub_channels_merge_results(self, r): - nodes = r.get_nodes() - channels = [] - pubsub_nodes = [] - i = 0 - for node in nodes: - channel = f"foo{i}" - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.subscribe(channel) - b_channel = channel.encode("utf-8") - channels.append(b_channel) - # Assert that each node returns only the channel it subscribed to - sub_channels = node.redis_connection.pubsub_channels() - if not sub_channels: - # Try again after a short sleep - sleep(0.3) - sub_channels = node.redis_connection.pubsub_channels() - assert sub_channels == [b_channel] - i += 1 - # Assert that the cluster's pubsub_channels function returns ALL of - # the cluster's channels - result = r.pubsub_channels(target_nodes="all") - result.sort() - assert result == channels - - def test_pubsub_numsub_merge_results(self, r): - nodes = r.get_nodes() - pubsub_nodes = [] - channel = "foo" - b_channel = channel.encode("utf-8") - for node in nodes: - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.subscribe(channel) - # Assert that each node returns that only one client is subscribed - sub_chann_num = node.redis_connection.pubsub_numsub(channel) - if sub_chann_num == [(b_channel, 0)]: - sleep(0.3) - sub_chann_num = node.redis_connection.pubsub_numsub(channel) - assert sub_chann_num == [(b_channel, 1)] - # Assert that the cluster's pubsub_numsub function returns ALL clients - # subscribed to this channel in the entire cluster - assert r.pubsub_numsub(channel, target_nodes="all") == [(b_channel, len(nodes))] - - def test_pubsub_numpat_merge_results(self, r): - nodes = r.get_nodes() - pubsub_nodes = [] - pattern = "foo*" - for node in nodes: - # We will create different pubsub clients where each one is - # connected to a different node - p = r.pubsub(node) - pubsub_nodes.append(p) - p.psubscribe(pattern) - # Assert that each node returns that only one client is subscribed - sub_num_pat = node.redis_connection.pubsub_numpat() - if sub_num_pat == 0: - sleep(0.3) - sub_num_pat = node.redis_connection.pubsub_numpat() - assert sub_num_pat == 1 - # Assert that the cluster's pubsub_numsub function returns ALL clients - # subscribed to this channel in the entire cluster - assert r.pubsub_numpat(target_nodes="all") == len(nodes) - - @skip_if_server_version_lt("2.8.0") - def test_cluster_pubsub_channels(self, r): - p = r.pubsub() - p.subscribe("foo", "bar", "baz", "quux") - for i in range(4): - assert wait_for_message(p, timeout=0.5)["type"] == "subscribe" - expected = [b"bar", b"baz", b"foo", b"quux"] - assert all( - [channel in r.pubsub_channels(target_nodes="all") for channel in expected] - ) - - @skip_if_server_version_lt("2.8.0") - def test_cluster_pubsub_numsub(self, r): - p1 = r.pubsub() - p1.subscribe("foo", "bar", "baz") - for i in range(3): - assert wait_for_message(p1, timeout=0.5)["type"] == "subscribe" - p2 = r.pubsub() - p2.subscribe("bar", "baz") - for i in range(2): - assert wait_for_message(p2, timeout=0.5)["type"] == "subscribe" - p3 = r.pubsub() - p3.subscribe("baz") - assert wait_for_message(p3, timeout=0.5)["type"] == "subscribe" - - channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] - assert r.pubsub_numsub("foo", "bar", "baz", target_nodes="all") == channels + assert await r.unlink(*d.keys()) == 0 @skip_if_redis_enterprise() - def test_cluster_myid(self, r): + async def test_cluster_myid(self, r): node = r.get_random_node() - myid = r.cluster_myid(node) + myid = await r.cluster_myid(node) assert len(myid) == 40 @skip_if_redis_enterprise() - def test_cluster_slots(self, r): + async def test_cluster_slots(self, r): mock_all_nodes_resp(r, default_cluster_slots) - cluster_slots = r.cluster_slots() + cluster_slots = await r.cluster_slots() assert isinstance(cluster_slots, dict) assert len(default_cluster_slots) == len(cluster_slots) assert cluster_slots.get((0, 8191)) is not None assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) @skip_if_redis_enterprise() - def test_cluster_addslots(self, r): + async def test_cluster_addslots(self, r): node = r.get_random_node() mock_node_resp(node, "OK") - assert r.cluster_addslots(node, 1, 2, 3) is True + assert await r.cluster_addslots(node, 1, 2, 3) is True @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - def test_cluster_addslotsrange(self, r): + async def test_cluster_addslotsrange(self, r): node = r.get_random_node() mock_node_resp(node, "OK") - assert r.cluster_addslotsrange(node, 1, 5) + assert await r.cluster_addslotsrange(node, 1, 5) @skip_if_redis_enterprise() - def test_cluster_countkeysinslot(self, r): + async def test_cluster_countkeysinslot(self, r): node = r.nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) - assert r.cluster_countkeysinslot(1) == 2 + assert await r.cluster_countkeysinslot(1) == 2 - def test_cluster_count_failure_report(self, r): + async def test_cluster_count_failure_report(self, r): mock_all_nodes_resp(r, 0) - assert r.cluster_count_failure_report("node_0") == 0 + assert await r.cluster_count_failure_report("node_0") == 0 @skip_if_redis_enterprise() - def test_cluster_delslots(self): + async def test_cluster_delslots(self): cluster_slots = [ [0, 8191, ["127.0.0.1", 7000, "node_0"]], [8192, 16383, ["127.0.0.1", 7001, "node_1"]], ] - r = get_mocked_redis_client( + r = await get_mocked_redis_client( host=default_host, port=default_port, cluster_slots=cluster_slots ) mock_all_nodes_resp(r, "OK") node0 = r.get_node(default_host, 7000) node1 = r.get_node(default_host, 7001) - assert r.cluster_delslots(0, 8192) == [True, True] + assert await r.cluster_delslots(0, 8192) == [True, True] assert node0.redis_connection.connection.read_response.called assert node1.redis_connection.connection.read_response.called @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - def test_cluster_delslotsrange(self, r): + async def test_cluster_delslotsrange(self, r): node = r.get_random_node() mock_node_resp(node, "OK") - r.cluster_addslots(node, 1, 2, 3, 4, 5) - assert r.cluster_delslotsrange(1, 5) + await r.cluster_addslots(node, 1, 2, 3, 4, 5) + assert await r.cluster_delslotsrange(1, 5) @skip_if_redis_enterprise() - def test_cluster_failover(self, r): + async def test_cluster_failover(self, r): node = r.get_random_node() mock_node_resp(node, "OK") - assert r.cluster_failover(node) is True - assert r.cluster_failover(node, "FORCE") is True - assert r.cluster_failover(node, "TAKEOVER") is True + assert await r.cluster_failover(node) is True + assert await r.cluster_failover(node, "FORCE") is True + assert await r.cluster_failover(node, "TAKEOVER") is True with pytest.raises(RedisError): - r.cluster_failover(node, "FORCT") + await r.cluster_failover(node, "FORCT") @skip_if_redis_enterprise() - def test_cluster_info(self, r): - info = r.cluster_info() + async def test_cluster_info(self, r): + info = await r.cluster_info() assert isinstance(info, dict) assert info["cluster_state"] == "ok" @skip_if_redis_enterprise() - def test_cluster_keyslot(self, r): + async def test_cluster_keyslot(self, r): mock_all_nodes_resp(r, 12182) - assert r.cluster_keyslot("foo") == 12182 + assert await r.cluster_keyslot("foo") == 12182 @skip_if_redis_enterprise() - def test_cluster_meet(self, r): + async def test_cluster_meet(self, r): node = r.get_default_node() mock_node_resp(node, "OK") - assert r.cluster_meet("127.0.0.1", 6379) is True + assert await r.cluster_meet("127.0.0.1", 6379) is True @skip_if_redis_enterprise() - def test_cluster_nodes(self, r): + async def test_cluster_nodes(self, r): response = ( "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " @@ -946,7 +850,7 @@ def test_cluster_nodes(self, r): "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) mock_all_nodes_resp(r, response) - nodes = r.cluster_nodes() + nodes = await r.cluster_nodes() assert len(nodes) == 7 assert nodes.get("172.17.0.7:7006") is not None assert ( @@ -955,7 +859,7 @@ def test_cluster_nodes(self, r): ) @skip_if_redis_enterprise() - def test_cluster_nodes_importing_migrating(self, r): + async def test_cluster_nodes_importing_migrating(self, r): response = ( "488ead2fcce24d8c0f158f9172cb1f4a9e040fe5 127.0.0.1:16381@26381 " "master - 0 1648975557664 3 connected 10923-16383\n" @@ -967,7 +871,7 @@ def test_cluster_nodes_importing_migrating(self, r): "2->-8ae2e70812db80776f739a72374e57fc4ae6f89d]\n" ) mock_all_nodes_resp(r, response) - nodes = r.cluster_nodes() + nodes = await r.cluster_nodes() assert len(nodes) == 3 node_16379 = nodes.get("127.0.0.1:16379") node_16380 = nodes.get("127.0.0.1:16380") @@ -992,12 +896,12 @@ def test_cluster_nodes_importing_migrating(self, r): assert node_16381.get("migrations") == [] @skip_if_redis_enterprise() - def test_cluster_replicate(self, r): + async def test_cluster_replicate(self, r): node = r.get_random_node() all_replicas = r.get_replicas() mock_all_nodes_resp(r, "OK") - assert r.cluster_replicate(node, "c8253bae761cb61857d") is True - results = r.cluster_replicate(all_replicas, "c8253bae761cb61857d") + assert await r.cluster_replicate(node, "c8253bae761cb61857d") is True + results = await r.cluster_replicate(all_replicas, "c8253bae761cb61857d") if isinstance(results, dict): for res in results.values(): assert res is True @@ -1005,60 +909,60 @@ def test_cluster_replicate(self, r): assert results is True @skip_if_redis_enterprise() - def test_cluster_reset(self, r): + async def test_cluster_reset(self, r): mock_all_nodes_resp(r, "OK") - assert r.cluster_reset() is True - assert r.cluster_reset(False) is True - all_results = r.cluster_reset(False, target_nodes="all") + assert await r.cluster_reset() is True + assert await r.cluster_reset(False) is True + all_results = await r.cluster_reset(False, target_nodes="all") for res in all_results.values(): assert res is True @skip_if_redis_enterprise() - def test_cluster_save_config(self, r): + async def test_cluster_save_config(self, r): node = r.get_random_node() all_nodes = r.get_nodes() mock_all_nodes_resp(r, "OK") - assert r.cluster_save_config(node) is True - all_results = r.cluster_save_config(all_nodes) + assert await r.cluster_save_config(node) is True + all_results = await r.cluster_save_config(all_nodes) for res in all_results.values(): assert res is True @skip_if_redis_enterprise() - def test_cluster_get_keys_in_slot(self, r): + async def test_cluster_get_keys_in_slot(self, r): response = [b"{foo}1", b"{foo}2"] node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) - keys = r.cluster_get_keys_in_slot(12182, 4) + keys = await r.cluster_get_keys_in_slot(12182, 4) assert keys == response @skip_if_redis_enterprise() - def test_cluster_set_config_epoch(self, r): + async def test_cluster_set_config_epoch(self, r): mock_all_nodes_resp(r, "OK") - assert r.cluster_set_config_epoch(3) is True - all_results = r.cluster_set_config_epoch(3, target_nodes="all") + assert await r.cluster_set_config_epoch(3) is True + all_results = await r.cluster_set_config_epoch(3, target_nodes="all") for res in all_results.values(): assert res is True @skip_if_redis_enterprise() - def test_cluster_setslot(self, r): + async def test_cluster_setslot(self, r): node = r.get_random_node() mock_node_resp(node, "OK") - assert r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True - assert r.cluster_setslot(node, "node_0", 1218, "NODE") is True - assert r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True + assert await r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True + assert await r.cluster_setslot(node, "node_0", 1218, "NODE") is True + assert await r.cluster_setslot(node, "node_0", 1218, "MIGRATING") is True with pytest.raises(RedisError): - r.cluster_failover(node, "STABLE") + await r.cluster_failover(node, "STABLE") with pytest.raises(RedisError): - r.cluster_failover(node, "STATE") + await r.cluster_failover(node, "STATE") - def test_cluster_setslot_stable(self, r): + async def test_cluster_setslot_stable(self, r): node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") - assert r.cluster_setslot_stable(12182) is True + assert await r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @skip_if_redis_enterprise() - def test_cluster_replicas(self, r): + async def test_cluster_replicas(self, r): response = [ b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " b"127.0.0.1:6377@16377 slave " @@ -1070,7 +974,7 @@ def test_cluster_replicas(self, r): b"1634550063436 4 connected", ] mock_all_nodes_resp(r, response) - replicas = r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") + replicas = await r.cluster_replicas("52611e796814b78e90ad94be9d769a4f668f9a") assert replicas.get("127.0.0.1:6377") is not None assert replicas.get("127.0.0.1:6378") is not None assert ( @@ -1079,75 +983,76 @@ def test_cluster_replicas(self, r): ) @skip_if_server_version_lt("7.0.0") - def test_cluster_links(self, r): + async def test_cluster_links(self, r): node = r.get_random_node() - res = r.cluster_links(node) + res = await r.cluster_links(node) links_to = sum(x.count("to") for x in res) links_for = sum(x.count("from") for x in res) assert links_to == links_for - print(res) for i in range(0, len(res) - 1, 2): assert res[i][3] == res[i + 1][3] @skip_if_redis_enterprise() - def test_readonly(self): - r = get_mocked_redis_client(host=default_host, port=default_port) + async def test_readonly(self): + r = await get_mocked_redis_client(host=default_host, port=default_port) mock_all_nodes_resp(r, "OK") - assert r.readonly() is True - all_replicas_results = r.readonly(target_nodes="replicas") + assert await r.readonly() is True + all_replicas_results = await r.readonly(target_nodes="replicas") for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called @skip_if_redis_enterprise() - def test_readwrite(self): - r = get_mocked_redis_client(host=default_host, port=default_port) + async def test_readwrite(self): + r = await get_mocked_redis_client(host=default_host, port=default_port) mock_all_nodes_resp(r, "OK") - assert r.readwrite() is True - all_replicas_results = r.readwrite(target_nodes="replicas") + assert await r.readwrite() is True + all_replicas_results = await r.readwrite(target_nodes="replicas") for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called @skip_if_redis_enterprise() - def test_bgsave(self, r): - assert r.bgsave() + async def test_bgsave(self, r): + assert await r.bgsave() sleep(0.3) - assert r.bgsave(True) + assert await r.bgsave(True) - def test_info(self, r): + async def test_info(self, r): # Map keys to same slot - r.set("x{1}", 1) - r.set("y{1}", 2) - r.set("z{1}", 3) + await r.set("x{1}", 1) + await r.set("y{1}", 2) + await r.set("z{1}", 3) # Get node that handles the slot slot = r.keyslot("x{1}") node = r.nodes_manager.get_node_from_slot(slot) # Run info on that node - info = r.info(target_nodes=node) + info = await r.info(target_nodes=node) assert isinstance(info, dict) assert info["db0"]["keys"] == 3 - def _init_slowlog_test(self, r, node): - slowlog_lim = r.config_get("slowlog-log-slower-than", target_nodes=node) - assert r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True + async def _init_slowlog_test(self, r, node): + slowlog_lim = await r.config_get("slowlog-log-slower-than", target_nodes=node) + assert ( + await r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True + ) return slowlog_lim["slowlog-log-slower-than"] - def _teardown_slowlog_test(self, r, node, prev_limit): + async def _teardown_slowlog_test(self, r, node, prev_limit): assert ( - r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) + await r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) is True ) - def test_slowlog_get(self, r, slowlog): + async def test_slowlog_get(self, r, slowlog): unicode_string = chr(3456) + "abcd" + chr(3421) node = r.get_node_from_key(unicode_string) - slowlog_limit = self._init_slowlog_test(r, node) - assert r.slowlog_reset(target_nodes=node) - r.get(unicode_string) - slowlog = r.slowlog_get(target_nodes=node) + slowlog_limit = await self._init_slowlog_test(r, node) + assert await r.slowlog_reset(target_nodes=node) + await r.get(unicode_string) + slowlog = await r.slowlog_get(target_nodes=node) assert isinstance(slowlog, list) commands = [log["command"] for log in slowlog] @@ -1165,543 +1070,605 @@ def test_slowlog_get(self, r, slowlog): assert isinstance(slowlog[0]["start_time"], int) assert isinstance(slowlog[0]["duration"], int) # rollback the slowlog limit to its original value - self._teardown_slowlog_test(r, node, slowlog_limit) + await self._teardown_slowlog_test(r, node, slowlog_limit) - def test_slowlog_get_limit(self, r, slowlog): - assert r.slowlog_reset() + async def test_slowlog_get_limit(self, r, slowlog): + assert await r.slowlog_reset() node = r.get_node_from_key("foo") - slowlog_limit = self._init_slowlog_test(r, node) - r.get("foo") - slowlog = r.slowlog_get(1, target_nodes=node) + slowlog_limit = await self._init_slowlog_test(r, node) + await r.get("foo") + slowlog = await r.slowlog_get(1, target_nodes=node) assert isinstance(slowlog, list) # only one command, based on the number we passed to slowlog_get() assert len(slowlog) == 1 - self._teardown_slowlog_test(r, node, slowlog_limit) + await self._teardown_slowlog_test(r, node, slowlog_limit) - def test_slowlog_length(self, r, slowlog): - r.get("foo") + async def test_slowlog_length(self, r, slowlog): + await r.get("foo") node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) - slowlog_len = r.slowlog_len(target_nodes=node) + slowlog_len = await r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) - def test_time(self, r): - t = r.time(target_nodes=r.get_primaries()[0]) + async def test_time(self, r): + t = await r.time(target_nodes=r.get_primaries()[0]) assert len(t) == 2 assert isinstance(t[0], int) assert isinstance(t[1], int) @skip_if_server_version_lt("4.0.0") - def test_memory_usage(self, r): - r.set("foo", "bar") - assert isinstance(r.memory_usage("foo"), int) + async def test_memory_usage(self, r): + await r.set("foo", "bar") + assert isinstance(await r.memory_usage("foo"), int) @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise() - def test_memory_malloc_stats(self, r): - assert r.memory_malloc_stats() + async def test_memory_malloc_stats(self, r): + assert await r.memory_malloc_stats() @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise() - def test_memory_stats(self, r): + async def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data - r.set("foo", "bar") + await r.set("foo", "bar") node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) - stats = r.memory_stats(target_nodes=node) + stats = await r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): if key.startswith("db."): assert isinstance(value, dict) @skip_if_server_version_lt("4.0.0") - def test_memory_help(self, r): + async def test_memory_help(self, r): with pytest.raises(NotImplementedError): - r.memory_help() + await r.memory_help() @skip_if_server_version_lt("4.0.0") - def test_memory_doctor(self, r): + async def test_memory_doctor(self, r): with pytest.raises(NotImplementedError): - r.memory_doctor() + await r.memory_doctor() @skip_if_redis_enterprise() - def test_lastsave(self, r): + async def test_lastsave(self, r): node = r.get_primaries()[0] - assert isinstance(r.lastsave(target_nodes=node), datetime.datetime) + assert isinstance(await r.lastsave(target_nodes=node), datetime.datetime) - def test_cluster_echo(self, r): + async def test_cluster_echo(self, r): node = r.get_primaries()[0] - assert r.echo("foo bar", target_nodes=node) == b"foo bar" + assert await r.echo("foo bar", target_nodes=node) == b"foo bar" @skip_if_server_version_lt("1.0.0") - def test_debug_segfault(self, r): + async def test_debug_segfault(self, r): with pytest.raises(NotImplementedError): - r.debug_segfault() + await r.debug_segfault() - def test_config_resetstat(self, r): + async def test_config_resetstat(self, r): node = r.get_primaries()[0] - r.ping(target_nodes=node) + await r.ping(target_nodes=node) prior_commands_processed = int( - r.info(target_nodes=node)["total_commands_processed"] + (await r.info(target_nodes=node))["total_commands_processed"] ) assert prior_commands_processed >= 1 - r.config_resetstat(target_nodes=node) + await r.config_resetstat(target_nodes=node) reset_commands_processed = int( - r.info(target_nodes=node)["total_commands_processed"] + (await r.info(target_nodes=node))["total_commands_processed"] ) assert reset_commands_processed < prior_commands_processed @skip_if_server_version_lt("6.2.0") - def test_client_trackinginfo(self, r): + async def test_client_trackinginfo(self, r): node = r.get_primaries()[0] - res = r.client_trackinginfo(target_nodes=node) + res = await r.client_trackinginfo(target_nodes=node) assert len(res) > 2 assert "prefixes" in res @skip_if_server_version_lt("2.9.50") - def test_client_pause(self, r): + async def test_client_pause(self, r): node = r.get_primaries()[0] - assert r.client_pause(1, target_nodes=node) - assert r.client_pause(timeout=1, target_nodes=node) + assert await r.client_pause(1, target_nodes=node) + assert await r.client_pause(timeout=1, target_nodes=node) with pytest.raises(RedisError): - r.client_pause(timeout="not an integer", target_nodes=node) + await r.client_pause(timeout="not an integer", target_nodes=node) @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() - def test_client_unpause(self, r): - assert r.client_unpause() + async def test_client_unpause(self, r): + assert await r.client_unpause() @skip_if_server_version_lt("5.0.0") - def test_client_id(self, r): + async def test_client_id(self, r): node = r.get_primaries()[0] - assert r.client_id(target_nodes=node) > 0 + assert await r.client_id(target_nodes=node) > 0 @skip_if_server_version_lt("5.0.0") - def test_client_unblock(self, r): + async def test_client_unblock(self, r): node = r.get_primaries()[0] - myid = r.client_id(target_nodes=node) - assert not r.client_unblock(myid, target_nodes=node) - assert not r.client_unblock(myid, error=True, target_nodes=node) - assert not r.client_unblock(myid, error=False, target_nodes=node) + myid = await r.client_id(target_nodes=node) + assert not await r.client_unblock(myid, target_nodes=node) + assert not await r.client_unblock(myid, error=True, target_nodes=node) + assert not await r.client_unblock(myid, error=False, target_nodes=node) @skip_if_server_version_lt("6.0.0") - def test_client_getredir(self, r): + async def test_client_getredir(self, r): node = r.get_primaries()[0] - assert isinstance(r.client_getredir(target_nodes=node), int) - assert r.client_getredir(target_nodes=node) == -1 + assert isinstance(await r.client_getredir(target_nodes=node), int) + assert await r.client_getredir(target_nodes=node) == -1 @skip_if_server_version_lt("6.2.0") - def test_client_info(self, r): + async def test_client_info(self, r): node = r.get_primaries()[0] - info = r.client_info(target_nodes=node) + info = await r.client_info(target_nodes=node) assert isinstance(info, dict) assert "addr" in info @skip_if_server_version_lt("2.6.9") - def test_client_kill(self, r, r2): + async def test_client_kill(self, r, r2): node = r.get_primaries()[0] - r.client_setname("redis-py-c1", target_nodes="all") - r2.client_setname("redis-py-c2", target_nodes="all") + await r.client_setname("redis-py-c1", target_nodes="all") + await r2.client_setname("redis-py-c2", target_nodes="all") clients = [ client - for client in r.client_list(target_nodes=node) + for client in await r.client_list(target_nodes=node) if client.get("name") in ["redis-py-c1", "redis-py-c2"] ] assert len(clients) == 2 clients_by_name = {client.get("name"): client for client in clients} client_addr = clients_by_name["redis-py-c2"].get("addr") - assert r.client_kill(client_addr, target_nodes=node) is True + assert await r.client_kill(client_addr, target_nodes=node) is True clients = [ client - for client in r.client_list(target_nodes=node) + for client in await r.client_list(target_nodes=node) if client.get("name") in ["redis-py-c1", "redis-py-c2"] ] assert len(clients) == 1 assert clients[0].get("name") == "redis-py-c1" @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not_empty_string(self, r): - r["{foo}a"] = "" - r.bitop("not", "{foo}r", "{foo}a") - assert r.get("{foo}r") is None + async def test_cluster_bitop_not_empty_string(self, r): + await r.set("{foo}a", "") + await r.bitop("not", "{foo}r", "{foo}a") + assert await r.get("{foo}r") is None @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not(self, r): + async def test_cluster_bitop_not(self, r): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r["{foo}a"] = test_str - r.bitop("not", "{foo}r", "{foo}a") - assert int(binascii.hexlify(r["{foo}r"]), 16) == correct + await r.set("{foo}a", test_str) + await r.bitop("not", "{foo}r", "{foo}a") + assert int(binascii.hexlify(await r.get("{foo}r")), 16) == correct @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_not_in_place(self, r): + async def test_cluster_bitop_not_in_place(self, r): test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF - r["{foo}a"] = test_str - r.bitop("not", "{foo}a", "{foo}a") - assert int(binascii.hexlify(r["{foo}a"]), 16) == correct + await r.set("{foo}a", test_str) + await r.bitop("not", "{foo}a", "{foo}a") + assert int(binascii.hexlify(await r.get("{foo}a")), 16) == correct @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_single_string(self, r): + async def test_cluster_bitop_single_string(self, r): test_str = b"\x01\x02\xFF" - r["{foo}a"] = test_str - r.bitop("and", "{foo}res1", "{foo}a") - r.bitop("or", "{foo}res2", "{foo}a") - r.bitop("xor", "{foo}res3", "{foo}a") - assert r["{foo}res1"] == test_str - assert r["{foo}res2"] == test_str - assert r["{foo}res3"] == test_str + await r.set("{foo}a", test_str) + await r.bitop("and", "{foo}res1", "{foo}a") + await r.bitop("or", "{foo}res2", "{foo}a") + await r.bitop("xor", "{foo}res3", "{foo}a") + assert await r.get("{foo}res1") == test_str + assert await r.get("{foo}res2") == test_str + assert await r.get("{foo}res3") == test_str @skip_if_server_version_lt("2.6.0") - def test_cluster_bitop_string_operands(self, r): - r["{foo}a"] = b"\x01\x02\xFF\xFF" - r["{foo}b"] = b"\x01\x02\xFF" - r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") - r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") - r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") - assert int(binascii.hexlify(r["{foo}res1"]), 16) == 0x0102FF00 - assert int(binascii.hexlify(r["{foo}res2"]), 16) == 0x0102FFFF - assert int(binascii.hexlify(r["{foo}res3"]), 16) == 0x000000FF + async def test_cluster_bitop_string_operands(self, r): + await r.set("{foo}a", b"\x01\x02\xFF\xFF") + await r.set("{foo}b", b"\x01\x02\xFF") + await r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") + await r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") + await r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") + assert int(binascii.hexlify(await r.get("{foo}res1")), 16) == 0x0102FF00 + assert int(binascii.hexlify(await r.get("{foo}res2")), 16) == 0x0102FFFF + assert int(binascii.hexlify(await r.get("{foo}res3")), 16) == 0x000000FF @skip_if_server_version_lt("6.2.0") - def test_cluster_copy(self, r): - assert r.copy("{foo}a", "{foo}b") == 0 - r.set("{foo}a", "bar") - assert r.copy("{foo}a", "{foo}b") == 1 - assert r.get("{foo}a") == b"bar" - assert r.get("{foo}b") == b"bar" + async def test_cluster_copy(self, r): + assert await r.copy("{foo}a", "{foo}b") == 0 + await r.set("{foo}a", "bar") + assert await r.copy("{foo}a", "{foo}b") == 1 + assert await r.get("{foo}a") == b"bar" + assert await r.get("{foo}b") == b"bar" @skip_if_server_version_lt("6.2.0") - def test_cluster_copy_and_replace(self, r): - r.set("{foo}a", "foo1") - r.set("{foo}b", "foo2") - assert r.copy("{foo}a", "{foo}b") == 0 - assert r.copy("{foo}a", "{foo}b", replace=True) == 1 + async def test_cluster_copy_and_replace(self, r): + await r.set("{foo}a", "foo1") + await r.set("{foo}b", "foo2") + assert await r.copy("{foo}a", "{foo}b") == 0 + assert await r.copy("{foo}a", "{foo}b", replace=True) == 1 @skip_if_server_version_lt("6.2.0") - def test_cluster_lmove(self, r): - r.rpush("{foo}a", "one", "two", "three", "four") - assert r.lmove("{foo}a", "{foo}b") - assert r.lmove("{foo}a", "{foo}b", "right", "left") + async def test_cluster_lmove(self, r): + await r.rpush("{foo}a", "one", "two", "three", "four") + assert await r.lmove("{foo}a", "{foo}b") + assert await r.lmove("{foo}a", "{foo}b", "right", "left") @skip_if_server_version_lt("6.2.0") - def test_cluster_blmove(self, r): - r.rpush("{foo}a", "one", "two", "three", "four") - assert r.blmove("{foo}a", "{foo}b", 5) - assert r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") + async def test_cluster_blmove(self, r): + await r.rpush("{foo}a", "one", "two", "three", "four") + assert await r.blmove("{foo}a", "{foo}b", 5) + assert await r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") - def test_cluster_msetnx(self, r): + async def test_cluster_msetnx(self, r): d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} - assert r.msetnx(d) + assert await r.msetnx(d) d2 = {"{foo}a": b"x", "{foo}d": b"4"} - assert not r.msetnx(d2) + assert not await r.msetnx(d2) for k, v in d.items(): - assert r[k] == v - assert r.get("{foo}d") is None - - def test_cluster_rename(self, r): - r["{foo}a"] = "1" - assert r.rename("{foo}a", "{foo}b") - assert r.get("{foo}a") is None - assert r["{foo}b"] == b"1" - - def test_cluster_renamenx(self, r): - r["{foo}a"] = "1" - r["{foo}b"] = "2" - assert not r.renamenx("{foo}a", "{foo}b") - assert r["{foo}a"] == b"1" - assert r["{foo}b"] == b"2" + assert await r.get(k) == v + assert await r.get("{foo}d") is None + + async def test_cluster_rename(self, r): + await r.set("{foo}a", "1") + assert await r.rename("{foo}a", "{foo}b") + assert await r.get("{foo}a") is None + assert await r.get("{foo}b") == b"1" + + async def test_cluster_renamenx(self, r): + await r.set("{foo}a", "1") + await r.set("{foo}b", "2") + assert not await r.renamenx("{foo}a", "{foo}b") + assert await r.get("{foo}a") == b"1" + assert await r.get("{foo}b") == b"2" # LIST COMMANDS - def test_cluster_blpop(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None - r.rpush("{foo}c", "1") - assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - - def test_cluster_brpop(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None - r.rpush("{foo}c", "1") - assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - - def test_cluster_brpoplpush(self, r): - r.rpush("{foo}a", "1", "2") - r.rpush("{foo}b", "3", "4") - assert r.brpoplpush("{foo}a", "{foo}b") == b"2" - assert r.brpoplpush("{foo}a", "{foo}b") == b"1" - assert r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None - assert r.lrange("{foo}a", 0, -1) == [] - assert r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] - - def test_cluster_brpoplpush_empty_string(self, r): - r.rpush("{foo}a", "") - assert r.brpoplpush("{foo}a", "{foo}b") == b"" - - def test_cluster_rpoplpush(self, r): - r.rpush("{foo}a", "a1", "a2", "a3") - r.rpush("{foo}b", "b1", "b2", "b3") - assert r.rpoplpush("{foo}a", "{foo}b") == b"a3" - assert r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] - assert r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] - - def test_cluster_sdiff(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} - r.sadd("{foo}b", "2", "3") - assert r.sdiff("{foo}a", "{foo}b") == {b"1"} - - def test_cluster_sdiffstore(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 - assert r.smembers("{foo}c") == {b"1", b"2", b"3"} - r.sadd("{foo}b", "2", "3") - assert r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 - assert r.smembers("{foo}c") == {b"1"} - - def test_cluster_sinter(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sinter("{foo}a", "{foo}b") == set() - r.sadd("{foo}b", "2", "3") - assert r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} - - def test_cluster_sinterstore(self, r): - r.sadd("{foo}a", "1", "2", "3") - assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 - assert r.smembers("{foo}c") == set() - r.sadd("{foo}b", "2", "3") - assert r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 - assert r.smembers("{foo}c") == {b"2", b"3"} - - def test_cluster_smove(self, r): - r.sadd("{foo}a", "a1", "a2") - r.sadd("{foo}b", "b1", "b2") - assert r.smove("{foo}a", "{foo}b", "a1") - assert r.smembers("{foo}a") == {b"a2"} - assert r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} - - def test_cluster_sunion(self, r): - r.sadd("{foo}a", "1", "2") - r.sadd("{foo}b", "2", "3") - assert r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} - - def test_cluster_sunionstore(self, r): - r.sadd("{foo}a", "1", "2") - r.sadd("{foo}b", "2", "3") - assert r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 - assert r.smembers("{foo}c") == {b"1", b"2", b"3"} + async def test_cluster_blpop(self, r): + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) is None + await r.rpush("{foo}c", "1") + assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + async def test_cluster_brpop(self, r): + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) is None + await r.rpush("{foo}c", "1") + assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + + async def test_cluster_brpoplpush(self, r): + await r.rpush("{foo}a", "1", "2") + await r.rpush("{foo}b", "3", "4") + assert await r.brpoplpush("{foo}a", "{foo}b") == b"2" + assert await r.brpoplpush("{foo}a", "{foo}b") == b"1" + assert await r.brpoplpush("{foo}a", "{foo}b", timeout=1) is None + assert await r.lrange("{foo}a", 0, -1) == [] + assert await r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] + + async def test_cluster_brpoplpush_empty_string(self, r): + await r.rpush("{foo}a", "") + assert await r.brpoplpush("{foo}a", "{foo}b") == b"" + + async def test_cluster_rpoplpush(self, r): + await r.rpush("{foo}a", "a1", "a2", "a3") + await r.rpush("{foo}b", "b1", "b2", "b3") + assert await r.rpoplpush("{foo}a", "{foo}b") == b"a3" + assert await r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] + assert await r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] + + async def test_cluster_sdiff(self, r): + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + await r.sadd("{foo}b", "2", "3") + assert await r.sdiff("{foo}a", "{foo}b") == {b"1"} + + async def test_cluster_sdiffstore(self, r): + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} + await r.sadd("{foo}b", "2", "3") + assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 + assert await r.smembers("{foo}c") == {b"1"} + + async def test_cluster_sinter(self, r): + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sinter("{foo}a", "{foo}b") == set() + await r.sadd("{foo}b", "2", "3") + assert await r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} + + async def test_cluster_sinterstore(self, r): + await r.sadd("{foo}a", "1", "2", "3") + assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 + assert await r.smembers("{foo}c") == set() + await r.sadd("{foo}b", "2", "3") + assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 + assert await r.smembers("{foo}c") == {b"2", b"3"} + + async def test_cluster_smove(self, r): + await r.sadd("{foo}a", "a1", "a2") + await r.sadd("{foo}b", "b1", "b2") + assert await r.smove("{foo}a", "{foo}b", "a1") + assert await r.smembers("{foo}a") == {b"a2"} + assert await r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} + + async def test_cluster_sunion(self, r): + await r.sadd("{foo}a", "1", "2") + await r.sadd("{foo}b", "2", "3") + assert await r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} + + async def test_cluster_sunionstore(self, r): + await r.sadd("{foo}a", "1", "2") + await r.sadd("{foo}b", "2", "3") + assert await r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 + assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} @skip_if_server_version_lt("6.2.0") - def test_cluster_zdiff(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 1, "a2": 2}) - assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + async def test_cluster_zdiff(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] + assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] @skip_if_server_version_lt("6.2.0") - def test_cluster_zdiffstore(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 1, "a2": 2}) - assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) - assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + async def test_cluster_zdiffstore(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) + assert await r.zrange("{foo}out", 0, -1) == [b"a3"] + assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] @skip_if_server_version_lt("6.2.0") - def test_cluster_zinter(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] + async def test_cluster_zinter(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"]) == [b"a3", b"a1"] # invalid aggregation with pytest.raises(DataError): - r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) + await r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True + ) # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ (b"a3", 8), (b"a1", 9), ] # aggregate with MAX - assert r.zinter( + assert await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True ) == [(b"a3", 5), (b"a1", 6)] # aggregate with MIN - assert r.zinter( + assert await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True ) == [(b"a1", 1), (b"a3", 1)] # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), + assert await r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a3", 20), (b"a1", 23)] + + async def test_cluster_zinterstore_sum(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 8), + (b"a1", 9), ] - def test_cluster_zinterstore_sum(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] - - def test_cluster_zinterstore_max(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + async def test_cluster_zinterstore_max(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert ( - r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + await r.zinterstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX" + ) == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 5), + (b"a1", 6), + ] - def test_cluster_zinterstore_min(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + async def test_cluster_zinterstore_min(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert ( - r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + await r.zinterstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN" + ) == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a3", 3), + ] - def test_cluster_zinterstore_with_weight(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + async def test_cluster_zinterstore_with_weight(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] @skip_if_server_version_lt("4.9.0") - def test_cluster_bzpopmax(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2}) - r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None - r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + async def test_cluster_bzpopmax(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2}) + await r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b2", + 20, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b1", + 10, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a2", + 2, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a1", + 1, + ) + assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None + await r.zadd("{foo}c", {"c1": 100}) + assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) @skip_if_server_version_lt("4.9.0") - def test_cluster_bzpopmin(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2}) - r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None - r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + async def test_cluster_bzpopmin(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2}) + await r.zadd("{foo}b", {"b1": 10, "b2": 20}) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b1", + 10, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}b", + b"b2", + 20, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a1", + 1, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( + b"{foo}a", + b"a2", + 2, + ) + assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None + await r.zadd("{foo}c", {"c1": 100}) + assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) @skip_if_server_version_lt("6.2.0") - def test_cluster_zrangestore(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zrangestore("{foo}b", "{foo}a", 0, 1) - assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] - assert r.zrangestore("{foo}b", "{foo}a", 1, 2) - assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + async def test_cluster_zrangestore(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrangestore("{foo}b", "{foo}a", 0, 1) + assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) + assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] + assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a3", 3), + ] # reversed order - assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) - assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] + assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) + assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] # by score - assert r.zrangestore( + assert await r.zrangestore( "{foo}b", "{foo}a", 2, 1, byscore=True, offset=0, num=1, desc=True ) - assert r.zrange("{foo}b", 0, -1) == [b"a2"] + assert await r.zrange("{foo}b", 0, -1) == [b"a2"] # by lex - assert r.zrangestore( + assert await r.zrangestore( "{foo}b", "{foo}a", "[a2", "(a3", bylex=True, offset=0, num=1 ) - assert r.zrange("{foo}b", 0, -1) == [b"a2"] + assert await r.zrange("{foo}b", 0, -1) == [b"a2"] @skip_if_server_version_lt("6.2.0") - def test_cluster_zunion(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + async def test_cluster_zunion(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [ + b"a2", + b"a4", + b"a3", + b"a1", + ] + assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ (b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9), ] # max - assert r.zunion( + assert await r.zunion( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] # min - assert r.zunion( + assert await r.zunion( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] - - def test_cluster_zunionstore_sum(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + assert await r.zunion( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + + async def test_cluster_zunionstore_sum(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ (b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9), ] - def test_cluster_zunionstore_max(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + async def test_cluster_zunionstore_max(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert ( - r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") + await r.zunionstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX" + ) == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ (b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6), ] - def test_cluster_zunionstore_min(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + async def test_cluster_zunionstore_min(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert ( - r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") + await r.zunionstore( + "{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN" + ) == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ (b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4), ] - def test_cluster_zunionstore_with_weight(self, r): - r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) - r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) - r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) - assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ + async def test_cluster_zunionstore_with_weight(self, r): + await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) + assert ( + await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 + ) + assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ (b"a2", 5), (b"a4", 12), (b"a3", 20), @@ -1709,63 +1676,63 @@ def test_cluster_zunionstore_with_weight(self, r): ] @skip_if_server_version_lt("2.8.9") - def test_cluster_pfcount(self, r): + async def test_cluster_pfcount(self, r): members = {b"1", b"2", b"3"} - r.pfadd("{foo}a", *members) - assert r.pfcount("{foo}a") == len(members) + await r.pfadd("{foo}a", *members) + assert await r.pfcount("{foo}a") == len(members) members_b = {b"2", b"3", b"4"} - r.pfadd("{foo}b", *members_b) - assert r.pfcount("{foo}b") == len(members_b) - assert r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) + await r.pfadd("{foo}b", *members_b) + assert await r.pfcount("{foo}b") == len(members_b) + assert await r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) @skip_if_server_version_lt("2.8.9") - def test_cluster_pfmerge(self, r): + async def test_cluster_pfmerge(self, r): mema = {b"1", b"2", b"3"} memb = {b"2", b"3", b"4"} memc = {b"5", b"6", b"7"} - r.pfadd("{foo}a", *mema) - r.pfadd("{foo}b", *memb) - r.pfadd("{foo}c", *memc) - r.pfmerge("{foo}d", "{foo}c", "{foo}a") - assert r.pfcount("{foo}d") == 6 - r.pfmerge("{foo}d", "{foo}b") - assert r.pfcount("{foo}d") == 7 - - def test_cluster_sort_store(self, r): - r.rpush("{foo}a", "2", "3", "1") - assert r.sort("{foo}a", store="{foo}sorted_values") == 3 - assert r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] + await r.pfadd("{foo}a", *mema) + await r.pfadd("{foo}b", *memb) + await r.pfadd("{foo}c", *memc) + await r.pfmerge("{foo}d", "{foo}c", "{foo}a") + assert await r.pfcount("{foo}d") == 6 + await r.pfmerge("{foo}d", "{foo}b") + assert await r.pfcount("{foo}d") == 7 + + async def test_cluster_sort_store(self, r): + await r.rpush("{foo}a", "2", "3", "1") + assert await r.sort("{foo}a", store="{foo}sorted_values") == 3 + assert await r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] # GEO COMMANDS @skip_if_server_version_lt("6.2.0") - def test_cluster_geosearchstore(self, r): + async def test_cluster_geosearchstore(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, "place2", ) - r.geoadd("{foo}barcelona", values) - r.geosearchstore( + await r.geoadd("{foo}barcelona", values) + await r.geosearchstore( "{foo}places_barcelona", "{foo}barcelona", longitude=2.191, latitude=41.433, radius=1000, ) - assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + assert await r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] @skip_unless_arch_bits(64) @skip_if_server_version_lt("6.2.0") - def test_geosearchstore_dist(self, r): + async def test_geosearchstore_dist(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, "place2", ) - r.geoadd("{foo}barcelona", values) - r.geosearchstore( + await r.geoadd("{foo}barcelona", values) + await r.geosearchstore( "{foo}places_barcelona", "{foo}barcelona", longitude=2.191, @@ -1774,103 +1741,105 @@ def test_geosearchstore_dist(self, r): storedist=True, ) # instead of save the geo score, the distance is saved. - assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 @skip_if_server_version_lt("3.2.0") - def test_cluster_georadius_store(self, r): + async def test_cluster_georadius_store(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, "place2", ) - r.geoadd("{foo}barcelona", values) - r.georadius( + await r.geoadd("{foo}barcelona", values) + await r.georadius( "{foo}barcelona", 2.191, 41.433, 1000, store="{foo}places_barcelona" ) - assert r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] + assert await r.zrange("{foo}places_barcelona", 0, -1) == [b"place1"] @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") - def test_cluster_georadius_store_dist(self, r): + async def test_cluster_georadius_store_dist(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, "place2", ) - r.geoadd("{foo}barcelona", values) - r.georadius( + await r.geoadd("{foo}barcelona", values) + await r.georadius( "{foo}barcelona", 2.191, 41.433, 1000, store_dist="{foo}places_barcelona" ) # instead of save the geo score, the distance is saved. - assert r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 + assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 - def test_cluster_dbsize(self, r): + async def test_cluster_dbsize(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} - assert r.mset_nonatomic(d) - assert r.dbsize(target_nodes="primaries") == len(d) + assert await r.mset_nonatomic(d) + assert await r.dbsize(target_nodes="primaries") == len(d) - def test_cluster_keys(self, r): - assert r.keys() == [] + async def test_cluster_keys(self, r): + assert await r.keys() == [] keys_with_underscores = {b"test_a", b"test_b"} keys = keys_with_underscores.union({b"testc"}) for key in keys: - r[key] = 1 + await r.set(key, 1) assert ( - set(r.keys(pattern="test_*", target_nodes="primaries")) + set(await r.keys(pattern="test_*", target_nodes="primaries")) == keys_with_underscores ) - assert set(r.keys(pattern="test*", target_nodes="primaries")) == keys + assert set(await r.keys(pattern="test*", target_nodes="primaries")) == keys # SCAN COMMANDS @skip_if_server_version_lt("2.8.0") - def test_cluster_scan(self, r): - r.set("a", 1) - r.set("b", 2) - r.set("c", 3) + async def test_cluster_scan(self, r): + await r.set("a", 1) + await r.set("b", 2) + await r.set("c", 3) for target_nodes, nodes in zip( ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] ): - cursors, keys = r.scan(target_nodes=target_nodes) + cursors, keys = await r.scan(target_nodes=target_nodes) assert sorted(keys) == [b"a", b"b", b"c"] assert sorted(cursors.keys()) == sorted(node.name for node in nodes) assert all(cursor == 0 for cursor in cursors.values()) - cursors, keys = r.scan(match="a*", target_nodes=target_nodes) + cursors, keys = await r.scan(match="a*", target_nodes=target_nodes) assert sorted(keys) == [b"a"] assert sorted(cursors.keys()) == sorted(node.name for node in nodes) assert all(cursor == 0 for cursor in cursors.values()) @skip_if_server_version_lt("6.0.0") - def test_cluster_scan_type(self, r): - r.sadd("a-set", 1) - r.sadd("b-set", 1) - r.sadd("c-set", 1) - r.hset("a-hash", "foo", 2) - r.lpush("a-list", "aux", 3) + async def test_cluster_scan_type(self, r): + await r.sadd("a-set", 1) + await r.sadd("b-set", 1) + await r.sadd("c-set", 1) + await r.hset("a-hash", "foo", 2) + await r.lpush("a-list", "aux", 3) for target_nodes, nodes in zip( ["primaries", "replicas"], [r.get_primaries(), r.get_replicas()] ): - cursors, keys = r.scan(_type="SET", target_nodes=target_nodes) + cursors, keys = await r.scan(_type="SET", target_nodes=target_nodes) assert sorted(keys) == [b"a-set", b"b-set", b"c-set"] assert sorted(cursors.keys()) == sorted(node.name for node in nodes) assert all(cursor == 0 for cursor in cursors.values()) - cursors, keys = r.scan(_type="SET", match="a*", target_nodes=target_nodes) + cursors, keys = await r.scan( + _type="SET", match="a*", target_nodes=target_nodes + ) assert sorted(keys) == [b"a-set"] assert sorted(cursors.keys()) == sorted(node.name for node in nodes) assert all(cursor == 0 for cursor in cursors.values()) @skip_if_server_version_lt("2.8.0") - def test_cluster_scan_iter(self, r): + async def test_cluster_scan_iter(self, r): keys_all = [] keys_1 = [] for i in range(100): s = str(i) - r.set(s, 1) + await r.set(s, 1) keys_all.append(s.encode("utf-8")) if s.startswith("1"): keys_1.append(s.encode("utf-8")) @@ -1878,31 +1847,33 @@ def test_cluster_scan_iter(self, r): keys_1.sort() for target_nodes in ["primaries", "replicas"]: - keys = r.scan_iter(target_nodes=target_nodes) + keys = [key async for key in r.scan_iter(target_nodes=target_nodes)] assert sorted(keys) == keys_all - keys = r.scan_iter(match="1*", target_nodes=target_nodes) + keys = [ + key async for key in r.scan_iter(match="1*", target_nodes=target_nodes) + ] assert sorted(keys) == keys_1 - def test_cluster_randomkey(self, r): + async def test_cluster_randomkey(self, r): node = r.get_node_from_key("{foo}") - assert r.randomkey(target_nodes=node) is None + assert await r.randomkey(target_nodes=node) is None for key in ("{foo}a", "{foo}b", "{foo}c"): - r[key] = 1 - assert r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") + await r.set(key, 1) + assert await r.randomkey(target_nodes=node) in (b"{foo}a", b"{foo}b", b"{foo}c") @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() - def test_acl_log(self, r, request): + async def test_acl_log(self, r, request): key = "{cache}:" node = r.get_node_from_key(key) username = "redis-py-user" - def teardown(): - r.acl_deluser(username, target_nodes="primaries") + async def teardown(): + await r.acl_deluser(username, target_nodes="primaries") request.addfinalizer(teardown) - r.acl_setuser( + await r.acl_setuser( username, enabled=True, reset=True, @@ -1911,30 +1882,30 @@ def teardown(): nopass=True, target_nodes="primaries", ) - r.acl_log_reset(target_nodes=node) + await r.acl_log_reset(target_nodes=node) user_client = _get_client( RedisCluster, request, flushdb=False, username=username ) # Valid operation and key - assert user_client.set("{cache}:0", 1) - assert user_client.get("{cache}:0") == b"1" + assert await user_client.set("{cache}:0", 1) + assert await user_client.get("{cache}:0") == b"1" # Invalid key with pytest.raises(NoPermissionError): - user_client.get("{cache}violated_cache:0") + await user_client.get("{cache}violated_cache:0") # Invalid operation with pytest.raises(NoPermissionError): - user_client.hset("{cache}:0", "hkey", "hval") + await user_client.hset("{cache}:0", "hkey", "hval") - assert isinstance(r.acl_log(target_nodes=node), list) - assert len(r.acl_log(target_nodes=node)) == 2 - assert len(r.acl_log(count=1, target_nodes=node)) == 1 - assert isinstance(r.acl_log(target_nodes=node)[0], dict) - assert "client-info" in r.acl_log(count=1, target_nodes=node)[0] - assert r.acl_log_reset(target_nodes=node) + assert isinstance(await r.acl_log(target_nodes=node), list) + assert len(await r.acl_log(target_nodes=node)) == 2 + assert len(await r.acl_log(count=1, target_nodes=node)) == 1 + assert isinstance((await r.acl_log(target_nodes=node))[0], dict) + assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] + assert await r.acl_log_reset(target_nodes=node) @pytest.mark.onlycluster @@ -1943,7 +1914,7 @@ class TestNodesManager: Tests for the NodesManager class """ - def test_load_balancer(self, r): + async def test_load_balancer(self, r): n_manager = r.nodes_manager lb = n_manager.read_load_balancer slot_1 = 1257 @@ -1975,7 +1946,7 @@ def test_load_balancer(self, r): assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 - def test_init_slots_cache_not_all_slots_covered(self): + async def test_init_slots_cache_not_all_slots_covered(self): """ Test that if not all slots are covered it should raise an exception """ @@ -1986,7 +1957,7 @@ def test_init_slots_cache_not_all_slots_covered(self): [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] with pytest.raises(RedisClusterException) as ex: - get_mocked_redis_client( + await get_mocked_redis_client( host=default_host, port=default_port, cluster_slots=cluster_slots, @@ -1996,7 +1967,7 @@ def test_init_slots_cache_not_all_slots_covered(self): "All slots are not covered after query all startup_nodes." ) - def test_init_slots_cache_not_require_full_coverage_success(self): + async def test_init_slots_cache_not_require_full_coverage_success(self): """ When require_full_coverage is set to False and not all slots are covered the cluster client initialization should succeed @@ -2008,7 +1979,7 @@ def test_init_slots_cache_not_require_full_coverage_success(self): [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] - rc = get_mocked_redis_client( + rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_slots=cluster_slots, @@ -2017,7 +1988,7 @@ def test_init_slots_cache_not_require_full_coverage_success(self): assert 5460 not in rc.nodes_manager.slots_cache - def test_init_slots_cache(self): + async def test_init_slots_cache(self): """ Test that slots cache can in initialized and all slots are covered """ @@ -2027,7 +1998,7 @@ def test_init_slots_cache(self): [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.2", 7005]], ] - rc = get_mocked_redis_client( + rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_slots=good_slots_resp ) n_manager = rc.nodes_manager @@ -2046,18 +2017,18 @@ def test_init_slots_cache(self): assert len(n_manager.nodes_cache) == 6 - def test_init_slots_cache_cluster_mode_disabled(self): + async def test_init_slots_cache_cluster_mode_disabled(self): """ Test that creating a RedisCluster failes if one of the startup nodes has cluster mode disabled """ with pytest.raises(RedisClusterException) as e: - get_mocked_redis_client( + await get_mocked_redis_client( host=default_host, port=default_port, cluster_enabled=False ) assert "Cluster mode is not enabled on this node" in str(e.value) - def test_empty_startup_nodes(self): + async def test_empty_startup_nodes(self): """ It should not be possible to create a node manager with no nodes specified @@ -2065,7 +2036,7 @@ def test_empty_startup_nodes(self): with pytest.raises(RedisClusterException): NodesManager([]) - def test_wrong_startup_nodes_type(self): + async def test_wrong_startup_nodes_type(self): """ If something other then a list type itteratable is provided it should fail @@ -2073,7 +2044,7 @@ def test_wrong_startup_nodes_type(self): with pytest.raises(RedisClusterException): NodesManager({}) - def test_init_slots_cache_slots_collision(self, request): + async def test_init_slots_cache_slots_collision(self, request): """ Test that if 2 nodes do not agree on the same slots setup it should raise an error. In this test both nodes will say that the first @@ -2127,14 +2098,16 @@ def execute_command(*args, **kwargs): "startup_nodes could not agree on a valid slots cache" ), str(ex.value) - def test_cluster_one_instance(self): + async def test_cluster_one_instance(self): """ If the cluster exists of only 1 node then there is some hacks that must be validated they work. """ node = ClusterNode(default_host, default_port) cluster_slots = [[0, 16383, ["", default_port]]] - rc = get_mocked_redis_client(startup_nodes=[node], cluster_slots=cluster_slots) + rc = await get_mocked_redis_client( + startup_nodes=[node], cluster_slots=cluster_slots + ) n = rc.nodes_manager assert len(n.nodes_cache) == 1 @@ -2146,7 +2119,7 @@ def test_cluster_one_instance(self): for i in range(0, REDIS_CLUSTER_HASH_SLOTS): assert n.slots_cache[i] == [n_node] - def test_init_with_down_node(self): + async def test_init_with_down_node(self): """ If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. @@ -2207,523 +2180,3 @@ def cmd_init_mock(self, r): rc = RedisCluster(startup_nodes=[node_1, node_2]) assert rc.get_node(host=default_host, port=7001) is not None assert rc.get_node(host=default_host, port=7002) is not None - - -@pytest.mark.onlycluster -class TestClusterPubSubObject: - """ - Tests for the ClusterPubSub class - """ - - def test_init_pubsub_with_host_and_port(self, r): - """ - Test creation of pubsub instance with passed host and port - """ - node = r.get_default_node() - p = r.pubsub(host=node.host, port=node.port) - assert p.get_pubsub_node() == node - - def test_init_pubsub_with_node(self, r): - """ - Test creation of pubsub instance with passed node - """ - node = r.get_default_node() - p = r.pubsub(node=node) - assert p.get_pubsub_node() == node - - def test_init_pubusub_without_specifying_node(self, r): - """ - Test creation of pubsub instance without specifying a node. The node - should be determined based on the keyslot of the first command - execution. - """ - channel_name = "foo" - node = r.get_node_from_key(channel_name) - p = r.pubsub() - assert p.get_pubsub_node() is None - p.subscribe(channel_name) - assert p.get_pubsub_node() == node - - def test_init_pubsub_with_a_non_existent_node(self, r): - """ - Test creation of pubsub instance with node that doesn't exists in the - cluster. RedisClusterException should be raised. - """ - node = ClusterNode("1.1.1.1", 1111) - with pytest.raises(RedisClusterException): - r.pubsub(node) - - def test_init_pubsub_with_a_non_existent_host_port(self, r): - """ - Test creation of pubsub instance with host and port that don't belong - to a node in the cluster. - RedisClusterException should be raised. - """ - with pytest.raises(RedisClusterException): - r.pubsub(host="1.1.1.1", port=1111) - - def test_init_pubsub_host_or_port(self, r): - """ - Test creation of pubsub instance with host but without port, and vice - versa. DataError should be raised. - """ - with pytest.raises(DataError): - r.pubsub(host="localhost") - - with pytest.raises(DataError): - r.pubsub(port=16379) - - def test_get_redis_connection(self, r): - """ - Test that get_redis_connection() returns the redis connection of the - set pubsub node - """ - node = r.get_default_node() - p = r.pubsub(node=node) - assert p.get_redis_connection() == node.redis_connection - - -@pytest.mark.onlycluster -class TestClusterPipeline: - """ - Tests for the ClusterPipeline class - """ - - def test_blocked_methods(self, r): - """ - Currently some method calls on a Cluster pipeline - is blocked when using in cluster mode. - They maybe implemented in the future. - """ - pipe = r.pipeline() - with pytest.raises(RedisClusterException): - pipe.multi() - - with pytest.raises(RedisClusterException): - pipe.immediate_execute_command() - - with pytest.raises(RedisClusterException): - pipe._execute_transaction(None, None, None) - - with pytest.raises(RedisClusterException): - pipe.load_scripts() - - with pytest.raises(RedisClusterException): - pipe.watch() - - with pytest.raises(RedisClusterException): - pipe.unwatch() - - with pytest.raises(RedisClusterException): - pipe.script_load_for_pipeline(None) - - with pytest.raises(RedisClusterException): - pipe.eval() - - def test_blocked_arguments(self, r): - """ - Currently some arguments is blocked when using in cluster mode. - They maybe implemented in the future. - """ - with pytest.raises(RedisClusterException) as ex: - r.pipeline(transaction=True) - - assert ( - str(ex.value).startswith("transaction is deprecated in cluster mode") - is True - ) - - with pytest.raises(RedisClusterException) as ex: - r.pipeline(shard_hint=True) - - assert ( - str(ex.value).startswith("shard_hint is deprecated in cluster mode") is True - ) - - def test_redis_cluster_pipeline(self, r): - """ - Test that we can use a pipeline with the RedisCluster class - """ - with r.pipeline() as pipe: - pipe.set("foo", "bar") - pipe.get("foo") - assert pipe.execute() == [True, b"bar"] - - def test_mget_disabled(self, r): - """ - Test that mget is disabled for ClusterPipeline - """ - with r.pipeline() as pipe: - with pytest.raises(RedisClusterException): - pipe.mget(["a"]) - - def test_mset_disabled(self, r): - """ - Test that mset is disabled for ClusterPipeline - """ - with r.pipeline() as pipe: - with pytest.raises(RedisClusterException): - pipe.mset({"a": 1, "b": 2}) - - def test_rename_disabled(self, r): - """ - Test that rename is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.rename("a", "b") - - def test_renamenx_disabled(self, r): - """ - Test that renamenx is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.renamenx("a", "b") - - def test_delete_single(self, r): - """ - Test a single delete operation - """ - r["a"] = 1 - with r.pipeline(transaction=False) as pipe: - pipe.delete("a") - assert pipe.execute() == [1] - - def test_multi_delete_unsupported(self, r): - """ - Test that multi delete operation is unsupported - """ - with r.pipeline(transaction=False) as pipe: - r["a"] = 1 - r["b"] = 2 - with pytest.raises(RedisClusterException): - pipe.delete("a", "b") - - def test_brpoplpush_disabled(self, r): - """ - Test that brpoplpush is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.brpoplpush() - - def test_rpoplpush_disabled(self, r): - """ - Test that rpoplpush is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.rpoplpush() - - def test_sort_disabled(self, r): - """ - Test that sort is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sort() - - def test_sdiff_disabled(self, r): - """ - Test that sdiff is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sdiff() - - def test_sdiffstore_disabled(self, r): - """ - Test that sdiffstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sdiffstore() - - def test_sinter_disabled(self, r): - """ - Test that sinter is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sinter() - - def test_sinterstore_disabled(self, r): - """ - Test that sinterstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sinterstore() - - def test_smove_disabled(self, r): - """ - Test that move is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.smove() - - def test_sunion_disabled(self, r): - """ - Test that sunion is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sunion() - - def test_sunionstore_disabled(self, r): - """ - Test that sunionstore is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.sunionstore() - - def test_spfmerge_disabled(self, r): - """ - Test that spfmerge is disabled for ClusterPipeline - """ - with r.pipeline(transaction=False) as pipe: - with pytest.raises(RedisClusterException): - pipe.pfmerge() - - def test_multi_key_operation_with_a_single_slot(self, r): - """ - Test multi key operation with a single slot - """ - pipe = r.pipeline(transaction=False) - pipe.set("a{foo}", 1) - pipe.set("b{foo}", 2) - pipe.set("c{foo}", 3) - pipe.get("a{foo}") - pipe.get("b{foo}") - pipe.get("c{foo}") - - res = pipe.execute() - assert res == [True, True, True, b"1", b"2", b"3"] - - def test_multi_key_operation_with_multi_slots(self, r): - """ - Test multi key operation with more than one slot - """ - pipe = r.pipeline(transaction=False) - pipe.set("a{foo}", 1) - pipe.set("b{foo}", 2) - pipe.set("c{foo}", 3) - pipe.set("bar", 4) - pipe.set("bazz", 5) - pipe.get("a{foo}") - pipe.get("b{foo}") - pipe.get("c{foo}") - pipe.get("bar") - pipe.get("bazz") - res = pipe.execute() - assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"] - - def test_connection_error_not_raised(self, r): - """ - Test that the pipeline doesn't raise an error on connection error when - raise_on_error=False - """ - key = "foo" - node = r.get_node_from_key(key, False) - - def raise_connection_error(): - e = ConnectionError("error") - return e - - with r.pipeline() as pipe: - mock_node_resp_func(node, raise_connection_error) - res = pipe.get(key).get(key).execute(raise_on_error=False) - assert node.redis_connection.connection.read_response.called - assert isinstance(res[0], ConnectionError) - - def test_connection_error_raised(self, r): - """ - Test that the pipeline raises an error on connection error when - raise_on_error=True - """ - key = "foo" - node = r.get_node_from_key(key, False) - - def raise_connection_error(): - e = ConnectionError("error") - return e - - with r.pipeline() as pipe: - mock_node_resp_func(node, raise_connection_error) - with pytest.raises(ConnectionError): - pipe.get(key).get(key).execute(raise_on_error=True) - - def test_asking_error(self, r): - """ - Test redirection on ASK error - """ - key = "foo" - first_node = r.get_node_from_key(key, False) - ask_node = None - for node in r.get_nodes(): - if node != first_node: - ask_node = node - break - if ask_node is None: - warnings.warn("skipping this test since the cluster has only one " "node") - return - ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}" - - def raise_ask_error(): - raise AskError(ask_msg) - - with r.pipeline() as pipe: - mock_node_resp_func(first_node, raise_ask_error) - mock_node_resp(ask_node, "MOCK_OK") - res = pipe.get(key).execute() - assert first_node.redis_connection.connection.read_response.called - assert ask_node.redis_connection.connection.read_response.called - assert res == ["MOCK_OK"] - - def test_empty_stack(self, r): - """ - If pipeline is executed with no commands it should - return a empty list. - """ - p = r.pipeline() - result = p.execute() - assert result == [] - - -@pytest.mark.onlycluster -class TestReadOnlyPipeline: - """ - Tests for ClusterPipeline class in readonly mode - """ - - def test_pipeline_readonly(self, r): - """ - On readonly mode, we supports get related stuff only. - """ - r.readonly(target_nodes="all") - r.set("foo71", "a1") # we assume this key is set on 127.0.0.1:7001 - r.zadd("foo88", {"z1": 1}) # we assume this key is set on 127.0.0.1:7002 - r.zadd("foo88", {"z2": 4}) - - with r.pipeline() as readonly_pipe: - readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] - - def test_moved_redirection_on_slave_with_default(self, r): - """ - On Pipeline, we redirected once and finally get from master with - readonly client when data is completely moved. - """ - key = "bar" - r.set(key, "foo") - # set read_from_replicas to True - r.read_from_replicas = True - primary = r.get_node_from_key(key, False) - replica = r.get_node_from_key(key, True) - with r.pipeline() as readwrite_pipe: - mock_node_resp(primary, "MOCK_FOO") - if replica is not None: - moved_error = f"{r.keyslot(key)} {primary.host}:{primary.port}" - - def raise_moved_error(): - raise MovedError(moved_error) - - mock_node_resp_func(replica, raise_moved_error) - assert readwrite_pipe.reinitialize_counter == 0 - readwrite_pipe.get(key).get(key) - assert readwrite_pipe.execute() == ["MOCK_FOO", "MOCK_FOO"] - if replica is not None: - # the slot has a replica as well, so MovedError should have - # occurred. If MovedError occurs, we should see the - # reinitialize_counter increase. - assert readwrite_pipe.reinitialize_counter == 1 - conn = replica.redis_connection.connection - assert conn.read_response.called is True - - def test_readonly_pipeline_from_readonly_client(self, request): - """ - Test that the pipeline is initialized with readonly mode if the client - has it enabled - """ - # Create a cluster with reading from replications - ro = _get_client(RedisCluster, request, read_from_replicas=True) - key = "bar" - ro.set(key, "foo") - import time - - time.sleep(0.2) - with ro.pipeline() as readonly_pipe: - mock_all_nodes_resp(ro, "MOCK_OK") - assert readonly_pipe.read_from_replicas is True - assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] - if len(slot_nodes) > 1: - executed_on_replica = False - for node in slot_nodes: - if node.server_type == REPLICA: - conn = node.redis_connection.connection - executed_on_replica = conn.read_response.called - if executed_on_replica: - break - assert executed_on_replica is True - - -@pytest.mark.onlycluster -class TestClusterMonitor: - def test_wait_command_not_found(self, r): - "Make sure the wait_for_command func works when command is not found" - key = "foo" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - response = wait_for_command(r, m, "nothing", key=key) - assert response is None - - def test_response_values(self, r): - db = 0 - key = "foo" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - r.ping(target_nodes=node) - response = wait_for_command(r, m, "PING", key=key) - assert isinstance(response["time"], float) - assert response["db"] == db - assert response["client_type"] in ("tcp", "unix") - assert isinstance(response["client_address"], str) - assert isinstance(response["client_port"], str) - assert response["command"] == "PING" - - def test_command_with_quoted_key(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(node) as m: - r.get('{foo}"bar') - response = wait_for_command(r, m, 'GET {foo}"bar', key=key) - assert response["command"] == 'GET {foo}"bar' - - def test_command_with_binary_data(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - byte_string = b"{foo}bar\x92" - r.get(byte_string) - response = wait_for_command(r, m, "GET {foo}bar\\x92", key=key) - assert response["command"] == "GET {foo}bar\\x92" - - def test_command_with_escaped_data(self, r): - key = "{foo}1" - node = r.get_node_from_key(key) - with r.monitor(target_node=node) as m: - byte_string = b"{foo}bar\\x92" - r.get(byte_string) - response = wait_for_command(r, m, "GET {foo}bar\\\\x92", key=key) - assert response["command"] == "GET {foo}bar\\\\x92" - - def test_flush(self, r): - r.set("x", "1") - r.set("z", "1") - r.flushall() - assert r.get("x") is None - assert r.get("y") is None From 96d992ceed819131ecd9bb33e744bd0111a2a0b3 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Mon, 11 Apr 2022 01:38:25 +0530 Subject: [PATCH 03/23] Add Async RedisCluster --- redis/asyncio/__init__.py | 4 + redis/asyncio/client.py | 2 + redis/asyncio/cluster.py | 1383 ++------------------ redis/asyncio/parser.py | 50 +- redis/cluster.py | 9 +- redis/commands/__init__.py | 3 +- redis/commands/cluster.py | 83 +- redis/commands/parser.py | 1 - tests/test_asyncio/conftest.py | 66 +- tests/test_asyncio/test_cluster.py | 212 +-- tests/test_asyncio/test_commands.py | 14 +- tests/test_asyncio/test_connection.py | 2 - tests/test_asyncio/test_connection_pool.py | 7 +- tests/test_asyncio/test_retry.py | 2 - tox.ini | 2 +- 15 files changed, 366 insertions(+), 1474 deletions(-) diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index c655c7da4b..598791ac15 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -1,4 +1,5 @@ from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.cluster import RedisCluster from redis.asyncio.connection import ( BlockingConnectionPool, Connection, @@ -6,6 +7,7 @@ SSLConnection, UnixDomainSocketConnection, ) +from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -35,6 +37,7 @@ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", + "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", @@ -44,6 +47,7 @@ "PubSubError", "ReadOnlyError", "Redis", + "RedisCluster", "RedisError", "ResponseError", "Sentinel", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 1fd46e5b16..3670dca6ed 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -172,6 +172,7 @@ def __init__( username: Optional[str] = None, retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, + redis_connect_func=None, ): """ Initialize a new Redis client. @@ -200,6 +201,7 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "redis_connect_func": redis_connect_func, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 92469e4f37..7651cbebff 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,15 +1,24 @@ +import asyncio import copy import logging import random import socket -import sys -import threading -import time -from collections import OrderedDict - -from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.asyncio.parser import CommandsParser +from redis.client import CaseInsensitiveDict +from redis.cluster import ( + PRIMARY, + READ_COMMANDS, + REPLICA, + SLOT_ID, + AbstractRedisCluster, + LoadBalancer, + cleanup_kwargs, + get_node_name, +) +from redis.commands import AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -18,178 +27,25 @@ ClusterDownError, ClusterError, ConnectionError, - DataError, MasterDownError, MovedError, RedisClusterException, - RedisError, ResponseError, SlotNotCoveredError, TimeoutError, TryAgainError, ) -from redis.lock import Lock -from redis.utils import ( - dict_merge, - list_keys_to_dict, - merge_result, - safe_str, - str_if_bytes, -) +from redis.utils import dict_merge, str_if_bytes log = logging.getLogger(__name__) -def get_node_name(host, port): - return f"{host}:{port}" - - -def get_connection(redis_node, *args, **options): - return redis_node.connection or redis_node.connection_pool.get_connection( +async def get_connection(redis_node, *args, **options): + return redis_node.connection or await redis_node.connection_pool.get_connection( args[0], **options ) -def parse_scan_result(command, res, **options): - cursors = {} - ret = [] - for node_name, response in res.items(): - cursor, r = parse_scan(response, **options) - cursors[node_name] = cursor - ret += r - - return cursors, ret - - -def parse_pubsub_numsub(command, res, **options): - numsub_d = OrderedDict() - for numsub_tups in res.values(): - for channel, numsubbed in numsub_tups: - try: - numsub_d[channel] += numsubbed - except KeyError: - numsub_d[channel] = numsubbed - - ret_numsub = [(channel, numsub) for channel, numsub in numsub_d.items()] - return ret_numsub - - -def parse_cluster_slots(resp, **options): - current_host = options.get("current_host", "") - - def fix_server(*args): - return str_if_bytes(args[0]) or current_host, args[1] - - slots = {} - for slot in resp: - start, end, primary = slot[:3] - replicas = slot[3:] - slots[start, end] = { - "primary": fix_server(*primary), - "replicas": [fix_server(*replica) for replica in replicas], - } - - return slots - - -PRIMARY = "primary" -REPLICA = "replica" -SLOT_ID = "slot-id" - -REDIS_ALLOWED_KEYS = ( - "charset", - "connection_class", - "connection_pool", - "client_name", - "db", - "decode_responses", - "encoding", - "encoding_errors", - "errors", - "host", - "max_connections", - "nodes_flag", - "redis_connect_func", - "password", - "port", - "retry", - "retry_on_timeout", - "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_timeout", - "ssl", - "ssl_ca_certs", - "ssl_ca_data", - "ssl_certfile", - "ssl_cert_reqs", - "ssl_keyfile", - "ssl_password", - "unix_socket_path", - "username", -) -KWARGS_DISABLED_KEYS = ("host", "port") - -# Not complete, but covers the major ones -# https://redis.io/commands -READ_COMMANDS = frozenset( - [ - "BITCOUNT", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUS", - "GEORADIUSBYMEMBER", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "KEYS", - "LINDEX", - "LLEN", - "LRANGE", - "MGET", - "PTTL", - "RANDOMKEY", - "SCARD", - "SDIFF", - "SINTER", - "SISMEMBER", - "SMEMBERS", - "SRANDMEMBER", - "STRLEN", - "SUNION", - "TTL", - "ZCARD", - "ZCOUNT", - "ZRANGE", - "ZSCORE", - ] -) - - -def cleanup_kwargs(**kwargs): - """ - Remove unsupported or disabled keys from kwargs - """ - connection_kwargs = { - k: v - for k, v in kwargs.items() - if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS - } - - return connection_kwargs - - class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( DefaultParser.EXCEPTION_CLASSES, @@ -204,220 +60,7 @@ class ClusterParser(DefaultParser): ) -class RedisCluster(RedisClusterCommands): - RedisClusterRequestTTL = 16 - - PRIMARIES = "primaries" - REPLICAS = "replicas" - ALL_NODES = "all" - RANDOM = "random" - DEFAULT_NODE = "default-node" - - NODE_FLAGS = {PRIMARIES, REPLICAS, ALL_NODES, RANDOM, DEFAULT_NODE} - - COMMAND_FLAGS = dict_merge( - list_keys_to_dict( - [ - "ACL CAT", - "ACL DELUSER", - "ACL GENPASS", - "ACL GETUSER", - "ACL HELP", - "ACL LIST", - "ACL LOG", - "ACL LOAD", - "ACL SAVE", - "ACL SETUSER", - "ACL USERS", - "ACL WHOAMI", - "AUTH", - "CLIENT LIST", - "CLIENT SETNAME", - "CLIENT GETNAME", - "CONFIG SET", - "CONFIG REWRITE", - "CONFIG RESETSTAT", - "TIME", - "PUBSUB CHANNELS", - "PUBSUB NUMPAT", - "PUBSUB NUMSUB", - "PING", - "INFO", - "SHUTDOWN", - "KEYS", - "DBSIZE", - "BGSAVE", - "SLOWLOG GET", - "SLOWLOG LEN", - "SLOWLOG RESET", - "WAIT", - "SAVE", - "MEMORY PURGE", - "MEMORY MALLOC-STATS", - "MEMORY STATS", - "LASTSAVE", - "CLIENT TRACKINGINFO", - "CLIENT PAUSE", - "CLIENT UNPAUSE", - "CLIENT UNBLOCK", - "CLIENT ID", - "CLIENT REPLY", - "CLIENT GETREDIR", - "CLIENT INFO", - "CLIENT KILL", - "READONLY", - "READWRITE", - "CLUSTER INFO", - "CLUSTER MEET", - "CLUSTER NODES", - "CLUSTER REPLICAS", - "CLUSTER RESET", - "CLUSTER SET-CONFIG-EPOCH", - "CLUSTER SLOTS", - "CLUSTER COUNT-FAILURE-REPORTS", - "CLUSTER KEYSLOT", - "COMMAND", - "COMMAND COUNT", - "COMMAND GETKEYS", - "CONFIG GET", - "DEBUG", - "RANDOMKEY", - "READONLY", - "READWRITE", - "TIME", - "GRAPH.CONFIG", - ], - DEFAULT_NODE, - ), - list_keys_to_dict( - [ - "FLUSHALL", - "FLUSHDB", - "FUNCTION DELETE", - "FUNCTION FLUSH", - "FUNCTION LIST", - "FUNCTION LOAD", - "FUNCTION RESTORE", - "SCAN", - "SCRIPT EXISTS", - "SCRIPT FLUSH", - "SCRIPT LOAD", - ], - PRIMARIES, - ), - list_keys_to_dict(["FUNCTION DUMP"], RANDOM), - list_keys_to_dict( - [ - "CLUSTER COUNTKEYSINSLOT", - "CLUSTER DELSLOTS", - "CLUSTER DELSLOTSRANGE", - "CLUSTER GETKEYSINSLOT", - "CLUSTER SETSLOT", - ], - SLOT_ID, - ), - ) - - SEARCH_COMMANDS = ( - [ - "FT.CREATE", - "FT.SEARCH", - "FT.AGGREGATE", - "FT.EXPLAIN", - "FT.EXPLAINCLI", - "FT,PROFILE", - "FT.ALTER", - "FT.DROPINDEX", - "FT.ALIASADD", - "FT.ALIASUPDATE", - "FT.ALIASDEL", - "FT.TAGVALS", - "FT.SUGADD", - "FT.SUGGET", - "FT.SUGDEL", - "FT.SUGLEN", - "FT.SYNUPDATE", - "FT.SYNDUMP", - "FT.SPELLCHECK", - "FT.DICTADD", - "FT.DICTDEL", - "FT.DICTDUMP", - "FT.INFO", - "FT._LIST", - "FT.CONFIG", - "FT.ADD", - "FT.DEL", - "FT.DROP", - "FT.GET", - "FT.MGET", - "FT.SYNADD", - ], - ) - - CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { - "CLUSTER ADDSLOTS": bool, - "CLUSTER ADDSLOTSRANGE": bool, - "CLUSTER COUNT-FAILURE-REPORTS": int, - "CLUSTER COUNTKEYSINSLOT": int, - "CLUSTER DELSLOTS": bool, - "CLUSTER DELSLOTSRANGE": bool, - "CLUSTER FAILOVER": bool, - "CLUSTER FORGET": bool, - "CLUSTER GETKEYSINSLOT": list, - "CLUSTER KEYSLOT": int, - "CLUSTER MEET": bool, - "CLUSTER REPLICATE": bool, - "CLUSTER RESET": bool, - "CLUSTER SAVECONFIG": bool, - "CLUSTER SET-CONFIG-EPOCH": bool, - "CLUSTER SETSLOT": bool, - "CLUSTER SLOTS": parse_cluster_slots, - "ASKING": bool, - "READONLY": bool, - "READWRITE": bool, - } - - RESULT_CALLBACKS = dict_merge( - list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), - list_keys_to_dict( - ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) - ), - list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), - list_keys_to_dict( - [ - "PING", - "CONFIG SET", - "CONFIG REWRITE", - "CONFIG RESETSTAT", - "CLIENT SETNAME", - "BGSAVE", - "SLOWLOG RESET", - "SAVE", - "MEMORY PURGE", - "CLIENT PAUSE", - "CLIENT UNPAUSE", - ], - lambda command, res: all(res.values()) if isinstance(res, dict) else res, - ), - list_keys_to_dict( - ["DBSIZE", "WAIT"], - lambda command, res: sum(res.values()) if isinstance(res, dict) else res, - ), - list_keys_to_dict( - ["CLIENT UNBLOCK"], lambda command, res: 1 if sum(res.values()) > 0 else 0 - ), - list_keys_to_dict(["SCAN"], parse_scan_result), - list_keys_to_dict( - ["SCRIPT LOAD"], lambda command, res: list(res.values()).pop() - ), - list_keys_to_dict( - ["SCRIPT EXISTS"], lambda command, res: [all(k) for k in zip(*res.values())] - ), - list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), - ) - - ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) - +class RedisCluster(AbstractRedisCluster, AsyncRedisClusterCommands): @classmethod def from_url(cls, url, **kwargs): """ @@ -565,9 +208,6 @@ def __init__( # Update the connection arguments # Whenever a new connection is established, RedisCluster's on_connect # method should be run - # If the user passed on_connect function we'll save it and run it - # inside the RedisCluster.on_connect() function - self.user_on_connect_func = kwargs.pop("redis_connect_func", None) kwargs.update({"redis_connect_func": self.on_connect}) kwargs = cleanup_kwargs(**kwargs) @@ -594,34 +234,44 @@ def __init__( self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) - self.commands_parser = CommandsParser(self) - self._lock = threading.Lock() - - def __enter__(self): + self.commands_parser = CommandsParser() + self._initialize = True + self._lock = asyncio.Lock() + + async def initialize(self, force=False): + if self._initialize or force: + self._initialize = False + await self.nodes_manager.initialize() + await self.commands_parser.initialize(self) return self - def __exit__(self, exc_type, exc_value, traceback): - self.close() + async def __aenter__(self): + return await self.initialize() - def __del__(self): - self.close() + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() - def disconnect_connection_pools(self): - for node in self.get_nodes(): - if node.redis_connection: - try: - node.redis_connection.connection_pool.disconnect() - except OSError: - # Client was already disconnected. do nothing - pass + def __await__(self): + return self.initialize().__await__() - def on_connect(self, connection): + def __del__(self): + try: + loop = asyncio.get_event_loop() + coro = self.close() + if loop.is_running(): + loop.create_task(coro) + else: + loop.run_until_complete(coro) + except Exception: + pass + + async def on_connect(self, connection): """ Initialize the connection, authenticate and select a database and send READONLY if it is set during object initialization. """ connection.set_parser(ClusterParser) - connection.on_connect() + await connection.on_connect() if self.read_from_replicas: # Sending READONLY command to server to configure connection as @@ -629,18 +279,15 @@ def on_connect(self, connection): # to a failover, we should establish a READONLY connection # regardless of the server type. If this is a primary connection, # READONLY would not affect executing write commands. - connection.send_command("READONLY") - if str_if_bytes(connection.read_response()) != "OK": + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response()) != "OK": raise ConnectionError("READONLY command failed") - if self.user_on_connect_func is not None: - self.user_on_connect_func(connection) - - def get_redis_connection(self, node): + async def get_redis_connection(self, node): if not node.redis_connection: - with self._lock: + async with self._lock: if not node.redis_connection: - self.nodes_manager.create_redis_connections([node]) + await self.nodes_manager.create_redis_connections([node]) return node.redis_connection def get_node(self, host=None, port=None, node_name=None): @@ -700,127 +347,11 @@ def set_default_node(self, node): log.info(f"Changed the default cluster node to {node}") return True - def monitor(self, target_node=None): - """ - Returns a Monitor object for the specified target node. - The default cluster node will be selected if no target node was - specified. - Monitor is useful for handling the MONITOR command to the redis server. - next_command() method returns one command from monitor - listen() method yields commands from monitor. - """ - if target_node is None: - target_node = self.get_default_node() - if target_node.redis_connection is None: - raise RedisClusterException( - f"Cluster Node {target_node.name} has no redis_connection" - ) - return target_node.redis_connection.monitor() - - def pubsub(self, node=None, host=None, port=None, **kwargs): - """ - Allows passing a ClusterNode, or host&port, to get a pubsub instance - connected to the specified node - """ - return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) - - def pipeline(self, transaction=None, shard_hint=None): - """ - Cluster impl: - Pipelines do not work in cluster mode the same way they - do in normal mode. Create a clone of this object so - that simulating pipelines will work correctly. Each - command will be called directly when used and - when calling execute() will only return the result stack. - """ - if shard_hint: - raise RedisClusterException("shard_hint is deprecated in cluster mode") - - if transaction: - raise RedisClusterException("transaction is deprecated in cluster mode") - - return ClusterPipeline( - nodes_manager=self.nodes_manager, - commands_parser=self.commands_parser, - startup_nodes=self.nodes_manager.startup_nodes, - result_callbacks=self.result_callbacks, - cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.cluster_error_retry_attempts, - read_from_replicas=self.read_from_replicas, - reinitialize_steps=self.reinitialize_steps, - ) - - def lock( - self, - name, - timeout=None, - sleep=0.1, - blocking_timeout=None, - lock_class=None, - thread_local=True, - ): - """ - Return a new Lock object using key ``name`` that mimics - the behavior of threading.Lock. - - If specified, ``timeout`` indicates a maximum life for the lock. - By default, it will remain locked until release() is called. - - ``sleep`` indicates the amount of time to sleep per loop iteration - when the lock is in blocking mode and another client is currently - holding the lock. - - ``blocking_timeout`` indicates the maximum amount of time in seconds to - spend trying to acquire the lock. A value of ``None`` indicates - continue trying forever. ``blocking_timeout`` can be specified as a - float or integer, both representing the number of seconds to wait. - - ``lock_class`` forces the specified lock implementation. Note that as - of redis-py 3.0, the only lock class we implement is ``Lock`` (which is - a Lua-based lock). So, it's unlikely you'll need this parameter, unless - you have created your own custom lock class. - - ``thread_local`` indicates whether the lock token is placed in - thread-local storage. By default, the token is placed in thread local - storage so that a thread only sees its token, not a token set by - another thread. Consider the following timeline: - - time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. - thread-1 sets the token to "abc" - time: 1, thread-2 blocks trying to acquire `my-lock` using the - Lock instance. - time: 5, thread-1 has not yet completed. redis expires the lock - key. - time: 5, thread-2 acquired `my-lock` now that it's available. - thread-2 sets the token to "xyz" - time: 6, thread-1 finishes its work and calls release(). if the - token is *not* stored in thread local storage, then - thread-1 would see the token value as "xyz" and would be - able to successfully release the thread-2's lock. - - In some use cases it's necessary to disable thread local storage. For - example, if you have code where one thread acquires a lock and passes - that lock instance to a worker thread to release later. If thread - local storage isn't disabled in this case, the worker thread won't see - the token set by the thread that acquired the lock. Our assumption - is that these cases aren't common and as such default to using - thread local storage.""" - if lock_class is None: - lock_class = Lock - return lock_class( - self, - name, - timeout=timeout, - sleep=sleep, - blocking_timeout=blocking_timeout, - thread_local=thread_local, - ) - def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs): + async def _determine_nodes(self, *args, **kwargs): command = args[0] nodes_flag = kwargs.pop("nodes_flag", None) if nodes_flag is not None: @@ -850,8 +381,8 @@ def _determine_nodes(self, *args, **kwargs): return [self.nodes_manager.default_node] else: # get the node that holds the key's slot - slot = self.determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( + slot = await self.determine_slot(*args) + node = await self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) log.debug(f"Target for {args}: slot {slot}") @@ -875,7 +406,7 @@ def keyslot(self, key): k = self.encoder.encode(key) return key_slot(k) - def _get_command_keys(self, *args): + async def _get_command_keys(self, *args): """ Get the keys in the command. If the command has no keys in in, None is returned. @@ -888,9 +419,9 @@ def _get_command_keys(self, *args): So, don't use this function with EVAL or EVALSHA. """ redis_conn = self.get_default_node().redis_connection - return self.commands_parser.get_keys(redis_conn, *args) + return await self.commands_parser.get_keys(redis_conn, *args) - def determine_slot(self, *args): + async def determine_slot(self, *args): """ Figure out what slot to use based on args. @@ -921,7 +452,7 @@ def determine_slot(self, *args): return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) keys = eval_keys else: - keys = self._get_command_keys(*args) + keys = await self._get_command_keys(*args) if keys is None or len(keys) == 0: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot @@ -947,9 +478,6 @@ def determine_slot(self, *args): return slots.pop() - def reinitialize_caches(self): - self.nodes_manager.initialize() - def get_encoder(self): """ Get the connections' encoder @@ -985,7 +513,7 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def execute_command(self, *args, **kwargs): + async def execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. @@ -1001,6 +529,7 @@ def execute_command(self, *args, **kwargs): list dict """ + await self.initialize() target_nodes_specified = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) @@ -1024,7 +553,7 @@ def execute_command(self, *args, **kwargs): res = {} if not target_nodes_specified: # Determine the nodes to execute the command on - target_nodes = self._determine_nodes( + target_nodes = await self._determine_nodes( *args, **kwargs, nodes_flag=passed_targets ) if not target_nodes: @@ -1032,11 +561,11 @@ def execute_command(self, *args, **kwargs): f"No targets were found to execute {args} command on" ) for node in target_nodes: - res[node.name] = self._execute_command(node, *args, **kwargs) + res[node.name] = await self._execute_command(node, *args, **kwargs) # Return the processed result return self._process_result(args[0], res, **kwargs) except BaseException as e: - if type(e) in RedisCluster.ERRORS_ALLOW_RETRY: + if type(e) in AbstractRedisCluster.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e @@ -1048,7 +577,7 @@ def execute_command(self, *args, **kwargs): # to caller of this method raise exception - def _execute_command(self, target_node, *args, **kwargs): + async def _execute_command(self, target_node, *args, **kwargs): """ Send a command to a node in the cluster """ @@ -1069,8 +598,8 @@ def _execute_command(self, target_node, *args, **kwargs): elif moved: # MOVED occurred and the slots cache was updated, # refresh the target node - slot = self.determine_slot(*args) - target_node = self.nodes_manager.get_node_from_slot( + slot = await self.determine_slot(*args) + target_node = await self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) moved = False @@ -1079,15 +608,17 @@ def _execute_command(self, target_node, *args, **kwargs): f"Executing command {command} on target node: " f"{target_node.server_type} {target_node.name}" ) - redis_node = self.get_redis_connection(target_node) - connection = get_connection(redis_node, *args, **kwargs) + redis_node = await self.get_redis_connection(target_node) + connection = await get_connection(redis_node, *args, **kwargs) if asking: - connection.send_command("ASKING") - redis_node.parse_response(connection, "ASKING", **kwargs) + await connection.send_command("ASKING") + await redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) + await connection.send_command(*args) + response = await redis_node.parse_response( + connection, command, **kwargs + ) if command in self.cluster_response_callbacks: response = self.cluster_response_callbacks[command]( response, **kwargs @@ -1103,7 +634,7 @@ def _execute_command(self, target_node, *args, **kwargs): # connection from the pool before timing out, so check that # this is an actual connection before attempting to disconnect. if connection is not None: - connection.disconnect() + await connection.disconnect() connection_error_retry_counter += 1 # Give the node 0.25 seconds to get back up and retry again @@ -1111,11 +642,11 @@ def _execute_command(self, target_node, *args, **kwargs): # to reinitialize the cluster and see if the nodes # configuration has changed or not if connection_error_retry_counter < 5: - time.sleep(0.25) + await asyncio.sleep(0.25) else: # Hard force of reinitialize of the node/slots setup # and try again with the new setup - self.nodes_manager.initialize() + await self.initialize(force=True) raise except MovedError as e: # First, we will try to patch the slots/nodes cache with the @@ -1129,7 +660,7 @@ def _execute_command(self, target_node, *args, **kwargs): log.exception("MovedError") self.reinitialize_counter += 1 if self._should_reinitialized(): - self.nodes_manager.initialize() + await self.initialize(force=True) # Reset the counter self.reinitialize_counter = 0 else: @@ -1139,7 +670,7 @@ def _execute_command(self, target_node, *args, **kwargs): log.exception("TryAgainError") if ttl < self.RedisClusterRequestTTL / 2: - time.sleep(0.05) + await asyncio.sleep(0.05) except AskError as e: log.exception("AskError") @@ -1150,8 +681,8 @@ def _execute_command(self, target_node, *args, **kwargs): # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command - time.sleep(0.25) - self.nodes_manager.initialize() + await asyncio.sleep(0.25) + await self.initialize(force=True) raise e except ResponseError as e: message = e.__str__() @@ -1160,19 +691,19 @@ def _execute_command(self, target_node, *args, **kwargs): except BaseException as e: log.exception("BaseException") if connection: - connection.disconnect() + await connection.disconnect() raise e finally: if connection is not None: - redis_node.connection_pool.release(connection) + await redis_node.connection_pool.release(connection) raise ClusterError("TTL exhausted.") - def close(self): + async def close(self): try: - with self._lock: + async with self._lock: if self.nodes_manager: - self.nodes_manager.close() + await self.nodes_manager.close() except AttributeError: # RedisCluster's __init__ can fail before nodes_manager is set pass @@ -1197,16 +728,6 @@ def _process_result(self, command, res, **kwargs): else: return res - def load_external_module(self, funcname, func): - """ - This function can be used to add externally defined redis modules, - and their namespaces to the redis client. - - ``funcname`` - A string containing the name of the function to create - ``func`` - The function, being added to this class. - """ - setattr(self, funcname, func) - class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1232,27 +753,16 @@ def __eq__(self, obj): return isinstance(obj, ClusterNode) and obj.name == self.name def __del__(self): - if self.redis_connection is not None: - self.redis_connection.close() - - -class LoadBalancer: - """ - Round-Robin Load Balancing - """ - - def __init__(self, start_index=0): - self.primary_to_idx = {} - self.start_index = start_index - - def get_server_index(self, primary, list_size): - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index - - def reset(self): - self.primary_to_idx.clear() + try: + if self.redis_connection is not None: + loop = asyncio.get_event_loop() + coro = self.redis_connection.close(True) + if loop.is_running(): + loop.create_task(coro) + else: + loop.run_until_complete(coro) + except Exception: + pass class NodesManager: @@ -1275,9 +785,8 @@ def __init__( self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() if lock is None: - lock = threading.Lock() + lock = asyncio.Lock() self._lock = lock - self.initialize() def get_node(self, host=None, port=None, node_name=None): """ @@ -1342,12 +851,14 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + async def get_node_from_slot( + self, slot, read_from_replicas=False, server_type=None + ): """ Gets a node that servers this hash slot """ if self._moved_exception: - with self._lock: + async with self._lock: if self._moved_exception: self._update_moved_slots() @@ -1404,17 +915,17 @@ def check_slots_coverage(self, slots_cache): return False return True - def create_redis_connections(self, nodes): + async def create_redis_connections(self, nodes): """ This function will create a redis connection to all nodes in :nodes: """ for node in nodes: if node.redis_connection is None: - node.redis_connection = self.create_redis_node( + node.redis_connection = await self.create_redis_node( host=node.host, port=node.port, **self.connection_kwargs ) - def create_redis_node(self, host, port, **kwargs): + async def create_redis_node(self, host, port, **kwargs): if self.from_url: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) @@ -1422,9 +933,10 @@ def create_redis_node(self, host, port, **kwargs): r = Redis(connection_pool=ConnectionPool(**kwargs)) else: r = Redis(host=host, port=port, **kwargs) + await r.initialize() return r - def initialize(self): + async def initialize(self): """ Initializes the nodes cache, slots cache and redis connections. :startup_nodes: @@ -1445,18 +957,27 @@ def initialize(self): else: # Create a new Redis connection and let Redis decode the # responses so we won't need to handle that + # TODO: redis_connect_func shouldn't need to be removed & readded + redis_connect_func = kwargs.pop("redis_connect_func") copy_kwargs = copy.deepcopy(kwargs) - copy_kwargs.update({"decode_responses": True, "encoding": "utf-8"}) - r = self.create_redis_node( + kwargs.setdefault("redis_connect_func", redis_connect_func) + copy_kwargs.update( + { + "decode_responses": True, + "encoding": "utf-8", + "redis_connect_func": redis_connect_func, + } + ) + r = await self.create_redis_node( startup_node.host, startup_node.port, **copy_kwargs ) self.startup_nodes[startup_node.name].redis_connection = r # Make sure cluster mode is enabled on this node - if bool(r.info().get("cluster_enabled")) is False: + if bool((await r.info()).get("cluster_enabled")) is False: raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + cluster_slots = str_if_bytes(await r.execute_command("CLUSTER SLOTS")) startup_nodes_reachable = True except (ConnectionError, TimeoutError) as e: msg = e.__str__ @@ -1561,7 +1082,7 @@ def initialize(self): ) # Create Redis connections to all nodes - self.create_redis_connections(list(tmp_nodes_cache.values())) + await self.create_redis_connections(list(tmp_nodes_cache.values())) # Check if the slots are not fully covered if not fully_covered and self._require_full_coverage: @@ -1583,11 +1104,11 @@ def initialize(self): # If initialize was called after a MovedError, clear it self._moved_exception = None - def close(self): + async def close(self): self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: - node.redis_connection.close() + await node.redis_connection.close(True) def reset(self): try: @@ -1595,657 +1116,3 @@ def reset(self): except TypeError: # The read_load_balancer is None, do nothing pass - - -class ClusterPubSub(PubSub): - """ - Wrapper for PubSub class. - - IMPORTANT: before using ClusterPubSub, read about the known limitations - with pubsub in Cluster mode and learn how to workaround them: - https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html - """ - - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): - """ - When a pubsub instance is created without specifying a node, a single - node will be transparently chosen for the pubsub connection on the - first command execution. The node will be determined by: - 1. Hashing the channel name in the request to find its keyslot - 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true, a replica can be selected. - - :type redis_cluster: RedisCluster - :type node: ClusterNode - :type host: str - :type port: int - """ - self.node = None - self.set_pubsub_node(redis_cluster, node, host, port) - connection_pool = ( - None - if self.node is None - else redis_cluster.get_redis_connection(self.node).connection_pool - ) - self.cluster = redis_cluster - super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder - ) - - def set_pubsub_node(self, cluster, node=None, host=None, port=None): - """ - The pubsub node will be set according to the passed node, host and port - When none of the node, host, or port are specified - the node is set - to None and will be determined by the keyslot of the channel in the - first command to be executed. - RedisClusterException will be thrown if the passed node does not exist - in the cluster. - If host is passed without port, or vice versa, a DataError will be - thrown. - :type cluster: RedisCluster - :type node: ClusterNode - :type host: str - :type port: int - """ - if node is not None: - # node is passed by the user - self._raise_on_invalid_node(cluster, node, node.host, node.port) - pubsub_node = node - elif host is not None and port is not None: - # host and port passed by the user - node = cluster.get_node(host=host, port=port) - self._raise_on_invalid_node(cluster, node, host, port) - pubsub_node = node - elif any([host, port]) is True: - # only 'host' or 'port' passed - raise DataError("Passing a host requires passing a port, " "and vice versa") - else: - # nothing passed by the user. set node to None - pubsub_node = None - - self.node = pubsub_node - - def get_pubsub_node(self): - """ - Get the node that is being used as the pubsub connection - """ - return self.node - - def _raise_on_invalid_node(self, redis_cluster, node, host, port): - """ - Raise a RedisClusterException if the node is None or doesn't exist in - the cluster. - """ - if node is None or redis_cluster.get_node(node_name=node.name) is None: - raise RedisClusterException( - f"Node {host}:{port} doesn't exist in the cluster" - ) - - def execute_command(self, *args, **kwargs): - """ - Execute a publish/subscribe command. - - Taken code from redis-py and tweak to make it work within a cluster. - """ - # NOTE: don't parse the response in this function -- it could pull a - # legitimate message off the stack if the connection is already - # subscribed to one or more channels - - if self.connection is None: - if self.connection_pool is None: - if len(args) > 1: - # Hash the first channel and get one of the nodes holding - # this slot - channel = args[1] - slot = self.cluster.keyslot(channel) - node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas - ) - else: - # Get a random node - node = self.cluster.get_random_node() - self.node = node - redis_connection = self.cluster.get_redis_connection(node) - self.connection_pool = redis_connection.connection_pool - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - connection = self.connection - self._execute(connection, connection.send_command, *args) - - def get_redis_connection(self): - """ - Get the Redis connection of the pubsub connected node. - """ - if self.node is not None: - return self.node.redis_connection - - -class ClusterPipeline(RedisCluster): - """ - Support for Redis pipeline - in cluster mode - """ - - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - MovedError, - AskError, - TryAgainError, - ) - - def __init__( - self, - nodes_manager, - commands_parser, - result_callbacks=None, - cluster_response_callbacks=None, - startup_nodes=None, - read_from_replicas=False, - cluster_error_retry_attempts=5, - reinitialize_steps=10, - **kwargs, - ): - """ """ - self.command_stack = [] - self.nodes_manager = nodes_manager - self.commands_parser = commands_parser - self.refresh_table_asap = False - self.result_callbacks = ( - result_callbacks or self.__class__.RESULT_CALLBACKS.copy() - ) - self.startup_nodes = startup_nodes if startup_nodes else [] - self.read_from_replicas = read_from_replicas - self.command_flags = self.__class__.COMMAND_FLAGS.copy() - self.cluster_response_callbacks = cluster_response_callbacks - self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.reinitialize_counter = 0 - self.reinitialize_steps = reinitialize_steps - self.encoder = Encoder( - kwargs.get("encoding", "utf-8"), - kwargs.get("encoding_errors", "strict"), - kwargs.get("decode_responses", False), - ) - - def __repr__(self): - """ """ - return f"{type(self).__name__}" - - def __enter__(self): - """ """ - return self - - def __exit__(self, exc_type, exc_value, traceback): - """ """ - self.reset() - - def __del__(self): - try: - self.reset() - except Exception: - pass - - def __len__(self): - """ """ - return len(self.command_stack) - - def __nonzero__(self): - "Pipeline instances should always evaluate to True on Python 2.7" - return True - - def __bool__(self): - "Pipeline instances should always evaluate to True on Python 3+" - return True - - def execute_command(self, *args, **kwargs): - """ - Wrapper function for pipeline_execute_command - """ - return self.pipeline_execute_command(*args, **kwargs) - - def pipeline_execute_command(self, *args, **options): - """ - Appends the executed command to the pipeline's command stack - """ - self.command_stack.append( - PipelineCommand(args, options, len(self.command_stack)) - ) - return self - - def raise_first_error(self, stack): - """ - Raise the first exception on the stack - """ - for c in stack: - r = c.result - if isinstance(r, Exception): - self.annotate_exception(r, c.position + 1, c.args) - raise r - - def annotate_exception(self, exception, number, command): - """ - Provides extra context to the exception prior to it being handled - """ - cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" - ) - exception.args = (msg,) + exception.args[1:] - - def execute(self, raise_on_error=True): - """ - Execute all the commands in the current pipeline - """ - stack = self.command_stack - try: - return self.send_cluster_commands(stack, raise_on_error) - finally: - self.reset() - - def reset(self): - """ - Reset back to empty pipeline. - """ - self.command_stack = [] - - self.scripts = set() - - # TODO: Implement - # make sure to reset the connection state in the event that we were - # watching something - # if self.watching and self.connection: - # try: - # # call this manually since our unwatch or - # # immediate_execute_command methods can call reset() - # self.connection.send_command('UNWATCH') - # self.connection.read_response() - # except ConnectionError: - # # disconnect will also remove any previous WATCHes - # self.connection.disconnect() - - # clean up the other instance attributes - self.watching = False - self.explicit_transaction = False - - # TODO: Implement - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - # if self.connection: - # self.connection_pool.release(self.connection) - # self.connection = None - - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Wrapper for CLUSTERDOWN error handling. - - If the cluster reports it is down it is assumed that: - - connection_pool was disconnected - - connection_pool was reseted - - refereh_table_asap set to True - - It will try the number of times specified by - the config option "self.cluster_error_retry_attempts" - which defaults to 3 unless manually configured. - - If it reaches the number of times, the command will - raises ClusterDownException. - """ - if not stack: - return [] - - for _ in range(0, self.cluster_error_retry_attempts): - try: - return self._send_cluster_commands( - stack, - raise_on_error=raise_on_error, - allow_redirections=allow_redirections, - ) - except ClusterDownError: - # Try again with the new cluster setup. All other errors - # should be raised. - pass - - # If it fails the configured number of times then raise - # exception back to caller of this method - raise ClusterDownError("CLUSTERDOWN error. Unable to rebuild the cluster") - - def _send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Send a bunch of cluster commands to the redis cluster. - - `allow_redirections` If the pipeline should follow - `ASK` & `MOVED` responses automatically. If set - to false it will raise RedisClusterException. - """ - # the first time sending the commands we send all of - # the commands that were queued up. - # if we have to run through it again, we only retry - # the commands that failed. - attempt = sorted(stack, key=lambda x: x.position) - - # build a list of node objects based on node names we need to - nodes = {} - - # as we move through each command that still needs to be processed, - # we figure out the slot number that command maps to, then from - # the slot determine the node. - for c in attempt: - # refer to our internal node -> slot table that - # tells us where a given - # command should route to. - node = self._determine_nodes(*c.args) - - # now that we know the name of the node - # ( it's just a string in the form of host:port ) - # we can build a list of commands for each node. - node_name = node[0].name - if node_name not in nodes: - redis_node = self.get_redis_connection(node[0]) - connection = get_connection(redis_node, c.args) - nodes[node_name] = NodeCommands( - redis_node.parse_response, redis_node.connection_pool, connection - ) - - nodes[node_name].append(c) - - # send the commands in sequence. - # we write to all the open sockets for each node first, - # before reading anything - # this allows us to flush all the requests out across the - # network essentially in parallel - # so that we can read them all in parallel as they come back. - # we dont' multiplex on the sockets as they come available, - # but that shouldn't make too much difference. - node_commands = nodes.values() - for n in node_commands: - n.write() - - for n in node_commands: - n.read() - - # release all of the redis connections we allocated earlier - # back into the connection pool. - # we used to do this step as part of a try/finally block, - # but it is really dangerous to - # release connections back into the pool if for some - # reason the socket has data still left in it - # from a previous operation. The write and - # read operations already have try/catch around them for - # all known types of errors including connection - # and socket level errors. - # So if we hit an exception, something really bad - # happened and putting any oF - # these connections back into the pool is a very bad idea. - # the socket might have unread buffer still sitting in it, - # and then the next time we read from it we pass the - # buffered result back from a previous command and - # every single request after to that connection will always get - # a mismatched result. - for n in nodes.values(): - n.connection_pool.release(n.connection) - - # if the response isn't an exception it is a - # valid response from the node - # we're all done with that command, YAY! - # if we have more commands to attempt, we've run into problems. - # collect all the commands we are allowed to retry. - # (MOVED, ASK, or connection errors or timeout errors) - attempt = sorted( - ( - c - for c in attempt - if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) - ), - key=lambda x: x.position, - ) - if attempt and allow_redirections: - # RETRY MAGIC HAPPENS HERE! - # send these remaing comamnds one at a time using `execute_command` - # in the main client. This keeps our retry logic - # in one place mostly, - # and allows us to be more confident in correctness of behavior. - # at this point any speed gains from pipelining have been lost - # anyway, so we might as well make the best - # attempt to get the correct behavior. - # - # The client command will handle retries for each - # individual command sequentially as we pass each - # one into `execute_command`. Any exceptions - # that bubble out should only appear once all - # retries have been exhausted. - # - # If a lot of commands have failed, we'll be setting the - # flag to rebuild the slots table from scratch. - # So MOVED errors should correct themselves fairly quickly. - log.exception( - f"An exception occurred during pipeline execution. " - f"args: {attempt[-1].args}, " - f"error: {type(attempt[-1].result).__name__} " - f"{str(attempt[-1].result)}" - ) - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() - for c in attempt: - try: - # send each command individually like we - # do in the main client. - c.result = super().execute_command(*c.args, **c.options) - except RedisError as e: - c.result = e - - # turn the response back into a simple flat array that corresponds - # to the sequence of commands issued in the stack in pipeline.execute() - response = [] - for c in sorted(stack, key=lambda x: x.position): - if c.args[0] in self.cluster_response_callbacks: - c.result = self.cluster_response_callbacks[c.args[0]]( - c.result, **c.options - ) - response.append(c.result) - - if raise_on_error: - self.raise_first_error(stack) - - return response - - def _fail_on_redirect(self, allow_redirections): - """ """ - if not allow_redirections: - raise RedisClusterException( - "ASK & MOVED redirection not allowed in this pipeline" - ) - - def exists(self, *keys): - return self.execute_command("EXISTS", *keys) - - def eval(self): - """ """ - raise RedisClusterException("method eval() is not implemented") - - def multi(self): - """ """ - raise RedisClusterException("method multi() is not implemented") - - def immediate_execute_command(self, *args, **options): - """ """ - raise RedisClusterException( - "method immediate_execute_command() is not implemented" - ) - - def _execute_transaction(self, *args, **kwargs): - """ """ - raise RedisClusterException("method _execute_transaction() is not implemented") - - def load_scripts(self): - """ """ - raise RedisClusterException("method load_scripts() is not implemented") - - def watch(self, *names): - """ """ - raise RedisClusterException("method watch() is not implemented") - - def unwatch(self): - """ """ - raise RedisClusterException("method unwatch() is not implemented") - - def script_load_for_pipeline(self, *args, **kwargs): - """ """ - raise RedisClusterException( - "method script_load_for_pipeline() is not implemented" - ) - - def delete(self, *names): - """ - "Delete a key specified by ``names``" - """ - if len(names) != 1: - raise RedisClusterException( - "deleting multiple keys is not " "implemented in pipeline command" - ) - - return self.execute_command("DEL", names[0]) - - -def block_pipeline_command(func): - """ - Prints error because some pipelined commands should - be blocked when running in cluster-mode - """ - - def inner(*args, **kwargs): - raise RedisClusterException( - f"ERROR: Calling pipelined function {func.__name__} is blocked " - f"when running redis in cluster mode..." - ) - - return inner - - -# Blocked pipeline commands -ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) -ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) -ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname) -ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) -ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname) -ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) -ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) -ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) -ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) -ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) -ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) -ClusterPipeline.move = block_pipeline_command(RedisCluster.move) -ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) -ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) -ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) -ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) -ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) -ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) -ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) -ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) -ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) -ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) -ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) -ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) -ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) -ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) -ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) -ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) -ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) -ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) -ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) -ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite) -ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly) - - -class PipelineCommand: - """ """ - - def __init__(self, args, options=None, position=None): - self.args = args - if options is None: - options = {} - self.options = options - self.position = position - self.result = None - self.node = None - self.asking = False - - -class NodeCommands: - """ """ - - def __init__(self, parse_response, connection_pool, connection): - """ """ - self.parse_response = parse_response - self.connection_pool = connection_pool - self.connection = connection - self.commands = [] - - def append(self, c): - """ """ - self.commands.append(c) - - def write(self): - """ - Code borrowed from Redis so it can be fixed - """ - connection = self.connection - commands = self.commands - - # We are going to clobber the commands with the write, so go ahead - # and ensure that nothing is sitting there from a previous run. - for c in commands: - c.result = None - - # build up all commands into a single request to increase network perf - # send all the commands and catch connection and timeout errors. - try: - connection.send_packed_command( - connection.pack_commands([c.args for c in commands]) - ) - except (ConnectionError, TimeoutError) as e: - for c in commands: - c.result = e - - def read(self): - """ """ - connection = self.connection - for c in self.commands: - - # if there is a result on this command, - # it means we ran into an exception - # like a connection error. Trying to parse - # a response on a connection that - # is no longer open will result in a - # connection error raised by redis-py. - # but redis-py doesn't check in parse_response - # that the sock object is - # still set and if you try to - # read from a closed connection, it will - # result in an AttributeError because - # it will do a readline() call on None. - # This can have all kinds of nasty side-effects. - # Treating this case as a connection error - # is fine because it will dump - # the connection object back into the - # pool and on the next write, it will - # explicitly open the connection and all will be well. - if c.result is None: - try: - c.result = self.parse_response(connection, c.args[0], **c.options) - except (ConnectionError, TimeoutError) as e: - for c in self.commands: - c.result = e - return - except RedisError: - c.result = sys.exc_info()[1] diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 89292ab2d3..0c82e4f2c6 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,5 +1,4 @@ from redis.exceptions import RedisError, ResponseError -from redis.utils import str_if_bytes class CommandsParser: @@ -11,13 +10,11 @@ class CommandsParser: 'COMMAND GETKEYS'. """ - def __init__(self, redis_connection): - self.initialized = False + def __init__(self): self.commands = {} - self.initialize(redis_connection) - def initialize(self, r): - commands = r.execute_command("COMMAND") + async def initialize(self, r): + commands = await r.execute_command("COMMAND") uppercase_commands = [] for cmd in commands: if any(x.isupper() for x in cmd): @@ -29,7 +26,7 @@ def initialize(self, r): # As soon as this PR is merged into Redis, we should reimplement # our logic to use COMMAND INFO changes to determine the key positions # https://github.com/redis/redis/pull/8324 - def get_keys(self, redis_conn, *args): + async def get_keys(self, redis_conn, *args): """ Get the keys from the passed command. @@ -56,7 +53,7 @@ def get_keys(self, redis_conn, *args): else: # We'll try to reinitialize the commands cache, if the engine # version has changed, the commands may not be current - self.initialize(redis_conn) + await self.initialize(redis_conn) if cmd_name not in self.commands: raise RedisError( f"{cmd_name.upper()} command doesn't exist in Redis commands" @@ -64,9 +61,7 @@ def get_keys(self, redis_conn, *args): command = self.commands.get(cmd_name) if "movablekeys" in command["flags"]: - keys = self._get_moveable_keys(redis_conn, *args) - elif "pubsub" in command["flags"]: - keys = self._get_pubsub_keys(*args) + keys = await self._get_moveable_keys(redis_conn, *args) else: if ( command["step_count"] == 0 @@ -85,7 +80,7 @@ def get_keys(self, redis_conn, *args): return keys - def _get_moveable_keys(self, redis_conn, *args): + async def _get_moveable_keys(self, redis_conn, *args): """ NOTE: Due to a bug in redis<7.0, this function does not work properly for EVAL or EVALSHA when the `numkeys` arg is 0. @@ -101,7 +96,7 @@ def _get_moveable_keys(self, redis_conn, *args): pieces = pieces + cmd_name.split() pieces = pieces + list(args[1:]) try: - keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) + keys = await redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: message = e.__str__() if ( @@ -112,32 +107,3 @@ def _get_moveable_keys(self, redis_conn, *args): else: raise e return keys - - def _get_pubsub_keys(self, *args): - """ - Get the keys from pubsub command. - Although PubSub commands have predetermined key locations, they are not - supported in the 'COMMAND's output, so the key positions are hardcoded - in this method - """ - if len(args) < 2: - # The command has no keys in it - return None - args = [str_if_bytes(arg) for arg in args] - command = args[0].upper() - keys = None - if command == "PUBSUB": - # the second argument is a part of the command name, e.g. - # ['PUBSUB', 'NUMSUB', 'foo']. - pubsub_type = args[1].upper() - if pubsub_type in ["CHANNELS", "NUMSUB"]: - keys = args[2:] - elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: - # format example: - # SUBSCRIBE channel [channel ...] - keys = list(args[1:]) - elif command == "PUBLISH": - # format example: - # PUBLISH channel message - keys = [args[1]] - return keys diff --git a/redis/cluster.py b/redis/cluster.py index 92469e4f37..9ede6b6eb9 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -204,7 +204,7 @@ class ClusterParser(DefaultParser): ) -class RedisCluster(RedisClusterCommands): +class AbstractRedisCluster: RedisClusterRequestTTL = 16 PRIMARIES = "primaries" @@ -418,6 +418,8 @@ class RedisCluster(RedisClusterCommands): ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + +class RedisCluster(AbstractRedisCluster, RedisClusterCommands): @classmethod def from_url(cls, url, **kwargs): """ @@ -947,9 +949,6 @@ def determine_slot(self, *args): return slots.pop() - def reinitialize_caches(self): - self.nodes_manager.initialize() - def get_encoder(self): """ Get the connections' encoder @@ -1036,7 +1035,7 @@ def execute_command(self, *args, **kwargs): # Return the processed result return self._process_result(args[0], res, **kwargs) except BaseException as e: - if type(e) in RedisCluster.ERRORS_ALLOW_RETRY: + if type(e) in AbstractRedisCluster.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index b9dd0b7210..43af1d1ebf 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,4 +1,4 @@ -from .cluster import RedisClusterCommands +from .cluster import AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser @@ -6,6 +6,7 @@ from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ + "AsyncRedisClusterCommands", "RedisClusterCommands", "CommandsParser", "AsyncCoreCommands", diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index ddeafc43e4..5d9a8b3819 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -1,3 +1,4 @@ +import asyncio from typing import Iterator, Union from redis.crc import key_slot @@ -6,6 +7,11 @@ from .core import ( ACLCommands, + AsyncACLCommands, + AsyncDataAccessCommands, + AsyncFunctionCommands, + AsyncManagementCommands, + AsyncScriptCommands, DataAccessCommands, FunctionCommands, ManagementCommands, @@ -185,7 +191,7 @@ def _partition_keys_by_slot(self, keys): return slots_to_keys - def mget_nonatomic(self, keys, *args): + async def mget_nonatomic(self, keys, *args): """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -210,18 +216,22 @@ def mget_nonatomic(self, keys, *args): # Call MGET for every slot and concatenate # the results # We must make sure that the keys are returned in order - all_results = {} - for slot_keys in slots_to_keys.values(): - slot_values = self.execute_command("MGET", *slot_keys, **options) + all_values = await asyncio.gather( + *[ + self.execute_command("MGET", *slot_keys, **options) + for slot_keys in slots_to_keys.values() + ] + ) - slot_results = dict(zip(slot_keys, slot_values)) - all_results.update(slot_results) + all_results = {} + for slot_keys, slot_values in zip(slots_to_keys.values(), all_values): + all_results.update(dict(zip(slot_keys, slot_values))) # Sort the results vals_in_order = [all_results[key] for key in keys] return vals_in_order - def mset_nonatomic(self, mapping): + async def mset_nonatomic(self, mapping): """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -244,13 +254,11 @@ def mset_nonatomic(self, mapping): # Call MSET for every slot and concatenate # the results (one result per slot) - res = [] - for pairs in slots_to_pairs.values(): - res.append(self.execute_command("MSET", *pairs)) - - return res + return await asyncio.gather( + *[self.execute_command("MSET", *pairs) for pairs in slots_to_pairs.values()] + ) - def _split_command_across_slots(self, command, *keys): + async def _split_command_across_slots(self, command, *keys): """ Runs the given command once for the keys of each slot. Returns the sum of the return values. @@ -259,11 +267,14 @@ def _split_command_across_slots(self, command, *keys): slots_to_keys = self._partition_keys_by_slot(keys) # Sum up the reply from each command - total = 0 - for slot_keys in slots_to_keys.values(): - total += self.execute_command(command, *slot_keys) - - return total + return sum( + await asyncio.gather( + *[ + self.execute_command(command, *slot_keys) + for slot_keys in slots_to_keys.values() + ] + ) + ) def exists(self, *keys): """ @@ -351,7 +362,7 @@ def swapdb(self, *args, **kwargs): raise RedisClusterException("SWAPDB is not supported in cluster" " mode") -class AsyncClusterManagementCommands(ManagementCommands): +class AsyncClusterManagementCommands(AsyncManagementCommands): """ A class for Redis Cluster management commands @@ -475,7 +486,7 @@ def scan_iter( } -class AsyncClusterDataAccessCommands(DataAccessCommands): +class AsyncClusterDataAccessCommands(AsyncDataAccessCommands): """ A class for Redis Cluster Data Access Commands @@ -530,7 +541,7 @@ def stralgo( **kwargs, ) - def scan_iter( + async def scan_iter( self, match: Union[PatternT, None] = None, count: Union[int, None] = None, @@ -538,8 +549,9 @@ def scan_iter( **kwargs, ) -> Iterator: # Do the first query with cursor=0 for all nodes - cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) - yield from data + cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) + for value in data: + yield value cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} if cursors: @@ -550,7 +562,7 @@ def scan_iter( kwargs.pop("target_nodes", None) while cursors: for name, cursor in cursors.items(): - cur, data = self.scan( + cur, data = await self.scan( cursor=cursor, match=match, count=count, @@ -558,7 +570,8 @@ def scan_iter( target_nodes=nodes[name], **kwargs, ) - yield from data + for value in data: + yield value cursors[name] = cur[name] cursors = { @@ -872,14 +885,12 @@ def readwrite(self, target_nodes=None): class AsyncRedisClusterCommands( - ClusterMultiKeyCommands, - ClusterManagementCommands, - ACLCommands, - PubSubCommands, - ClusterDataAccessCommands, - ScriptCommands, - FunctionCommands, - RedisModuleCommands, + AsyncClusterMultiKeyCommands, + AsyncClusterManagementCommands, + AsyncACLCommands, + AsyncClusterDataAccessCommands, + AsyncScriptCommands, + AsyncFunctionCommands, ): """ A class for all Redis Cluster commands @@ -959,7 +970,7 @@ def cluster_count_failure_report(self, node_id): """ return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - def cluster_delslots(self, *slots): + async def cluster_delslots(self, *slots): """ Set hash slots as unbound in the cluster. It determines by it self what node the slot is in and sends it there @@ -968,7 +979,9 @@ def cluster_delslots(self, *slots): For more information see https://redis.io/commands/cluster-delslots """ - return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + return await asyncio.gather( + *[self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + ) def cluster_delslotsrange(self, *slots): """ diff --git a/redis/commands/parser.py b/redis/commands/parser.py index 89292ab2d3..936f2ec97d 100644 --- a/redis/commands/parser.py +++ b/redis/commands/parser.py @@ -12,7 +12,6 @@ class CommandsParser: """ def __init__(self, redis_connection): - self.initialized = False self.commands = {} self.initialize(redis_connection) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 34c22c5a55..f578458f64 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -36,12 +36,18 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ - (True, PythonParser), + pytest.param( + (True, PythonParser), + marks=pytest.mark.skipif( + REDIS_INFO["cluster_enabled"], reason="cluster mode enabled" + ), + ), (False, PythonParser), pytest.param( (True, HiredisParser), marks=pytest.mark.skipif( - not HIREDIS_AVAILABLE, reason="hiredis is not installed" + not HIREDIS_AVAILABLE or REDIS_INFO["cluster_enabled"], + reason="hiredis is not installed or cluster mode enabled", ), ), pytest.param( @@ -62,29 +68,51 @@ def create_redis(request, event_loop: asyncio.BaseEventLoop): """Wrapper around redis.create_redis.""" single_connection, parser_cls = request.param - async def f(url: str = request.config.getoption("--redis-url"), **kwargs): - single = kwargs.pop("single_connection_client", False) or single_connection - parser_class = kwargs.pop("parser_class", None) or parser_cls - url_options = parse_url(url) - url_options.update(kwargs) - pool = redis.ConnectionPool(parser_class=parser_class, **url_options) - client: redis.Redis = redis.Redis(connection_pool=pool) + async def f( + url: str = request.config.getoption("--redis-url"), + cls=redis.Redis, + flushdb=True, + **kwargs, + ): + cluster_mode = REDIS_INFO["cluster_enabled"] + if not cluster_mode: + single = kwargs.pop("single_connection_client", False) or single_connection + parser_class = kwargs.pop("parser_class", None) or parser_cls + url_options = parse_url(url) + url_options.update(kwargs) + pool = redis.ConnectionPool(parser_class=parser_class, **url_options) + client = cls(connection_pool=pool) + else: + client = redis.RedisCluster.from_url(url, **kwargs) + await client.initialize() + single = False if single: client = client.client() await client.initialize() def teardown(): async def ateardown(): - if "username" in kwargs: - return - try: - await client.flushdb() - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - await client.flushdb() - await client.close() - await client.connection_pool.disconnect() + if not cluster_mode: + if "username" in kwargs: + return + if flushdb: + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + else: + if flushdb: + try: + await client.flushdb(target_nodes="primaries") + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb(target_nodes="primaries") + await client.close() if event_loop.is_running(): event_loop.create_task(ateardown()) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index f6634e8782..7da05cf6da 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,23 +1,28 @@ +import asyncio import binascii import datetime +import sys import warnings -from time import sleep -from unittest.mock import DEFAULT, Mock, call, patch import pytest -from redis import Redis -from redis.cluster import ( +from .compat import mock + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from redis.asyncio import Connection, Redis, RedisCluster +from redis.asyncio.cluster import ( PRIMARY, REDIS_CLUSTER_HASH_SLOTS, REPLICA, ClusterNode, NodesManager, - RedisCluster, get_node_name, ) -from redis.commands import CommandsParser -from redis.connection import Connection +from redis.asyncio.parser import CommandsParser from redis.crc import key_slot from redis.exceptions import ( AskError, @@ -31,14 +36,14 @@ ResponseError, ) from redis.utils import str_if_bytes - -from .conftest import ( - _get_client, +from tests.conftest import ( skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, ) +pytestmark = pytest.mark.asyncio + default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ @@ -47,7 +52,7 @@ ] -@pytest.fixture() +@pytest_asyncio.fixture() async def slowlog(request, r): """ Set the slowlog threshold to 0, and the @@ -70,7 +75,7 @@ async def slowlog(request, r): await r.config_set("slowlog-max-len", old_max_length_value) -async def get_mocked_redis_client(func=None, *args, **kwargs): +async def get_mocked_redis_client(*args, **kwargs): """ Return a stable RedisCluster object that have deterministic nodes and slots setup to remove the problem of different IP addresses @@ -79,7 +84,7 @@ async def get_mocked_redis_client(func=None, *args, **kwargs): cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) coverage_res = kwargs.pop("coverage_result", "yes") cluster_enabled = kwargs.pop("cluster_enabled", True) - with patch.object(Redis, "execute_command") as execute_command_mock: + with mock.patch.object(Redis, "execute_command") as execute_command_mock: async def execute_command(*_args, **_kwargs): if _args[0] == "CLUSTER SLOTS": @@ -91,14 +96,12 @@ async def execute_command(*_args, **_kwargs): return {"cluster_enabled": cluster_enabled} elif len(_args) > 1 and _args[1] == "cluster-require-full-coverage": return {"cluster-require-full-coverage": coverage_res} - elif func is not None: - return func(*args, **kwargs) else: - return execute_command_mock(*_args, **_kwargs) + return await execute_command_mock(*_args, **_kwargs) execute_command_mock.side_effect = execute_command - with patch.object( + with mock.patch.object( CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: @@ -116,36 +119,23 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock - return RedisCluster(*args, **kwargs) + return await RedisCluster(*args, **kwargs) def mock_node_resp(node, response): - connection = Mock() + connection = mock.AsyncMock() connection.read_response.return_value = response node.redis_connection.connection = connection return node -def mock_node_resp_func(node, func): - connection = Mock() - connection.read_response.side_effect = func - node.redis_connection.connection = connection - return node - - def mock_all_nodes_resp(rc, response): for node in rc.get_nodes(): mock_node_resp(node, response) return rc -def find_node_ip_based_on_port(cluster_client, port): - for node in cluster_client.get_nodes(): - if node.port == port: - return node.host - - -async def moved_redirection_helper(request, failover=False): +async def moved_redirection_helper(request, create_redis, failover=False): """ Test that the client handles MOVED response after a failover. Redirection after a failover means that the redirection address is of a @@ -162,11 +152,11 @@ async def moved_redirection_helper(request, failover=False): 3. the redirected node's server type updated to 'primary' 4. the server type of the previous slot owner updated to 'replica' """ - rc = _get_client(RedisCluster, request, flushdb=False) + rc = await create_redis(cls=RedisCluster, flushdb=False) slot = 12182 redirect_node = None # Get the current primary that holds this slot - prev_primary = rc.nodes_manager.get_node_from_slot(slot) + prev_primary = await rc.nodes_manager.get_node_from_slot(slot) if failover: if len(rc.nodes_manager.slots_cache[slot]) < 2: warnings.warn("Skipping this test since it requires to have a " "replica") @@ -177,7 +167,7 @@ async def moved_redirection_helper(request, failover=False): redirect_node = rc.get_primaries()[0] r_host = redirect_node.host r_port = redirect_node.port - with patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object(Redis, "parse_response") as parse_response: def moved_redirect_effect(connection, *args, **options): def ok_response(connection, *args, **options): @@ -212,6 +202,8 @@ async def test_host_port_startup_node(self): cluster = await get_mocked_redis_client(host=default_host, port=default_port) assert cluster.get_node(host=default_host, port=default_port) is not None + await cluster.close() + async def test_startup_nodes(self): """ Test that it is possible to use startup_nodes @@ -229,6 +221,8 @@ async def test_startup_nodes(self): and cluster.get_node(host=default_host, port=port_2) is not None ) + await cluster.close() + async def test_empty_startup_nodes(self): """ Test that exception is raised when empty providing empty startup_nodes @@ -242,7 +236,7 @@ async def test_empty_startup_nodes(self): async def test_from_url(self, r): redis_url = f"redis://{default_host}:{default_port}/0" - with patch.object(RedisCluster, "from_url") as from_url: + with mock.patch.object(RedisCluster, "from_url") as from_url: async def from_url_mocked(_url, **_kwargs): return await get_mocked_redis_client(url=_url, **_kwargs) @@ -251,6 +245,8 @@ async def from_url_mocked(_url, **_kwargs): cluster = await RedisCluster.from_url(redis_url) assert cluster.get_node(host=default_host, port=default_port) is not None + await cluster.close() + async def test_execute_command_errors(self, r): """ Test that if no key is provided then exception should be raised. @@ -293,6 +289,8 @@ async def test_execute_command_node_flag_replicas(self, r): conn = primary.redis_connection.connection assert conn.read_response.called is not True + await r.close() + async def test_execute_command_node_flag_all_nodes(self, r): """ Test command execution with nodes flag ALL_NODES @@ -337,7 +335,7 @@ async def test_ask_redirection(self, r): Important thing to verify is that it tries to talk to the second node. """ redirect_node = r.get_nodes()[0] - with patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object(Redis, "parse_response") as parse_response: def ask_redirect_effect(connection, *args, **options): def ok_response(connection, *args, **options): @@ -353,29 +351,34 @@ def ok_response(connection, *args, **options): assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" - async def test_moved_redirection(self, request): + async def test_moved_redirection(self, request, create_redis): """ Test that the client handles MOVED response. """ - await moved_redirection_helper(request, failover=False) + await moved_redirection_helper(request, create_redis, failover=False) - async def test_moved_redirection_after_failover(self, request): + async def test_moved_redirection_after_failover(self, request, create_redis): """ Test that the client handles MOVED response after a failover. """ - await moved_redirection_helper(request, failover=True) + await moved_redirection_helper(request, create_redis, failover=True) - async def test_refresh_using_specific_nodes(self, request): + async def test_refresh_using_specific_nodes(self, request, create_redis): """ Test making calls on specific nodes when the cluster has failed over to another node """ node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) - with patch.object(Redis, "parse_response") as parse_response: - with patch.object(NodesManager, "initialize", autospec=True) as initialize: - with patch.multiple( - Connection, send_command=DEFAULT, connect=DEFAULT, can_read=DEFAULT + with mock.patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object( + NodesManager, "initialize", autospec=True + ) as initialize: + with mock.patch.multiple( + Connection, + send_command=mock.DEFAULT, + connect=mock.DEFAULT, + can_read=mock.DEFAULT, ) as mocks: # simulate 7006 as a failed node def parse_response_mock(connection, command_name, **options): @@ -418,7 +421,7 @@ def map_7007(self): mocks["can_read"].return_value = False mocks["send_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None - with patch.object( + with mock.patch.object( CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: @@ -436,7 +439,7 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock - rc = _get_client(RedisCluster, request, flushdb=False) + rc = await create_redis(cls=RedisCluster, flushdb=False) assert len(rc.get_nodes()) == 1 assert rc.get_node(node_name=node_7006.name) is not None @@ -451,15 +454,15 @@ def cmd_init_mock(self, r): assert parse_response.successful_calls == 1 async def test_reading_from_replicas_in_round_robin(self): - with patch.multiple( + with mock.patch.multiple( Connection, - send_command=DEFAULT, - read_response=DEFAULT, - _connect=DEFAULT, - can_read=DEFAULT, - on_connect=DEFAULT, + send_command=mock.DEFAULT, + read_response=mock.DEFAULT, + _connect=mock.DEFAULT, + can_read=mock.DEFAULT, + on_connect=mock.DEFAULT, ) as mocks: - with patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object(Redis, "parse_response") as parse_response: def parse_response_mock_first(connection, *args, **options): # Primary @@ -500,7 +503,9 @@ def parse_response_mock_third(connection, *args, **options): await read_cluster.get("foo") await read_cluster.get("foo") await read_cluster.get("foo") - mocks["send_command"].assert_has_calls([call("READONLY")]) + mocks["send_command"].assert_has_calls([mock.call("READONLY")]) + + await read_cluster.close() async def test_keyslot(self, r): """ @@ -555,7 +560,7 @@ async def test_cluster_down_overreaches_retry_attempts(self, error): the command as many times as configured in cluster_error_retry_attempts and then raise the exception """ - with patch.object(RedisCluster, "_execute_command") as execute_command: + with mock.patch.object(RedisCluster, "_execute_command") as execute_command: def raise_error(target_node, *args, **kwargs): execute_command.failed_calls += 1 @@ -569,18 +574,7 @@ def raise_error(target_node, *args, **kwargs): await rc.get("bar") assert execute_command.failed_calls == rc.cluster_error_retry_attempts - async def test_user_on_connect_function(self, request): - """ - Test support in passing on_connect function by the user - """ - - def on_connect(connection): - assert connection is not None - - mock = Mock(side_effect=on_connect) - - _get_client(RedisCluster, request, redis_connect_func=mock) - assert mock.called is True + await rc.close() async def test_set_default_node_success(self, r): """ @@ -737,7 +731,7 @@ async def test_unlink(self, r): assert await r.unlink(*d.keys()) == len(d) # Unlink is non-blocking so we sleep before # verifying the deletion - sleep(0.1) + await asyncio.sleep(0.1) assert await r.unlink(*d.keys()) == 0 @skip_if_redis_enterprise() @@ -770,7 +764,7 @@ async def test_cluster_addslotsrange(self, r): @skip_if_redis_enterprise() async def test_cluster_countkeysinslot(self, r): - node = r.nodes_manager.get_node_from_slot(1) + node = await r.nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) assert await r.cluster_countkeysinslot(1) == 2 @@ -794,6 +788,8 @@ async def test_cluster_delslots(self): assert node0.redis_connection.connection.read_response.called assert node1.redis_connection.connection.read_response.called + await r.close() + @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() async def test_cluster_delslotsrange(self, r): @@ -930,7 +926,7 @@ async def test_cluster_save_config(self, r): @skip_if_redis_enterprise() async def test_cluster_get_keys_in_slot(self, r): response = [b"{foo}1", b"{foo}2"] - node = r.nodes_manager.get_node_from_slot(12182) + node = await r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = await r.cluster_get_keys_in_slot(12182, 4) assert keys == response @@ -956,7 +952,7 @@ async def test_cluster_setslot(self, r): await r.cluster_failover(node, "STATE") async def test_cluster_setslot_stable(self, r): - node = r.nodes_manager.get_node_from_slot(12182) + node = await r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @@ -1003,6 +999,8 @@ async def test_readonly(self): for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called + await r.close() + @skip_if_redis_enterprise() async def test_readwrite(self): r = await get_mocked_redis_client(host=default_host, port=default_port) @@ -1014,10 +1012,12 @@ async def test_readwrite(self): for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called + await r.close() + @skip_if_redis_enterprise() async def test_bgsave(self, r): assert await r.bgsave() - sleep(0.3) + await asyncio.sleep(0.3) assert await r.bgsave(True) async def test_info(self, r): @@ -1027,7 +1027,7 @@ async def test_info(self, r): await r.set("z{1}", 3) # Get node that handles the slot slot = r.keyslot("x{1}") - node = r.nodes_manager.get_node_from_slot(slot) + node = await r.nodes_manager.get_node_from_slot(slot) # Run info on that node info = await r.info(target_nodes=node) assert isinstance(info, dict) @@ -1085,7 +1085,7 @@ async def test_slowlog_get_limit(self, r, slowlog): async def test_slowlog_length(self, r, slowlog): await r.get("foo") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = await r.nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = await r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) @@ -1111,7 +1111,7 @@ async def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data await r.set("foo", "bar") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = await r.nodes_manager.get_node_from_slot(key_slot(b"foo")) stats = await r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): @@ -1864,15 +1864,11 @@ async def test_cluster_randomkey(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() - async def test_acl_log(self, r, request): + async def test_acl_log(self, r, request, create_redis): key = "{cache}:" node = r.get_node_from_key(key) username = "redis-py-user" - async def teardown(): - await r.acl_deluser(username, target_nodes="primaries") - - request.addfinalizer(teardown) await r.acl_setuser( username, enabled=True, @@ -1884,8 +1880,8 @@ async def teardown(): ) await r.acl_log_reset(target_nodes=node) - user_client = _get_client( - RedisCluster, request, flushdb=False, username=username + user_client = await create_redis( + cls=RedisCluster, flushdb=False, username=username ) # Valid operation and key @@ -1907,6 +1903,10 @@ async def teardown(): assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] assert await r.acl_log_reset(target_nodes=node) + yield + + await r.acl_deluser(username, target_nodes="primaries") + @pytest.mark.onlycluster class TestNodesManager: @@ -1957,12 +1957,13 @@ async def test_init_slots_cache_not_all_slots_covered(self): [10923, 16383, ["127.0.0.1", 7002], ["127.0.0.1", 7005]], ] with pytest.raises(RedisClusterException) as ex: - await get_mocked_redis_client( + rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_slots=cluster_slots, require_full_coverage=True, ) + await rc.close() assert str(ex.value).startswith( "All slots are not covered after query all startup_nodes." ) @@ -1988,6 +1989,8 @@ async def test_init_slots_cache_not_require_full_coverage_success(self): assert 5460 not in rc.nodes_manager.slots_cache + await rc.close() + async def test_init_slots_cache(self): """ Test that slots cache can in initialized and all slots are covered @@ -2017,16 +2020,19 @@ async def test_init_slots_cache(self): assert len(n_manager.nodes_cache) == 6 + await rc.close() + async def test_init_slots_cache_cluster_mode_disabled(self): """ Test that creating a RedisCluster failes if one of the startup nodes has cluster mode disabled """ with pytest.raises(RedisClusterException) as e: - await get_mocked_redis_client( + rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_enabled=False ) - assert "Cluster mode is not enabled on this node" in str(e.value) + await rc.close() + assert "Cluster mode is not enabled on this node" in str(e.value) async def test_empty_startup_nodes(self): """ @@ -2034,7 +2040,7 @@ async def test_empty_startup_nodes(self): specified """ with pytest.raises(RedisClusterException): - NodesManager([]) + await NodesManager([]).initialize() async def test_wrong_startup_nodes_type(self): """ @@ -2042,7 +2048,7 @@ async def test_wrong_startup_nodes_type(self): fail """ with pytest.raises(RedisClusterException): - NodesManager({}) + await NodesManager({}).initialize() async def test_init_slots_cache_slots_collision(self, request): """ @@ -2050,9 +2056,9 @@ async def test_init_slots_cache_slots_collision(self, request): raise an error. In this test both nodes will say that the first slots block should be bound to different servers. """ - with patch.object(NodesManager, "create_redis_node") as create_redis_node: + with mock.patch.object(NodesManager, "create_redis_node") as create_redis_node: - def create_mocked_redis_node(host, port, **kwargs): + async def create_mocked_redis_node(host, port, **kwargs): """ Helper function to return custom slots cache data from different redis nodes @@ -2075,7 +2081,7 @@ def create_mocked_redis_node(host, port, **kwargs): orig_execute_command = r_node.execute_command - def execute_command(*args, **kwargs): + async def execute_command(*args, **kwargs): if args[0] == "CLUSTER SLOTS": return result elif args[0] == "INFO": @@ -2093,7 +2099,9 @@ def execute_command(*args, **kwargs): with pytest.raises(RedisClusterException) as ex: node_1 = ClusterNode("127.0.0.1", 7000) node_2 = ClusterNode("127.0.0.1", 7001) - RedisCluster(startup_nodes=[node_1, node_2]) + rc = RedisCluster(startup_nodes=[node_1, node_2]) + await rc.initialize() + await rc.close() assert str(ex.value).startswith( "startup_nodes could not agree on a valid slots cache" ), str(ex.value) @@ -2119,20 +2127,22 @@ async def test_cluster_one_instance(self): for i in range(0, REDIS_CLUSTER_HASH_SLOTS): assert n.slots_cache[i] == [n_node] + await rc.close() + async def test_init_with_down_node(self): """ If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. """ - with patch.object(NodesManager, "create_redis_node") as create_redis_node: + with mock.patch.object(NodesManager, "create_redis_node") as create_redis_node: - def create_mocked_redis_node(host, port, **kwargs): + async def create_mocked_redis_node(host, port, **kwargs): if port == 7000: raise ConnectionError("mock connection error for 7000") r_node = Redis(host=host, port=port, decode_responses=True) - def execute_command(*args, **kwargs): + async def execute_command(*args, **kwargs): if args[0] == "CLUSTER SLOTS": return [ [0, 8191, ["127.0.0.1", 7001, "node_1"]], @@ -2155,10 +2165,12 @@ def execute_command(*args, **kwargs): # If all startup nodes fail to connect, connection error should be # thrown with pytest.raises(RedisClusterException) as e: - RedisCluster(startup_nodes=[node_1]) + rc = RedisCluster(startup_nodes=[node_1]) + await rc.initialize() + await rc.close() assert "Redis Cluster cannot be connected" in str(e.value) - with patch.object( + with mock.patch.object( CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: @@ -2178,5 +2190,7 @@ def cmd_init_mock(self, r): # When at least one startup node is reachable, the cluster # initialization should succeeds rc = RedisCluster(startup_nodes=[node_1, node_2]) + await rc.initialize() assert rc.get_node(host=default_host, port=7001) is not None assert rc.get_node(host=default_host, port=7002) is not None + await rc.close() diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 9bb7480a42..e0be541e4a 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -4,11 +4,17 @@ import binascii import datetime import re +import sys import time from string import ascii_letters import pytest +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + import redis from redis import exceptions from redis.client import parse_info @@ -21,10 +27,10 @@ REDIS_6_VERSION = "5.9.0" -pytestmark = [pytest.mark.asyncio, pytest.mark.onlynoncluster] +pytestmark = pytest.mark.asyncio -@pytest.fixture() +@pytest_asyncio.fixture() async def slowlog(r: redis.Redis, event_loop): current_config = await r.config_get() old_slower_than_value = current_config["slowlog-log-slower-than"] @@ -1923,7 +1929,8 @@ async def test_hmget(self, r: redis.Redis): async def test_hmset(self, r: redis.Redis): warning_message = ( - r"^Redis\.hmset\(\) is deprecated\. " r"Use Redis\.hset\(\) instead\.$" + r"^Redis(?:Cluster)*\.hmset\(\) is deprecated\. " + r"Use Redis(?:Cluster)*\.hset\(\) instead\.$" ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} with pytest.warns(DeprecationWarning, match=warning_message): @@ -2929,6 +2936,7 @@ async def test_xtrim(self, r: redis.Redis): # 1 message is trimmed assert await r.xtrim(stream, 3, approximate=False) == 1 + @pytest.mark.onlynoncluster async def test_bitfield_operations(self, r: redis.Redis): # comments show affected bits await r.execute_command("SELECT", 10) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 46abec01d6..f6259adbd2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -51,14 +51,12 @@ def inner(): # assert mod.get('fookey') == d -@pytest.mark.onlynoncluster async def test_socket_param_regression(r): """A regression test for issue #1060""" conn = UnixDomainSocketConnection() _ = await conn.disconnect() is True -@pytest.mark.onlynoncluster async def test_can_run_concurrent_commands(r): assert await r.ping() is True assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index b5a5e77f1f..2cd9480c08 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -20,6 +20,7 @@ pytestmark = pytest.mark.asyncio +@pytest.mark.onlynoncluster class TestRedisAutoReleaseConnectionPool: @pytest_asyncio.fixture async def r(self, create_redis) -> redis.Redis: @@ -112,7 +113,6 @@ async def can_read(self, timeout: float = 0): return False -@pytest.mark.onlynoncluster class TestConnectionPool: def get_pool( self, @@ -189,7 +189,6 @@ def test_repr_contains_db_info_unix(self): assert repr(pool) == expected -@pytest.mark.onlynoncluster class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} @@ -296,7 +295,6 @@ def test_repr_contains_db_info_unix(self): assert repr(pool) == expected -@pytest.mark.onlynoncluster class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url("redis://my.host") @@ -439,7 +437,6 @@ def test_invalid_scheme_raises_error(self): ) -@pytest.mark.onlynoncluster class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") @@ -508,7 +505,6 @@ def test_extra_querystring_options(self): assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} -@pytest.mark.onlynoncluster class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") @@ -538,7 +534,6 @@ def get_connection(self, *args, **kwargs): assert pool.get_connection("_").check_hostname is True -@pytest.mark.onlynoncluster class TestConnection: async def test_on_connect_error(self): """ diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index e83e001847..6e277ae38f 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -19,7 +19,6 @@ def compute(self, failures): return 0 -@pytest.mark.onlynoncluster class TestConnectionConstructorWithRetry: "Test that the Connection constructors properly handles Retry objects" @@ -41,7 +40,6 @@ def test_retry_on_timeout_retry(self, Class, retries: int): assert c.retry._retries == retries -@pytest.mark.onlynoncluster class TestRetry: "Test that Retry calls backoff and retries the expected number of times" diff --git a/tox.ini b/tox.ini index a880da45a5..4641ec3638 100644 --- a/tox.ini +++ b/tox.ini @@ -291,7 +291,7 @@ commands = standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} {posargs} - cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} + cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --uvloop {posargs} [testenv:redis5] deps = From 921152bcb22dcaadd4e7d6acd255049305dbafb6 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Thu, 14 Apr 2022 16:45:20 +0530 Subject: [PATCH 04/23] cluster: use ERRORS_ALLOW_RETRY from self.__class__ --- redis/asyncio/cluster.py | 2 +- redis/cluster.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7651cbebff..f489cdcdbb 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -565,7 +565,7 @@ async def execute_command(self, *args, **kwargs): # Return the processed result return self._process_result(args[0], res, **kwargs) except BaseException as e: - if type(e) in AbstractRedisCluster.ERRORS_ALLOW_RETRY: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e diff --git a/redis/cluster.py b/redis/cluster.py index 9ede6b6eb9..b61d384791 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1035,7 +1035,7 @@ def execute_command(self, *args, **kwargs): # Return the processed result return self._process_result(args[0], res, **kwargs) except BaseException as e: - if type(e) in AbstractRedisCluster.ERRORS_ALLOW_RETRY: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. exception = e From 48c3530dd4f65f2757424618657ffc8e4cc9e221 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 15 Apr 2022 06:52:41 +0530 Subject: [PATCH 05/23] async_cluster: rework redis_connection, initialize, & close - move redis_connection from NodesManager to ClusterNode & handle all related logic in ClusterNode class - use Locks while initializing or closing - in case of error, close connections instead of instantly reinitializing - create ResourceWarning instead of manually deleting client object - use asyncio.gather to run commands/initialize/close in parallel - inline single use functions - fix test_acl_log for py3.6 --- redis/asyncio/cluster.py | 321 ++++++++++++++--------------- tests/test_asyncio/test_cluster.py | 44 ++-- 2 files changed, 172 insertions(+), 193 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index f489cdcdbb..b21de3a68a 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,8 +1,8 @@ import asyncio -import copy import logging import random import socket +import warnings from redis.asyncio.client import Redis from redis.asyncio.connection import ConnectionPool, DefaultParser, Encoder, parse_url @@ -40,12 +40,6 @@ log = logging.getLogger(__name__) -async def get_connection(redis_node, *args, **options): - return redis_node.connection or await redis_node.connection_pool.get_connection( - args[0], **options - ) - - class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( DefaultParser.EXCEPTION_CLASSES, @@ -222,7 +216,6 @@ def __init__( self.read_from_replicas = read_from_replicas self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - self.nodes_manager = None self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, @@ -238,13 +231,28 @@ def __init__( self._initialize = True self._lock = asyncio.Lock() - async def initialize(self, force=False): - if self._initialize or force: - self._initialize = False - await self.nodes_manager.initialize() - await self.commands_parser.initialize(self) + async def initialize(self): + if self._initialize: + async with self._lock: + if self._initialize: + self._initialize = False + try: + await self.nodes_manager.initialize() + await self.commands_parser.initialize(self) + except BaseException: + self._initialize = True + await self.nodes_manager.close() + await self.nodes_manager.close("startup_nodes") + raise return self + async def close(self): + if not self._initialize: + async with self._lock: + if not self._initialize: + self._initialize = True + await self.nodes_manager.close() + async def __aenter__(self): return await self.initialize() @@ -254,16 +262,19 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __await__(self): return self.initialize().__await__() - def __del__(self): - try: - loop = asyncio.get_event_loop() - coro = self.close() - if loop.is_running(): - loop.create_task(coro) - else: - loop.run_until_complete(coro) - except Exception: - pass + _DEL_MESSAGE = "Unclosed RedisCluster client" + + def __del__(self, _warnings=warnings): + if hasattr(self, "_initialize") and not self._initialize: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... async def on_connect(self, connection): """ @@ -283,13 +294,6 @@ async def on_connect(self, connection): if str_if_bytes(await connection.read_response()) != "OK": raise ConnectionError("READONLY command failed") - async def get_redis_connection(self, node): - if not node.redis_connection: - async with self._lock: - if not node.redis_connection: - await self.nodes_manager.create_redis_connections([node]) - return node.redis_connection - def get_node(self, host=None, port=None, node_name=None): return self.nodes_manager.get_node(host, port, node_name) @@ -388,16 +392,6 @@ async def _determine_nodes(self, *args, **kwargs): log.debug(f"Target for {args}: slot {slot}") return [node] - def _should_reinitialized(self): - # To reinitialize the cluster on every MOVED error, - # set reinitialize_steps to 1. - # To avoid reinitializing the cluster on moved errors, set - # reinitialize_steps to 0. - if self.reinitialize_steps == 0: - return False - else: - return self.reinitialize_counter % self.reinitialize_steps == 0 - def keyslot(self, key): """ Calculate keyslot for a given key. @@ -406,21 +400,6 @@ def keyslot(self, key): k = self.encoder.encode(key) return key_slot(k) - async def _get_command_keys(self, *args): - """ - Get the keys in the command. If the command has no keys in in, None is - returned. - - NOTE: Due to a bug in redis<7.0, this function does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this function with EVAL or EVALSHA. - """ - redis_conn = self.get_default_node().redis_connection - return await self.commands_parser.get_keys(redis_conn, *args) - async def determine_slot(self, *args): """ Figure out what slot to use based on args. @@ -440,6 +419,8 @@ async def determine_slot(self, *args): # redis server to parse the keys. Besides, there is a bug in redis<7.0 # where `self._get_command_keys()` fails anyway. So, we special case # EVAL/EVALSHA. + # - issue: https://github.com/redis/redis/issues/9493 + # - fix: https://github.com/redis/redis/pull/9733 if command in ("EVAL", "EVALSHA"): # command syntax: EVAL "script body" num_keys ... if len(args) <= 2: @@ -452,7 +433,10 @@ async def determine_slot(self, *args): return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) keys = eval_keys else: - keys = await self._get_command_keys(*args) + redis_connection = await self.get_default_node().initialize( + **self.get_connection_kwargs() + ) + keys = await self.commands_parser.get_keys(redis_connection, *args) if keys is None or len(keys) == 0: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot @@ -529,7 +513,6 @@ async def execute_command(self, *args, **kwargs): list dict """ - await self.initialize() target_nodes_specified = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) @@ -549,8 +532,8 @@ async def execute_command(self, *args, **kwargs): ) exception = None for _ in range(0, retry_attempts): + await self.initialize() try: - res = {} if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = await self._determine_nodes( @@ -560,10 +543,16 @@ async def execute_command(self, *args, **kwargs): raise RedisClusterException( f"No targets were found to execute {args} command on" ) - for node in target_nodes: - res[node.name] = await self._execute_command(node, *args, **kwargs) + + keys = [node.name for node in target_nodes] + values = await asyncio.gather( + *[ + self._execute_command(node, *args, **kwargs) + for node in target_nodes + ] + ) # Return the processed result - return self._process_result(args[0], res, **kwargs) + return self._process_result(args[0], dict(zip(keys, values)), **kwargs) except BaseException as e: if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. @@ -582,7 +571,7 @@ async def _execute_command(self, target_node, *args, **kwargs): Send a command to a node in the cluster """ command = args[0] - redis_node = None + redis_connection = None connection = None redirect_addr = None asking = False @@ -608,15 +597,25 @@ async def _execute_command(self, target_node, *args, **kwargs): f"Executing command {command} on target node: " f"{target_node.server_type} {target_node.name}" ) - redis_node = await self.get_redis_connection(target_node) - connection = await get_connection(redis_node, *args, **kwargs) + redis_connection = await target_node.initialize( + **self.get_connection_kwargs() + ) + connection = ( + redis_connection.connection + or await redis_connection.connection_pool.get_connection( + command, **kwargs + ) + ) + if asking: await connection.send_command("ASKING") - await redis_node.parse_response(connection, "ASKING", **kwargs) + await redis_connection.parse_response( + connection, "ASKING", **kwargs + ) asking = False await connection.send_command(*args) - response = await redis_node.parse_response( + response = await redis_connection.parse_response( connection, command, **kwargs ) if command in self.cluster_response_callbacks: @@ -646,7 +645,7 @@ async def _execute_command(self, target_node, *args, **kwargs): else: # Hard force of reinitialize of the node/slots setup # and try again with the new setup - await self.initialize(force=True) + await self.close() raise except MovedError as e: # First, we will try to patch the slots/nodes cache with the @@ -659,12 +658,15 @@ async def _execute_command(self, target_node, *args, **kwargs): # RedisCluster constructor. log.exception("MovedError") self.reinitialize_counter += 1 - if self._should_reinitialized(): - await self.initialize(force=True) + if ( + self.reinitialize_steps + and self.reinitialize_counter % self.reinitialize_steps == 0 + ): + await self.close() # Reset the counter self.reinitialize_counter = 0 else: - self.nodes_manager.update_moved_exception(e) + self.nodes_manager._moved_exception = e moved = True except TryAgainError: log.exception("TryAgainError") @@ -682,7 +684,7 @@ async def _execute_command(self, target_node, *args, **kwargs): # self-healed, we will try to reinitialize the cluster layout # and retry executing the command await asyncio.sleep(0.25) - await self.initialize(force=True) + await self.close() raise e except ResponseError as e: message = e.__str__() @@ -695,19 +697,10 @@ async def _execute_command(self, target_node, *args, **kwargs): raise e finally: if connection is not None: - await redis_node.connection_pool.release(connection) + await redis_connection.connection_pool.release(connection) raise ClusterError("TTL exhausted.") - async def close(self): - try: - async with self._lock: - if self.nodes_manager: - await self.nodes_manager.close() - except AttributeError: - # RedisCluster's __init__ can fail before nodes_manager is set - pass - def _process_result(self, command, res, **kwargs): """ Process the result of the executed command. @@ -739,6 +732,7 @@ def __init__(self, host, port, server_type=None, redis_connection=None): self.name = get_node_name(host, port) self.server_type = server_type self.redis_connection = redis_connection + self._lock = asyncio.Lock() def __repr__(self): return ( @@ -752,41 +746,58 @@ def __repr__(self): def __eq__(self, obj): return isinstance(obj, ClusterNode) and obj.name == self.name - def __del__(self): - try: - if self.redis_connection is not None: - loop = asyncio.get_event_loop() - coro = self.redis_connection.close(True) - if loop.is_running(): - loop.create_task(coro) - else: - loop.run_until_complete(coro) - except Exception: - pass + _DEL_MESSAGE = "Unclosed ClusterNode object" + + def __del__(self, _warnings=warnings): + if hasattr(self, "redis_connection") and self.redis_connection: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + + async def initialize(self, from_url=False, **kwargs): + if not self.redis_connection: + async with self._lock: + if not self.redis_connection: + if from_url: + # Create a redis node with a costumed connection pool + kwargs.update(host=self.host, port=self.port) + conn = Redis(connection_pool=ConnectionPool(**kwargs)) + else: + conn = Redis(host=self.host, port=self.port, **kwargs) + + self.redis_connection = await conn.initialize() + + return self.redis_connection + + async def close(self): + if self.redis_connection: + async with self._lock: + if self.redis_connection: + conn = self.redis_connection + self.redis_connection = None + await conn.close(True) class NodesManager: def __init__( - self, - startup_nodes, - from_url=False, - require_full_coverage=False, - lock=None, - **kwargs, + self, startup_nodes, from_url=False, require_full_coverage=False, **kwargs ): self.nodes_cache = {} self.slots_cache = {} - self.startup_nodes = {} + self.startup_nodes = {node.name: node for node in startup_nodes} self.default_node = None - self.populate_startup_nodes(startup_nodes) self.from_url = from_url self._require_full_coverage = require_full_coverage self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() - if lock is None: - lock = asyncio.Lock() - self._lock = lock + self._lock = asyncio.Lock() def get_node(self, host=None, port=None, node_name=None): """ @@ -809,10 +820,17 @@ def get_node(self, host=None, port=None, node_name=None): ) return None - def update_moved_exception(self, exception): - self._moved_exception = exception + async def set_nodes(self, old, new): + tasks = [node.close() for name, node in old.items() if name not in new] + for name, node in new.items(): + if name in old: + if old[name] is node: + continue + tasks.append(old[name].close()) + old[name] = node + await asyncio.gather(*tasks) - def _update_moved_slots(self): + async def _update_moved_slots(self): """ Update the slot's node with the redirected one """ @@ -826,7 +844,9 @@ def _update_moved_slots(self): else: # This is a new node, we will add it to the nodes cache redirected_node = ClusterNode(e.host, e.port, PRIMARY) - self.nodes_cache[redirected_node.name] = redirected_node + await self.set_nodes( + self.nodes_cache, {redirected_node.name: redirected_node} + ) if redirected_node in self.slots_cache[e.slot_id]: # The MOVED error resulted from a failover, and the new slot owner # had previously been a replica. @@ -860,7 +880,7 @@ async def get_node_from_slot( if self._moved_exception: async with self._lock: if self._moved_exception: - self._update_moved_slots() + await self._update_moved_slots() if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: raise SlotNotCoveredError( @@ -900,13 +920,6 @@ def get_nodes_by_server_type(self, server_type): if node.server_type == server_type ] - def populate_startup_nodes(self, nodes): - """ - Populate all startup nodes and filters out any duplicates - """ - for n in nodes: - self.startup_nodes[n.name] = n - def check_slots_coverage(self, slots_cache): # Validate if all slots are covered or if we should try next # startup node @@ -915,27 +928,6 @@ def check_slots_coverage(self, slots_cache): return False return True - async def create_redis_connections(self, nodes): - """ - This function will create a redis connection to all nodes in :nodes: - """ - for node in nodes: - if node.redis_connection is None: - node.redis_connection = await self.create_redis_node( - host=node.host, port=node.port, **self.connection_kwargs - ) - - async def create_redis_node(self, host, port, **kwargs): - if self.from_url: - # Create a redis node with a costumed connection pool - kwargs.update({"host": host}) - kwargs.update({"port": port}) - r = Redis(connection_pool=ConnectionPool(**kwargs)) - else: - r = Redis(host=host, port=port, **kwargs) - await r.initialize() - return r - async def initialize(self): """ Initializes the nodes cache, slots cache and redis connections. @@ -949,35 +941,20 @@ async def initialize(self): disagreements = [] startup_nodes_reachable = False fully_covered = False - kwargs = self.connection_kwargs for startup_node in self.startup_nodes.values(): try: - if startup_node.redis_connection: - r = startup_node.redis_connection - else: - # Create a new Redis connection and let Redis decode the - # responses so we won't need to handle that - # TODO: redis_connect_func shouldn't need to be removed & readded - redis_connect_func = kwargs.pop("redis_connect_func") - copy_kwargs = copy.deepcopy(kwargs) - kwargs.setdefault("redis_connect_func", redis_connect_func) - copy_kwargs.update( - { - "decode_responses": True, - "encoding": "utf-8", - "redis_connect_func": redis_connect_func, - } - ) - r = await self.create_redis_node( - startup_node.host, startup_node.port, **copy_kwargs - ) - self.startup_nodes[startup_node.name].redis_connection = r + redis_connection = await startup_node.initialize( + **self.connection_kwargs + ) + # Make sure cluster mode is enabled on this node - if bool((await r.info()).get("cluster_enabled")) is False: + if not (await redis_connection.info()).get("cluster_enabled"): raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = str_if_bytes(await r.execute_command("CLUSTER SLOTS")) + cluster_slots = str_if_bytes( + await redis_connection.execute_command("CLUSTER SLOTS") + ) startup_nodes_reachable = True except (ConnectionError, TimeoutError) as e: msg = e.__str__ @@ -1022,6 +999,8 @@ async def initialize(self): cluster_slots[0][2][0] = startup_node.host for slot in cluster_slots: + for i in range(2, len(slot)): + slot[i] = [str_if_bytes(val) for val in slot[i]] primary_node = slot[2] host = primary_node[0] if host == "": @@ -1081,9 +1060,6 @@ async def initialize(self): "one reachable node. " ) - # Create Redis connections to all nodes - await self.create_redis_connections(list(tmp_nodes_cache.values())) - # Check if the slots are not fully covered if not fully_covered and self._require_full_coverage: # Despite the requirement that the slots be covered, there @@ -1095,20 +1071,27 @@ async def initialize(self): ) # Set the tmp variables to the real variables - self.nodes_cache = tmp_nodes_cache self.slots_cache = tmp_slots + await self.set_nodes(self.nodes_cache, tmp_nodes_cache) + # Populate the startup nodes with all discovered nodes + await self.set_nodes(self.startup_nodes, self.nodes_cache) + + # Create Redis connections to all nodes + await asyncio.gather( + *[ + node.initialize(**self.connection_kwargs) + for node in self.nodes_cache.values() + ] + ) + # Set the default node self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] - # Populate the startup nodes with all discovered nodes - self.populate_startup_nodes(self.nodes_cache.values()) # If initialize was called after a MovedError, clear it self._moved_exception = None - async def close(self): + async def close(self, attr="nodes_cache"): self.default_node = None - for node in self.nodes_cache.values(): - if node.redis_connection: - await node.redis_connection.close(True) + await asyncio.gather(*[node.close() for node in getattr(self, attr).values()]) def reset(self): try: diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 7da05cf6da..b5954c7d66 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1903,10 +1903,10 @@ async def test_acl_log(self, r, request, create_redis): assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] assert await r.acl_log_reset(target_nodes=node) - yield - await r.acl_deluser(username, target_nodes="primaries") + await user_client.close() + @pytest.mark.onlycluster class TestNodesManager: @@ -2056,20 +2056,20 @@ async def test_init_slots_cache_slots_collision(self, request): raise an error. In this test both nodes will say that the first slots block should be bound to different servers. """ - with mock.patch.object(NodesManager, "create_redis_node") as create_redis_node: + with mock.patch.object(ClusterNode, "initialize", autospec=True) as initialize: - async def create_mocked_redis_node(host, port, **kwargs): + async def mocked_initialize(self, **kwargs): """ Helper function to return custom slots cache data from different redis nodes """ - if port == 7000: + if self.port == 7000: result = [ [0, 5460, ["127.0.0.1", 7000], ["127.0.0.1", 7003]], [5461, 10922, ["127.0.0.1", 7001], ["127.0.0.1", 7004]], ] - elif port == 7001: + elif self.port == 7001: result = [ [0, 5460, ["127.0.0.1", 7001], ["127.0.0.1", 7003]], [5461, 10922, ["127.0.0.1", 7000], ["127.0.0.1", 7004]], @@ -2077,7 +2077,7 @@ async def create_mocked_redis_node(host, port, **kwargs): else: result = [] - r_node = Redis(host=host, port=port) + r_node = Redis(host=self.host, port=self.port) orig_execute_command = r_node.execute_command @@ -2094,14 +2094,13 @@ async def execute_command(*args, **kwargs): r_node.execute_command = execute_command return r_node - create_redis_node.side_effect = create_mocked_redis_node + initialize.side_effect = mocked_initialize with pytest.raises(RedisClusterException) as ex: node_1 = ClusterNode("127.0.0.1", 7000) node_2 = ClusterNode("127.0.0.1", 7001) - rc = RedisCluster(startup_nodes=[node_1, node_2]) - await rc.initialize() - await rc.close() + async with RedisCluster(startup_nodes=[node_1, node_2]): + ... assert str(ex.value).startswith( "startup_nodes could not agree on a valid slots cache" ), str(ex.value) @@ -2134,13 +2133,13 @@ async def test_init_with_down_node(self): If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. """ - with mock.patch.object(NodesManager, "create_redis_node") as create_redis_node: + with mock.patch.object(ClusterNode, "initialize", autospec=True) as initialize: - async def create_mocked_redis_node(host, port, **kwargs): - if port == 7000: + async def mocked_initialize(self, **kwargs): + if self.port == 7000: raise ConnectionError("mock connection error for 7000") - r_node = Redis(host=host, port=port, decode_responses=True) + r_node = Redis(host=self.host, port=self.port, decode_responses=True) async def execute_command(*args, **kwargs): if args[0] == "CLUSTER SLOTS": @@ -2157,7 +2156,7 @@ async def execute_command(*args, **kwargs): return r_node - create_redis_node.side_effect = create_mocked_redis_node + initialize.side_effect = mocked_initialize node_1 = ClusterNode("127.0.0.1", 7000) node_2 = ClusterNode("127.0.0.1", 7001) @@ -2165,9 +2164,8 @@ async def execute_command(*args, **kwargs): # If all startup nodes fail to connect, connection error should be # thrown with pytest.raises(RedisClusterException) as e: - rc = RedisCluster(startup_nodes=[node_1]) - await rc.initialize() - await rc.close() + async with RedisCluster(startup_nodes=[node_1]): + ... assert "Redis Cluster cannot be connected" in str(e.value) with mock.patch.object( @@ -2189,8 +2187,6 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock # When at least one startup node is reachable, the cluster # initialization should succeeds - rc = RedisCluster(startup_nodes=[node_1, node_2]) - await rc.initialize() - assert rc.get_node(host=default_host, port=7001) is not None - assert rc.get_node(host=default_host, port=7002) is not None - await rc.close() + async with RedisCluster(startup_nodes=[node_1, node_2]) as rc: + assert rc.get_node(host=default_host, port=7001) is not None + assert rc.get_node(host=default_host, port=7002) is not None From 1707a5a7ee665d7d01f00a703bde7aea8381753b Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 15 Apr 2022 14:25:40 +0530 Subject: [PATCH 06/23] async_cluster: add types --- redis/asyncio/cluster.py | 174 ++++++++++----- redis/asyncio/parser.py | 20 +- redis/commands/cluster.py | 151 ++++++++----- redis/crc.py | 4 +- redis/typing.py | 12 +- tests/conftest.py | 11 +- tests/test_asyncio/test_cluster.py | 337 ++++++++++++++++------------- whitelist.py | 1 + 8 files changed, 444 insertions(+), 266 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index b21de3a68a..ccd2f6671f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,9 +3,16 @@ import random import socket import warnings +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from redis.asyncio.client import Redis -from redis.asyncio.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + DefaultParser, + Encoder, + parse_url, +) from redis.asyncio.parser import CommandsParser from redis.client import CaseInsensitiveDict from redis.cluster import ( @@ -35,10 +42,14 @@ TimeoutError, TryAgainError, ) +from redis.typing import EncodableT, KeyT from redis.utils import dict_merge, str_if_bytes log = logging.getLogger(__name__) +_RedisClusterT = TypeVar("_RedisClusterT", bound="RedisCluster") +_ClusterNodeT = TypeVar("_ClusterNodeT", bound="ClusterNode") + class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( @@ -56,7 +67,7 @@ class ClusterParser(DefaultParser): class RedisCluster(AbstractRedisCluster, AsyncRedisClusterCommands): @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls, url: str, **kwargs) -> _RedisClusterT: """ Return a Redis client object configured from the given URL @@ -99,18 +110,34 @@ class initializer. In the case of conflicting arguments, querystring """ return cls(url=url, **kwargs) + __slots__ = ( + "_initialize", + "_lock", + "cluster_error_retry_attempts", + "cluster_response_callbacks", + "command_flags", + "commands_parser", + "encoder", + "node_flags", + "nodes_manager", + "read_from_replicas", + "reinitialize_counter", + "reinitialize_steps", + "result_callbacks", + ) + def __init__( self, - host=None, - port=6379, - startup_nodes=None, - cluster_error_retry_attempts=3, - require_full_coverage=False, - reinitialize_steps=10, - read_from_replicas=False, - url=None, + host: Optional[str] = None, + port: int = 6379, + startup_nodes: Optional[List[_ClusterNodeT]] = None, + cluster_error_retry_attempts: int = 3, + require_full_coverage: bool = False, + reinitialize_steps: int = 10, + read_from_replicas: bool = False, + url: Optional[str] = None, **kwargs, - ): + ) -> None: """ Initialize a new RedisCluster client. @@ -231,7 +258,7 @@ def __init__( self._initialize = True self._lock = asyncio.Lock() - async def initialize(self): + async def initialize(self) -> _RedisClusterT: if self._initialize: async with self._lock: if self._initialize: @@ -246,17 +273,17 @@ async def initialize(self): raise return self - async def close(self): + async def close(self) -> None: if not self._initialize: async with self._lock: if not self._initialize: self._initialize = True await self.nodes_manager.close() - async def __aenter__(self): + async def __aenter__(self) -> _RedisClusterT: return await self.initialize() - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: await self.close() def __await__(self): @@ -276,7 +303,7 @@ def __del__(self, _warnings=warnings): except RuntimeError: ... - async def on_connect(self, connection): + async def on_connect(self, connection: Connection) -> None: """ Initialize the connection, authenticate and select a database and send READONLY if it is set during object initialization. @@ -294,22 +321,27 @@ async def on_connect(self, connection): if str_if_bytes(await connection.read_response()) != "OK": raise ConnectionError("READONLY command failed") - def get_node(self, host=None, port=None, node_name=None): + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional[_ClusterNodeT]: return self.nodes_manager.get_node(host, port, node_name) - def get_primaries(self): + def get_primaries(self) -> List[_ClusterNodeT]: return self.nodes_manager.get_nodes_by_server_type(PRIMARY) - def get_replicas(self): + def get_replicas(self) -> List[_ClusterNodeT]: return self.nodes_manager.get_nodes_by_server_type(REPLICA) - def get_random_node(self): + def get_random_node(self) -> _ClusterNodeT: return random.choice(list(self.nodes_manager.nodes_cache.values())) - def get_nodes(self): + def get_nodes(self) -> List[_ClusterNodeT]: return list(self.nodes_manager.nodes_cache.values()) - def get_node_from_key(self, key, replica=False): + def get_node_from_key(self, key: str, replica: bool = False) -> _ClusterNodeT: """ Get the node that holds the key's slot. If replica set to True but the slot doesn't have any replicas, None is @@ -329,13 +361,13 @@ def get_node_from_key(self, key, replica=False): return slot_cache[node_idx] - def get_default_node(self): + def get_default_node(self) -> _ClusterNodeT: """ Get the cluster's default node """ return self.nodes_manager.default_node - def set_default_node(self, node): + def set_default_node(self, node: Optional[_ClusterNodeT]) -> bool: """ Set the default node of the cluster. :param node: 'ClusterNode' @@ -351,11 +383,11 @@ def set_default_node(self, node): log.info(f"Changed the default cluster node to {node}") return True - def set_response_callback(self, command, callback): + def set_response_callback(self, command: KeyT, callback: Callable) -> None: """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - async def _determine_nodes(self, *args, **kwargs): + async def _determine_nodes(self, *args, **kwargs) -> List[_ClusterNodeT]: command = args[0] nodes_flag = kwargs.pop("nodes_flag", None) if nodes_flag is not None: @@ -392,7 +424,7 @@ async def _determine_nodes(self, *args, **kwargs): log.debug(f"Target for {args}: slot {slot}") return [node] - def keyslot(self, key): + def keyslot(self, key: Union[str, int, float, bytes]) -> int: """ Calculate keyslot for a given key. See Keys distribution model in https://redis.io/topics/cluster-spec @@ -400,7 +432,7 @@ def keyslot(self, key): k = self.encoder.encode(key) return key_slot(k) - async def determine_slot(self, *args): + async def determine_slot(self, *args) -> int: """ Figure out what slot to use based on args. @@ -462,22 +494,26 @@ async def determine_slot(self, *args): return slots.pop() - def get_encoder(self): + def get_encoder(self) -> Encoder: """ Get the connections' encoder """ return self.encoder - def get_connection_kwargs(self): + def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: """ Get the connections' key-word arguments """ return self.nodes_manager.connection_kwargs - def _is_nodes_flag(self, target_nodes): + def _is_nodes_flag( + self, target_nodes: Union[List[_ClusterNodeT], _ClusterNodeT, str] + ) -> bool: return isinstance(target_nodes, str) and target_nodes in self.node_flags - def _parse_target_nodes(self, target_nodes): + def _parse_target_nodes( + self, target_nodes: Union[List[_ClusterNodeT], _ClusterNodeT] + ) -> List[_ClusterNodeT]: if isinstance(target_nodes, list): nodes = target_nodes elif isinstance(target_nodes, ClusterNode): @@ -497,7 +533,7 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - async def execute_command(self, *args, **kwargs): + async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: """ Wrapper for ERRORS_ALLOW_RETRY error handling. @@ -566,7 +602,9 @@ async def execute_command(self, *args, **kwargs): # to caller of this method raise exception - async def _execute_command(self, target_node, *args, **kwargs): + async def _execute_command( + self, target_node: _ClusterNodeT, *args: Union[KeyT, EncodableT], **kwargs + ) -> Any: """ Send a command to a node in the cluster """ @@ -701,7 +739,7 @@ async def _execute_command(self, target_node, *args, **kwargs): raise ClusterError("TTL exhausted.") - def _process_result(self, command, res, **kwargs): + def _process_result(self, command: KeyT, res: Dict[str, Any], **kwargs) -> Any: """ Process the result of the executed command. The function would return a dict or a single value. @@ -723,7 +761,15 @@ def _process_result(self, command, res, **kwargs): class ClusterNode: - def __init__(self, host, port, server_type=None, redis_connection=None): + __slots__ = ("_lock", "host", "name", "port", "redis_connection", "server_type") + + def __init__( + self, + host: str, + port: int, + server_type: Optional[str] = None, + redis_connection: None = None, + ) -> None: if host == "localhost": host = socket.gethostbyname(host) @@ -734,7 +780,7 @@ def __init__(self, host, port, server_type=None, redis_connection=None): self.redis_connection = redis_connection self._lock = asyncio.Lock() - def __repr__(self): + def __repr__(self) -> str: return ( f"[host={self.host}," f"port={self.port}," @@ -743,7 +789,7 @@ def __repr__(self): f"redis_connection={self.redis_connection}]" ) - def __eq__(self, obj): + def __eq__(self, obj: _ClusterNodeT) -> bool: return isinstance(obj, ClusterNode) and obj.name == self.name _DEL_MESSAGE = "Unclosed ClusterNode object" @@ -760,7 +806,7 @@ def __del__(self, _warnings=warnings): except RuntimeError: ... - async def initialize(self, from_url=False, **kwargs): + async def initialize(self, from_url: bool = False, **kwargs) -> Redis: if not self.redis_connection: async with self._lock: if not self.redis_connection: @@ -775,7 +821,7 @@ async def initialize(self, from_url=False, **kwargs): return self.redis_connection - async def close(self): + async def close(self) -> None: if self.redis_connection: async with self._lock: if self.redis_connection: @@ -785,9 +831,26 @@ async def close(self): class NodesManager: + __slots__ = ( + "_lock", + "_moved_exception", + "_require_full_coverage", + "connection_kwargs", + "default_node", + "from_url", + "nodes_cache", + "read_load_balancer", + "slots_cache", + "startup_nodes", + ) + def __init__( - self, startup_nodes, from_url=False, require_full_coverage=False, **kwargs - ): + self, + startup_nodes: List[_ClusterNodeT], + from_url: bool = False, + require_full_coverage: bool = False, + **kwargs, + ) -> None: self.nodes_cache = {} self.slots_cache = {} self.startup_nodes = {node.name: node for node in startup_nodes} @@ -799,7 +862,12 @@ def __init__( self.read_load_balancer = LoadBalancer() self._lock = asyncio.Lock() - def get_node(self, host=None, port=None, node_name=None): + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional[_ClusterNodeT]: """ Get the requested node from the cluster's nodes. nodes. @@ -820,7 +888,9 @@ def get_node(self, host=None, port=None, node_name=None): ) return None - async def set_nodes(self, old, new): + async def set_nodes( + self, old: Dict[str, _ClusterNodeT], new: Dict[str, _ClusterNodeT] + ) -> None: tasks = [node.close() for name, node in old.items() if name not in new] for name, node in new.items(): if name in old: @@ -830,7 +900,7 @@ async def set_nodes(self, old, new): old[name] = node await asyncio.gather(*tasks) - async def _update_moved_slots(self): + async def _update_moved_slots(self) -> None: """ Update the slot's node with the redirected one """ @@ -872,8 +942,8 @@ async def _update_moved_slots(self): self._moved_exception = None async def get_node_from_slot( - self, slot, read_from_replicas=False, server_type=None - ): + self, slot: int, read_from_replicas: bool = False, server_type: None = None + ) -> _ClusterNodeT: """ Gets a node that servers this hash slot """ @@ -908,7 +978,7 @@ async def get_node_from_slot( return self.slots_cache[slot][node_idx] - def get_nodes_by_server_type(self, server_type): + def get_nodes_by_server_type(self, server_type: str) -> List[_ClusterNodeT]: """ Get all nodes with the specified server type :param server_type: 'primary' or 'replica' @@ -920,7 +990,7 @@ def get_nodes_by_server_type(self, server_type): if node.server_type == server_type ] - def check_slots_coverage(self, slots_cache): + def check_slots_coverage(self, slots_cache: Dict[int, List[_ClusterNodeT]]) -> bool: # Validate if all slots are covered or if we should try next # startup node for i in range(0, REDIS_CLUSTER_HASH_SLOTS): @@ -928,7 +998,7 @@ def check_slots_coverage(self, slots_cache): return False return True - async def initialize(self): + async def initialize(self) -> None: """ Initializes the nodes cache, slots cache and redis connections. :startup_nodes: @@ -1089,11 +1159,11 @@ async def initialize(self): # If initialize was called after a MovedError, clear it self._moved_exception = None - async def close(self, attr="nodes_cache"): + async def close(self, attr: str = "nodes_cache") -> None: self.default_node = None await asyncio.gather(*[node.close() for node in getattr(self, attr).values()]) - def reset(self): + def reset(self) -> None: try: self.read_load_balancer.reset() except TypeError: diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 0c82e4f2c6..7a84373c26 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,5 +1,13 @@ +from typing import TYPE_CHECKING, List, Optional, TypeVar, Union + +from redis.asyncio.client import Redis from redis.exceptions import RedisError, ResponseError +if TYPE_CHECKING: + from redis.asyncio.cluster import RedisCluster + +_RedisClusterT = TypeVar("_RedisClusterT", bound="RedisCluster") + class CommandsParser: """ @@ -10,10 +18,12 @@ class CommandsParser: 'COMMAND GETKEYS'. """ - def __init__(self): + __slots__ = ("commands",) + + def __init__(self) -> None: self.commands = {} - async def initialize(self, r): + async def initialize(self, r: _RedisClusterT) -> None: commands = await r.execute_command("COMMAND") uppercase_commands = [] for cmd in commands: @@ -26,7 +36,9 @@ async def initialize(self, r): # As soon as this PR is merged into Redis, we should reimplement # our logic to use COMMAND INFO changes to determine the key positions # https://github.com/redis/redis/pull/8324 - async def get_keys(self, redis_conn, *args): + async def get_keys( + self, redis_conn: Redis, *args + ) -> Optional[Union[List[str], List[bytes]]]: """ Get the keys from the passed command. @@ -80,7 +92,7 @@ async def get_keys(self, redis_conn, *args): return keys - async def _get_moveable_keys(self, redis_conn, *args): + async def _get_moveable_keys(self, redis_conn: Redis, *args) -> Optional[List[str]]: """ NOTE: Due to a bug in redis<7.0, this function does not work properly for EVAL or EVALSHA when the `numkeys` arg is 0. diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 5d9a8b3819..f85db27170 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -1,9 +1,31 @@ import asyncio -from typing import Iterator, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Mapping, + NoReturn, + Optional, + TypeVar, + Union, +) +from redis.compat import Literal from redis.crc import key_slot from redis.exceptions import RedisClusterException, RedisError -from redis.typing import PatternT +from redis.typing import ( + AnyKeyT, + ClusterCommandsProtocol, + EncodableT, + KeysT, + KeyT, + PatternT, +) from .core import ( ACLCommands, @@ -21,8 +43,15 @@ from .helpers import list_or_args from .redismodules import RedisModuleCommands +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + +_TargetNodesT = TypeVar( + "_TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] +) -class ClusterMultiKeyCommands: + +class ClusterMultiKeyCommands(ClusterCommandsProtocol): """ A class containing commands that handle more than one key """ @@ -173,12 +202,12 @@ def unlink(self, *keys): return self._split_command_across_slots("UNLINK", *keys) -class AsyncClusterMultiKeyCommands: +class AsyncClusterMultiKeyCommands(ClusterCommandsProtocol): """ A class containing commands that handle more than one key """ - def _partition_keys_by_slot(self, keys): + def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: """ Split keys into a dictionary that maps a slot to a list of keys. @@ -191,7 +220,7 @@ def _partition_keys_by_slot(self, keys): return slots_to_keys - async def mget_nonatomic(self, keys, *args): + async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -231,7 +260,7 @@ async def mget_nonatomic(self, keys, *args): vals_in_order = [all_results[key] for key in keys] return vals_in_order - async def mset_nonatomic(self, mapping): + async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -258,7 +287,7 @@ async def mset_nonatomic(self, mapping): *[self.execute_command("MSET", *pairs) for pairs in slots_to_pairs.values()] ) - async def _split_command_across_slots(self, command, *keys): + async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: """ Runs the given command once for the keys of each slot. Returns the sum of the return values. @@ -276,7 +305,7 @@ async def _split_command_across_slots(self, command, *keys): ) ) - def exists(self, *keys): + def exists(self, *keys: KeyT) -> Awaitable: """ Returns the number of ``names`` that exist in the whole cluster. The keys are first split up into slots @@ -286,7 +315,7 @@ def exists(self, *keys): """ return self._split_command_across_slots("EXISTS", *keys) - def delete(self, *keys): + def delete(self, *keys: KeyT) -> Awaitable: """ Deletes the given keys in the cluster. The keys are first split up into slots @@ -299,7 +328,7 @@ def delete(self, *keys): """ return self._split_command_across_slots("DEL", *keys) - def touch(self, *keys): + def touch(self, *keys: KeyT) -> Awaitable: """ Updates the last access time of given keys across the cluster. @@ -314,7 +343,7 @@ def touch(self, *keys): """ return self._split_command_across_slots("TOUCH", *keys) - def unlink(self, *keys): + def unlink(self, *keys: KeyT) -> Awaitable: """ Remove the specified keys in a different thread. @@ -370,7 +399,7 @@ class AsyncClusterManagementCommands(AsyncManagementCommands): required adjustments to work with cluster mode """ - def slaveof(self, *args, **kwargs): + def slaveof(self, *args, **kwargs) -> NoReturn: """ Make the server a replica of another instance, or promote it as master. @@ -378,7 +407,7 @@ def slaveof(self, *args, **kwargs): """ raise RedisClusterException("SLAVEOF is not supported in cluster mode") - def replicaof(self, *args, **kwargs): + def replicaof(self, *args, **kwargs) -> NoReturn: """ Make the server a replica of another instance, or promote it as master. @@ -386,7 +415,7 @@ def replicaof(self, *args, **kwargs): """ raise RedisClusterException("REPLICAOF is not supported in cluster" " mode") - def swapdb(self, *args, **kwargs): + def swapdb(self, *args, **kwargs) -> NoReturn: """ Swaps two Redis databases. @@ -496,16 +525,16 @@ class AsyncClusterDataAccessCommands(AsyncDataAccessCommands): def stralgo( self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Optional[int] = None, + withmatchlen: bool = False, **kwargs, - ): + ) -> Awaitable: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -543,11 +572,11 @@ def stralgo( async def scan_iter( self, - match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, **kwargs, - ) -> Iterator: + ) -> AsyncIterator: # Do the first query with cursor=0 for all nodes cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) for value in data: @@ -912,7 +941,7 @@ class AsyncRedisClusterCommands( r.cluster_info(target_nodes=RedisCluster.ALL_NODES) """ - def cluster_myid(self, target_node): + def cluster_myid(self, target_node: _TargetNodesT) -> Awaitable: """ Returns the node’s id. @@ -923,7 +952,9 @@ def cluster_myid(self, target_node): """ return self.execute_command("CLUSTER MYID", target_nodes=target_node) - def cluster_addslots(self, target_node, *slots): + def cluster_addslots( + self, target_node: _TargetNodesT, *slots: EncodableT + ) -> Awaitable: """ Assign new hash slots to receiving node. Sends to specified node. @@ -936,7 +967,9 @@ def cluster_addslots(self, target_node, *slots): "CLUSTER ADDSLOTS", *slots, target_nodes=target_node ) - def cluster_addslotsrange(self, target_node, *slots): + def cluster_addslotsrange( + self, target_node: _TargetNodesT, *slots: EncodableT + ) -> Awaitable: """ Similar to the CLUSTER ADDSLOTS command. The difference between the two commands is that ADDSLOTS takes a list of slots @@ -952,7 +985,7 @@ def cluster_addslotsrange(self, target_node, *slots): "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node ) - def cluster_countkeysinslot(self, slot_id): + def cluster_countkeysinslot(self, slot_id: int) -> Awaitable: """ Return the number of local keys in the specified hash slot Send to node based on specified slot_id @@ -961,7 +994,7 @@ def cluster_countkeysinslot(self, slot_id): """ return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) - def cluster_count_failure_report(self, node_id): + def cluster_count_failure_report(self, node_id: str) -> Awaitable: """ Return the number of failure reports active for a given node Sends to a random node @@ -970,7 +1003,7 @@ def cluster_count_failure_report(self, node_id): """ return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - async def cluster_delslots(self, *slots): + async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: """ Set hash slots as unbound in the cluster. It determines by it self what node the slot is in and sends it there @@ -983,7 +1016,7 @@ async def cluster_delslots(self, *slots): *[self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] ) - def cluster_delslotsrange(self, *slots): + def cluster_delslotsrange(self, *slots: EncodableT) -> Awaitable: """ Similar to the CLUSTER DELSLOTS command. The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove @@ -994,7 +1027,9 @@ def cluster_delslotsrange(self, *slots): """ return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) - def cluster_failover(self, target_node, option=None): + def cluster_failover( + self, target_node: _TargetNodesT, option: Optional[str] = None + ) -> Awaitable: """ Forces a slave to perform a manual failover of its master Sends to specified node @@ -1016,7 +1051,7 @@ def cluster_failover(self, target_node, option=None): else: return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - def cluster_info(self, target_nodes=None): + def cluster_info(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: """ Provides info about Redis Cluster node state. The command will be sent to a random node in the cluster if no target @@ -1026,7 +1061,7 @@ def cluster_info(self, target_nodes=None): """ return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - def cluster_keyslot(self, key): + def cluster_keyslot(self, key: str) -> Awaitable: """ Returns the hash slot of the specified key Sends to random node in the cluster @@ -1035,7 +1070,9 @@ def cluster_keyslot(self, key): """ return self.execute_command("CLUSTER KEYSLOT", key) - def cluster_meet(self, host, port, target_nodes=None): + def cluster_meet( + self, host: str, port: int, target_nodes: Optional[_TargetNodesT] = None + ) -> Awaitable: """ Force a node cluster to handshake with another node. Sends to specified node. @@ -1046,7 +1083,7 @@ def cluster_meet(self, host, port, target_nodes=None): "CLUSTER MEET", host, port, target_nodes=target_nodes ) - def cluster_nodes(self): + def cluster_nodes(self) -> Awaitable: """ Get Cluster config for the node. Sends to random node in the cluster @@ -1055,7 +1092,7 @@ def cluster_nodes(self): """ return self.execute_command("CLUSTER NODES") - def cluster_replicate(self, target_nodes, node_id): + def cluster_replicate(self, target_nodes: _TargetNodesT, node_id: str) -> Awaitable: """ Reconfigure a node as a slave of the specified master node @@ -1065,7 +1102,9 @@ def cluster_replicate(self, target_nodes, node_id): "CLUSTER REPLICATE", node_id, target_nodes=target_nodes ) - def cluster_reset(self, soft=True, target_nodes=None): + def cluster_reset( + self, soft: bool = True, target_nodes: Optional[_TargetNodesT] = None + ) -> Awaitable: """ Reset a Redis Cluster node @@ -1078,7 +1117,9 @@ def cluster_reset(self, soft=True, target_nodes=None): "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes ) - def cluster_save_config(self, target_nodes=None): + def cluster_save_config( + self, target_nodes: Optional[_TargetNodesT] = None + ) -> Awaitable: """ Forces the node to save cluster state on disk @@ -1086,7 +1127,7 @@ def cluster_save_config(self, target_nodes=None): """ return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) - def cluster_get_keys_in_slot(self, slot, num_keys): + def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> Awaitable: """ Returns the number of keys in the specified cluster slot @@ -1094,7 +1135,9 @@ def cluster_get_keys_in_slot(self, slot, num_keys): """ return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) - def cluster_set_config_epoch(self, epoch, target_nodes=None): + def cluster_set_config_epoch( + self, epoch: int, target_nodes: Optional[_TargetNodesT] = None + ) -> Awaitable: """ Set the configuration epoch in a new node @@ -1104,7 +1147,9 @@ def cluster_set_config_epoch(self, epoch, target_nodes=None): "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes ) - def cluster_setslot(self, target_node, node_id, slot_id, state): + def cluster_setslot( + self, target_node: _TargetNodesT, node_id: str, slot_id: int, state: str + ) -> Awaitable: """ Bind an hash slot to a specific node @@ -1122,7 +1167,7 @@ def cluster_setslot(self, target_node, node_id, slot_id, state): else: raise RedisError(f"Invalid slot state: {state}") - def cluster_setslot_stable(self, slot_id): + def cluster_setslot_stable(self, slot_id: int) -> Awaitable: """ Clears migrating / importing state from the slot. It determines by it self what node the slot is in and sends it there. @@ -1131,7 +1176,9 @@ def cluster_setslot_stable(self, slot_id): """ return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") - def cluster_replicas(self, node_id, target_nodes=None): + def cluster_replicas( + self, node_id: str, target_nodes: Optional[_TargetNodesT] = None + ) -> Awaitable: """ Provides a list of replica nodes replicating from the specified primary target node. @@ -1142,7 +1189,7 @@ def cluster_replicas(self, node_id, target_nodes=None): "CLUSTER REPLICAS", node_id, target_nodes=target_nodes ) - def cluster_slots(self, target_nodes=None): + def cluster_slots(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: """ Get array of Cluster slot to node mappings @@ -1150,7 +1197,7 @@ def cluster_slots(self, target_nodes=None): """ return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - def cluster_links(self, target_node): + def cluster_links(self, target_node: _TargetNodesT) -> Awaitable: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each peer in the cluster: One for sending outbound messages towards the peer and one @@ -1162,7 +1209,7 @@ def cluster_links(self, target_node): """ return self.execute_command("CLUSTER LINKS", target_nodes=target_node) - def readonly(self, target_nodes=None): + def readonly(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: """ Enables read queries. The command will be sent to the default cluster node if target_nodes is @@ -1176,7 +1223,7 @@ def readonly(self, target_nodes=None): self.read_from_replicas = True return self.execute_command("READONLY", target_nodes=target_nodes) - def readwrite(self, target_nodes=None): + def readwrite(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: """ Disables read queries. The command will be sent to the default cluster node if target_nodes is diff --git a/redis/crc.py b/redis/crc.py index c47e2acede..e261241178 100644 --- a/redis/crc.py +++ b/redis/crc.py @@ -1,5 +1,7 @@ from binascii import crc_hqx +from redis.typing import EncodedT + # Redis Cluster's key space is divided into 16384 slots. # For more information see: https://github.com/redis/redis/issues/2576 REDIS_CLUSTER_HASH_SLOTS = 16384 @@ -7,7 +9,7 @@ __all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] -def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): +def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: """Calculate key slot for a given key. See Keys distribution model in https://redis.io/topics/cluster-spec :param key - bytes diff --git a/redis/typing.py b/redis/typing.py index 73ae411f4d..6748612ff1 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,13 +1,14 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Iterable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union from redis.compat import Protocol if TYPE_CHECKING: from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.connection import ConnectionPool + from redis.asyncio.connection import Encoder as AsyncEncoder + from redis.connection import ConnectionPool, Encoder EncodedT = Union[bytes, memoryview] @@ -43,3 +44,10 @@ class CommandsProtocol(Protocol): def execute_command(self, *args, **options): ... + + +class ClusterCommandsProtocol(CommandsProtocol): + encoder: Union["AsyncEncoder", "Encoder"] + + def execute_command(self, *args, **options) -> Union[Any, Awaitable]: + ... diff --git a/tests/conftest.py b/tests/conftest.py index 903e961b46..e83c866f77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,20 +199,21 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): def skip_if_server_version_lt(min_version: str) -> _TestDecorator: - redis_version = REDIS_INFO["version"] + redis_version = REDIS_INFO.get("version", "0") check = Version(redis_version) < Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") def skip_if_server_version_gte(min_version: str) -> _TestDecorator: - redis_version = REDIS_INFO["version"] + redis_version = REDIS_INFO.get("version", "0") check = Version(redis_version) >= Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( - REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" + REDIS_INFO.get("arch_bits", "") != arch_bits, + reason=f"server is not {arch_bits}-bit", ) @@ -235,12 +236,12 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): def skip_if_redis_enterprise() -> _TestDecorator: - check = REDIS_INFO["enterprise"] is True + check = REDIS_INFO.get("enterprise", False) is True return pytest.mark.skipif(check, reason="Redis enterprise") def skip_ifnot_redis_enterprise() -> _TestDecorator: - check = REDIS_INFO["enterprise"] is False + check = REDIS_INFO.get("enterprise", False) is False return pytest.mark.skipif(check, reason="Not running in redis enterprise") diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index b5954c7d66..39f6abc13a 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -13,6 +13,10 @@ else: import pytest_asyncio +from typing import Callable, Dict, List, Optional, Type, Union + +from _pytest.fixtures import FixtureRequest, SubRequest + from redis.asyncio import Connection, Redis, RedisCluster from redis.asyncio.cluster import ( PRIMARY, @@ -53,7 +57,7 @@ @pytest_asyncio.fixture() -async def slowlog(request, r): +async def slowlog(request: SubRequest, r: RedisCluster) -> None: """ Set the slowlog threshold to 0, and the max length to 128. This will force every @@ -75,7 +79,7 @@ async def slowlog(request, r): await r.config_set("slowlog-max-len", old_max_length_value) -async def get_mocked_redis_client(*args, **kwargs): +async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: """ Return a stable RedisCluster object that have deterministic nodes and slots setup to remove the problem of different IP addresses @@ -122,20 +126,32 @@ def cmd_init_mock(self, r): return await RedisCluster(*args, **kwargs) -def mock_node_resp(node, response): +def mock_node_resp( + node: ClusterNode, + response: Union[ + List[List[Union[int, List[Union[str, int]]]]], List[bytes], str, int + ], +) -> ClusterNode: connection = mock.AsyncMock() connection.read_response.return_value = response node.redis_connection.connection = connection return node -def mock_all_nodes_resp(rc, response): +def mock_all_nodes_resp( + rc: RedisCluster, + response: Union[ + List[List[Union[int, List[Union[str, int]]]]], List[bytes], int, str + ], +) -> RedisCluster: for node in rc.get_nodes(): mock_node_resp(node, response) return rc -async def moved_redirection_helper(request, create_redis, failover=False): +async def moved_redirection_helper( + request: FixtureRequest, create_redis: Callable, failover: bool = False +) -> None: """ Test that the client handles MOVED response after a failover. Redirection after a failover means that the redirection address is of a @@ -194,7 +210,7 @@ class TestRedisClusterObj: Tests for the RedisCluster class """ - async def test_host_port_startup_node(self): + async def test_host_port_startup_node(self) -> None: """ Test that it is possible to use host & port arguments as startup node args @@ -204,7 +220,7 @@ async def test_host_port_startup_node(self): await cluster.close() - async def test_startup_nodes(self): + async def test_startup_nodes(self) -> None: """ Test that it is possible to use startup_nodes argument to init the cluster @@ -223,7 +239,7 @@ async def test_startup_nodes(self): await cluster.close() - async def test_empty_startup_nodes(self): + async def test_empty_startup_nodes(self) -> None: """ Test that exception is raised when empty providing empty startup_nodes """ @@ -234,7 +250,7 @@ async def test_empty_startup_nodes(self): "RedisCluster requires at least one node to discover the " "cluster" ), str_if_bytes(ex.value) - async def test_from_url(self, r): + async def test_from_url(self, r: RedisCluster) -> None: redis_url = f"redis://{default_host}:{default_port}/0" with mock.patch.object(RedisCluster, "from_url") as from_url: @@ -247,7 +263,7 @@ async def from_url_mocked(_url, **_kwargs): await cluster.close() - async def test_execute_command_errors(self, r): + async def test_execute_command_errors(self, r: RedisCluster) -> None: """ Test that if no key is provided then exception should be raised. """ @@ -257,7 +273,7 @@ async def test_execute_command_errors(self, r): "No way to dispatch this command to " "Redis Cluster. Missing key." ) - async def test_execute_command_node_flag_primaries(self, r): + async def test_execute_command_node_flag_primaries(self, r: RedisCluster) -> None: """ Test command execution with nodes flag PRIMARIES """ @@ -272,7 +288,7 @@ async def test_execute_command_node_flag_primaries(self, r): conn = replica.redis_connection.connection assert conn.read_response.called is not True - async def test_execute_command_node_flag_replicas(self, r): + async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None: """ Test command execution with nodes flag REPLICAS """ @@ -291,7 +307,7 @@ async def test_execute_command_node_flag_replicas(self, r): await r.close() - async def test_execute_command_node_flag_all_nodes(self, r): + async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> None: """ Test command execution with nodes flag ALL_NODES """ @@ -301,7 +317,7 @@ async def test_execute_command_node_flag_all_nodes(self, r): conn = node.redis_connection.connection assert conn.read_response.called is True - async def test_execute_command_node_flag_random(self, r): + async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: """ Test command execution with nodes flag RANDOM """ @@ -314,7 +330,7 @@ async def test_execute_command_node_flag_random(self, r): called_count += 1 assert called_count == 1 - async def test_execute_command_default_node(self, r): + async def test_execute_command_default_node(self, r: RedisCluster) -> None: """ Test command execution without node flag is being executed on the default node @@ -325,7 +341,7 @@ async def test_execute_command_default_node(self, r): conn = def_node.redis_connection.connection assert conn.read_response.called - async def test_ask_redirection(self, r): + async def test_ask_redirection(self, r: RedisCluster) -> None: """ Test that the server handles ASK response. @@ -351,19 +367,25 @@ def ok_response(connection, *args, **options): assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" - async def test_moved_redirection(self, request, create_redis): + async def test_moved_redirection( + self, request: FixtureRequest, create_redis: Callable + ) -> None: """ Test that the client handles MOVED response. """ await moved_redirection_helper(request, create_redis, failover=False) - async def test_moved_redirection_after_failover(self, request, create_redis): + async def test_moved_redirection_after_failover( + self, request: FixtureRequest, create_redis: Callable + ) -> None: """ Test that the client handles MOVED response after a failover. """ await moved_redirection_helper(request, create_redis, failover=True) - async def test_refresh_using_specific_nodes(self, request, create_redis): + async def test_refresh_using_specific_nodes( + self, request: FixtureRequest, create_redis: Callable + ) -> None: """ Test making calls on specific nodes when the cluster has failed over to another node @@ -453,7 +475,7 @@ def cmd_init_mock(self, r): assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 - async def test_reading_from_replicas_in_round_robin(self): + async def test_reading_from_replicas_in_round_robin(self) -> None: with mock.patch.multiple( Connection, send_command=mock.DEFAULT, @@ -507,7 +529,7 @@ def parse_response_mock_third(connection, *args, **options): await read_cluster.close() - async def test_keyslot(self, r): + async def test_keyslot(self, r: RedisCluster) -> None: """ Test that method will compute correct key in all supported cases """ @@ -524,13 +546,13 @@ async def test_keyslot(self, r): assert r.keyslot(1337) == r.keyslot("1337") assert r.keyslot(b"abc") == r.keyslot("abc") - async def test_get_node_name(self): + async def test_get_node_name(self) -> None: assert ( get_node_name(default_host, default_port) == f"{default_host}:{default_port}" ) - async def test_all_nodes(self, r): + async def test_all_nodes(self, r: RedisCluster) -> None: """ Set a list of nodes and it should be possible to iterate over all """ @@ -539,7 +561,7 @@ async def test_all_nodes(self, r): for i, node in enumerate(r.get_nodes()): assert node in nodes - async def test_all_nodes_masters(self, r): + async def test_all_nodes_masters(self, r: RedisCluster) -> None: """ Set a list of nodes with random primaries/replicas config and it shold be possible to iterate over all of them. @@ -554,7 +576,10 @@ async def test_all_nodes_masters(self, r): assert node in nodes @pytest.mark.parametrize("error", RedisCluster.ERRORS_ALLOW_RETRY) - async def test_cluster_down_overreaches_retry_attempts(self, error): + async def test_cluster_down_overreaches_retry_attempts( + self, + error: Union[Type[TimeoutError], Type[ClusterDownError], Type[ConnectionError]], + ) -> None: """ When error that allows retry is thrown, test that we retry executing the command as many times as configured in cluster_error_retry_attempts @@ -576,7 +601,7 @@ def raise_error(target_node, *args, **kwargs): await rc.close() - async def test_set_default_node_success(self, r): + async def test_set_default_node_success(self, r: RedisCluster) -> None: """ test successful replacement of the default cluster node """ @@ -590,7 +615,7 @@ async def test_set_default_node_success(self, r): assert r.set_default_node(new_def_node) is True assert r.get_default_node() == new_def_node - async def test_set_default_node_failure(self, r): + async def test_set_default_node_failure(self, r: RedisCluster) -> None: """ test failed replacement of the default cluster node """ @@ -600,7 +625,7 @@ async def test_set_default_node_failure(self, r): assert r.set_default_node(new_def_node) is False assert r.get_default_node() == default_node - async def test_get_node_from_key(self, r): + async def test_get_node_from_key(self, r: RedisCluster) -> None: """ Test that get_node_from_key function returns the correct node """ @@ -615,7 +640,9 @@ async def test_get_node_from_key(self, r): assert replica in slot_nodes @skip_if_redis_enterprise() - async def test_not_require_full_coverage_cluster_down_error(self, r): + async def test_not_require_full_coverage_cluster_down_error( + self, r: RedisCluster + ) -> None: """ When require_full_coverage is set to False (default client config) and not all slots are covered, if one of the nodes has 'cluster-require_full_coverage' @@ -648,13 +675,13 @@ class TestClusterRedisCommands: Tests for RedisCluster unique commands """ - async def test_case_insensitive_command_names(self, r): + async def test_case_insensitive_command_names(self, r: RedisCluster) -> None: assert ( r.cluster_response_callbacks["cluster addslots"] == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] ) - async def test_get_and_set(self, r): + async def test_get_and_set(self, r: RedisCluster) -> None: # get and set can't be tested independently of each other assert await r.get("a") is None byte_string = b"value" @@ -667,7 +694,7 @@ async def test_get_and_set(self, r): assert await r.get("integer") == str(integer).encode() assert (await r.get("unicode_string")).decode("utf-8") == unicode_string - async def test_mget_nonatomic(self, r): + async def test_mget_nonatomic(self, r: RedisCluster) -> None: assert await r.mget_nonatomic([]) == [] assert await r.mget_nonatomic(["a", "b"]) == [None, None] await r.set("a", "1") @@ -681,16 +708,16 @@ async def test_mget_nonatomic(self, r): b"3", ] - async def test_mset_nonatomic(self, r): + async def test_mset_nonatomic(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} assert await r.mset_nonatomic(d) for k, v in d.items(): assert await r.get(k) == v - async def test_config_set(self, r): + async def test_config_set(self, r: RedisCluster) -> None: assert await r.config_set("slowlog-log-slower-than", 0) - async def test_cluster_config_resetstat(self, r): + async def test_cluster_config_resetstat(self, r: RedisCluster) -> None: await r.ping(target_nodes="all") all_info = await r.info(target_nodes="all") prior_commands_processed = -1 @@ -703,29 +730,29 @@ async def test_cluster_config_resetstat(self, r): reset_commands_processed = node_info["total_commands_processed"] assert reset_commands_processed < prior_commands_processed - async def test_client_setname(self, r): + async def test_client_setname(self, r: RedisCluster) -> None: node = r.get_random_node() await r.client_setname("redis_py_test", target_nodes=node) client_name = await r.client_getname(target_nodes=node) assert client_name == "redis_py_test" - async def test_exists(self, r): + async def test_exists(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} await r.mset_nonatomic(d) assert await r.exists(*d.keys()) == len(d) - async def test_delete(self, r): + async def test_delete(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} await r.mset_nonatomic(d) assert await r.delete(*d.keys()) == len(d) assert await r.delete(*d.keys()) == 0 - async def test_touch(self, r): + async def test_touch(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} await r.mset_nonatomic(d) assert await r.touch(*d.keys()) == len(d) - async def test_unlink(self, r): + async def test_unlink(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} await r.mset_nonatomic(d) assert await r.unlink(*d.keys()) == len(d) @@ -735,13 +762,13 @@ async def test_unlink(self, r): assert await r.unlink(*d.keys()) == 0 @skip_if_redis_enterprise() - async def test_cluster_myid(self, r): + async def test_cluster_myid(self, r: RedisCluster) -> None: node = r.get_random_node() myid = await r.cluster_myid(node) assert len(myid) == 40 @skip_if_redis_enterprise() - async def test_cluster_slots(self, r): + async def test_cluster_slots(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, default_cluster_slots) cluster_slots = await r.cluster_slots() assert isinstance(cluster_slots, dict) @@ -750,7 +777,7 @@ async def test_cluster_slots(self, r): assert cluster_slots.get((0, 8191)).get("primary") == ("127.0.0.1", 7000) @skip_if_redis_enterprise() - async def test_cluster_addslots(self, r): + async def test_cluster_addslots(self, r: RedisCluster) -> None: node = r.get_random_node() mock_node_resp(node, "OK") assert await r.cluster_addslots(node, 1, 2, 3) is True @@ -763,17 +790,17 @@ async def test_cluster_addslotsrange(self, r): assert await r.cluster_addslotsrange(node, 1, 5) @skip_if_redis_enterprise() - async def test_cluster_countkeysinslot(self, r): + async def test_cluster_countkeysinslot(self, r: RedisCluster) -> None: node = await r.nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) assert await r.cluster_countkeysinslot(1) == 2 - async def test_cluster_count_failure_report(self, r): + async def test_cluster_count_failure_report(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, 0) assert await r.cluster_count_failure_report("node_0") == 0 @skip_if_redis_enterprise() - async def test_cluster_delslots(self): + async def test_cluster_delslots(self) -> None: cluster_slots = [ [0, 8191, ["127.0.0.1", 7000, "node_0"]], [8192, 16383, ["127.0.0.1", 7001, "node_1"]], @@ -799,7 +826,7 @@ async def test_cluster_delslotsrange(self, r): assert await r.cluster_delslotsrange(1, 5) @skip_if_redis_enterprise() - async def test_cluster_failover(self, r): + async def test_cluster_failover(self, r: RedisCluster) -> None: node = r.get_random_node() mock_node_resp(node, "OK") assert await r.cluster_failover(node) is True @@ -809,24 +836,24 @@ async def test_cluster_failover(self, r): await r.cluster_failover(node, "FORCT") @skip_if_redis_enterprise() - async def test_cluster_info(self, r): + async def test_cluster_info(self, r: RedisCluster) -> None: info = await r.cluster_info() assert isinstance(info, dict) assert info["cluster_state"] == "ok" @skip_if_redis_enterprise() - async def test_cluster_keyslot(self, r): + async def test_cluster_keyslot(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, 12182) assert await r.cluster_keyslot("foo") == 12182 @skip_if_redis_enterprise() - async def test_cluster_meet(self, r): + async def test_cluster_meet(self, r: RedisCluster) -> None: node = r.get_default_node() mock_node_resp(node, "OK") assert await r.cluster_meet("127.0.0.1", 6379) is True @skip_if_redis_enterprise() - async def test_cluster_nodes(self, r): + async def test_cluster_nodes(self, r: RedisCluster) -> None: response = ( "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " @@ -855,7 +882,7 @@ async def test_cluster_nodes(self, r): ) @skip_if_redis_enterprise() - async def test_cluster_nodes_importing_migrating(self, r): + async def test_cluster_nodes_importing_migrating(self, r: RedisCluster) -> None: response = ( "488ead2fcce24d8c0f158f9172cb1f4a9e040fe5 127.0.0.1:16381@26381 " "master - 0 1648975557664 3 connected 10923-16383\n" @@ -892,7 +919,7 @@ async def test_cluster_nodes_importing_migrating(self, r): assert node_16381.get("migrations") == [] @skip_if_redis_enterprise() - async def test_cluster_replicate(self, r): + async def test_cluster_replicate(self, r: RedisCluster) -> None: node = r.get_random_node() all_replicas = r.get_replicas() mock_all_nodes_resp(r, "OK") @@ -905,7 +932,7 @@ async def test_cluster_replicate(self, r): assert results is True @skip_if_redis_enterprise() - async def test_cluster_reset(self, r): + async def test_cluster_reset(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, "OK") assert await r.cluster_reset() is True assert await r.cluster_reset(False) is True @@ -914,7 +941,7 @@ async def test_cluster_reset(self, r): assert res is True @skip_if_redis_enterprise() - async def test_cluster_save_config(self, r): + async def test_cluster_save_config(self, r: RedisCluster) -> None: node = r.get_random_node() all_nodes = r.get_nodes() mock_all_nodes_resp(r, "OK") @@ -924,7 +951,7 @@ async def test_cluster_save_config(self, r): assert res is True @skip_if_redis_enterprise() - async def test_cluster_get_keys_in_slot(self, r): + async def test_cluster_get_keys_in_slot(self, r: RedisCluster) -> None: response = [b"{foo}1", b"{foo}2"] node = await r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) @@ -932,7 +959,7 @@ async def test_cluster_get_keys_in_slot(self, r): assert keys == response @skip_if_redis_enterprise() - async def test_cluster_set_config_epoch(self, r): + async def test_cluster_set_config_epoch(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, "OK") assert await r.cluster_set_config_epoch(3) is True all_results = await r.cluster_set_config_epoch(3, target_nodes="all") @@ -940,7 +967,7 @@ async def test_cluster_set_config_epoch(self, r): assert res is True @skip_if_redis_enterprise() - async def test_cluster_setslot(self, r): + async def test_cluster_setslot(self, r: RedisCluster) -> None: node = r.get_random_node() mock_node_resp(node, "OK") assert await r.cluster_setslot(node, "node_0", 1218, "IMPORTING") is True @@ -951,14 +978,14 @@ async def test_cluster_setslot(self, r): with pytest.raises(RedisError): await r.cluster_failover(node, "STATE") - async def test_cluster_setslot_stable(self, r): + async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: node = await r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @skip_if_redis_enterprise() - async def test_cluster_replicas(self, r): + async def test_cluster_replicas(self, r: RedisCluster) -> None: response = [ b"01eca22229cf3c652b6fca0d09ff6941e0d2e3 " b"127.0.0.1:6377@16377 slave " @@ -989,7 +1016,7 @@ async def test_cluster_links(self, r): assert res[i][3] == res[i + 1][3] @skip_if_redis_enterprise() - async def test_readonly(self): + async def test_readonly(self) -> None: r = await get_mocked_redis_client(host=default_host, port=default_port) mock_all_nodes_resp(r, "OK") assert await r.readonly() is True @@ -1002,7 +1029,7 @@ async def test_readonly(self): await r.close() @skip_if_redis_enterprise() - async def test_readwrite(self): + async def test_readwrite(self) -> None: r = await get_mocked_redis_client(host=default_host, port=default_port) mock_all_nodes_resp(r, "OK") assert await r.readwrite() is True @@ -1015,12 +1042,12 @@ async def test_readwrite(self): await r.close() @skip_if_redis_enterprise() - async def test_bgsave(self, r): + async def test_bgsave(self, r: RedisCluster) -> None: assert await r.bgsave() await asyncio.sleep(0.3) assert await r.bgsave(True) - async def test_info(self, r): + async def test_info(self, r: RedisCluster) -> None: # Map keys to same slot await r.set("x{1}", 1) await r.set("y{1}", 2) @@ -1033,20 +1060,24 @@ async def test_info(self, r): assert isinstance(info, dict) assert info["db0"]["keys"] == 3 - async def _init_slowlog_test(self, r, node): + async def _init_slowlog_test(self, r: RedisCluster, node: ClusterNode) -> str: slowlog_lim = await r.config_get("slowlog-log-slower-than", target_nodes=node) assert ( await r.config_set("slowlog-log-slower-than", 0, target_nodes=node) is True ) return slowlog_lim["slowlog-log-slower-than"] - async def _teardown_slowlog_test(self, r, node, prev_limit): + async def _teardown_slowlog_test( + self, r: RedisCluster, node: ClusterNode, prev_limit: str + ) -> None: assert ( await r.config_set("slowlog-log-slower-than", prev_limit, target_nodes=node) is True ) - async def test_slowlog_get(self, r, slowlog): + async def test_slowlog_get( + self, r: RedisCluster, slowlog: Optional[List[Dict[str, Union[int, bytes]]]] + ) -> None: unicode_string = chr(3456) + "abcd" + chr(3421) node = r.get_node_from_key(unicode_string) slowlog_limit = await self._init_slowlog_test(r, node) @@ -1072,7 +1103,9 @@ async def test_slowlog_get(self, r, slowlog): # rollback the slowlog limit to its original value await self._teardown_slowlog_test(r, node, slowlog_limit) - async def test_slowlog_get_limit(self, r, slowlog): + async def test_slowlog_get_limit( + self, r: RedisCluster, slowlog: Optional[List[Dict[str, Union[int, bytes]]]] + ) -> None: assert await r.slowlog_reset() node = r.get_node_from_key("foo") slowlog_limit = await self._init_slowlog_test(r, node) @@ -1083,31 +1116,31 @@ async def test_slowlog_get_limit(self, r, slowlog): assert len(slowlog) == 1 await self._teardown_slowlog_test(r, node, slowlog_limit) - async def test_slowlog_length(self, r, slowlog): + async def test_slowlog_length(self, r: RedisCluster, slowlog: None) -> None: await r.get("foo") node = await r.nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = await r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) - async def test_time(self, r): + async def test_time(self, r: RedisCluster) -> None: t = await r.time(target_nodes=r.get_primaries()[0]) assert len(t) == 2 assert isinstance(t[0], int) assert isinstance(t[1], int) @skip_if_server_version_lt("4.0.0") - async def test_memory_usage(self, r): + async def test_memory_usage(self, r: RedisCluster) -> None: await r.set("foo", "bar") assert isinstance(await r.memory_usage("foo"), int) @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise() - async def test_memory_malloc_stats(self, r): + async def test_memory_malloc_stats(self, r: RedisCluster) -> None: assert await r.memory_malloc_stats() @skip_if_server_version_lt("4.0.0") @skip_if_redis_enterprise() - async def test_memory_stats(self, r): + async def test_memory_stats(self, r: RedisCluster) -> None: # put a key into the current db to make sure that "db." # has data await r.set("foo", "bar") @@ -1119,30 +1152,30 @@ async def test_memory_stats(self, r): assert isinstance(value, dict) @skip_if_server_version_lt("4.0.0") - async def test_memory_help(self, r): + async def test_memory_help(self, r: RedisCluster) -> None: with pytest.raises(NotImplementedError): await r.memory_help() @skip_if_server_version_lt("4.0.0") - async def test_memory_doctor(self, r): + async def test_memory_doctor(self, r: RedisCluster) -> None: with pytest.raises(NotImplementedError): await r.memory_doctor() @skip_if_redis_enterprise() - async def test_lastsave(self, r): + async def test_lastsave(self, r: RedisCluster) -> None: node = r.get_primaries()[0] assert isinstance(await r.lastsave(target_nodes=node), datetime.datetime) - async def test_cluster_echo(self, r): + async def test_cluster_echo(self, r: RedisCluster) -> None: node = r.get_primaries()[0] assert await r.echo("foo bar", target_nodes=node) == b"foo bar" @skip_if_server_version_lt("1.0.0") - async def test_debug_segfault(self, r): + async def test_debug_segfault(self, r: RedisCluster) -> None: with pytest.raises(NotImplementedError): await r.debug_segfault() - async def test_config_resetstat(self, r): + async def test_config_resetstat(self, r: RedisCluster) -> None: node = r.get_primaries()[0] await r.ping(target_nodes=node) prior_commands_processed = int( @@ -1156,14 +1189,14 @@ async def test_config_resetstat(self, r): assert reset_commands_processed < prior_commands_processed @skip_if_server_version_lt("6.2.0") - async def test_client_trackinginfo(self, r): + async def test_client_trackinginfo(self, r: RedisCluster) -> None: node = r.get_primaries()[0] res = await r.client_trackinginfo(target_nodes=node) assert len(res) > 2 assert "prefixes" in res @skip_if_server_version_lt("2.9.50") - async def test_client_pause(self, r): + async def test_client_pause(self, r: RedisCluster) -> None: node = r.get_primaries()[0] assert await r.client_pause(1, target_nodes=node) assert await r.client_pause(timeout=1, target_nodes=node) @@ -1172,16 +1205,16 @@ async def test_client_pause(self, r): @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() - async def test_client_unpause(self, r): + async def test_client_unpause(self, r: RedisCluster) -> None: assert await r.client_unpause() @skip_if_server_version_lt("5.0.0") - async def test_client_id(self, r): + async def test_client_id(self, r: RedisCluster) -> None: node = r.get_primaries()[0] assert await r.client_id(target_nodes=node) > 0 @skip_if_server_version_lt("5.0.0") - async def test_client_unblock(self, r): + async def test_client_unblock(self, r: RedisCluster) -> None: node = r.get_primaries()[0] myid = await r.client_id(target_nodes=node) assert not await r.client_unblock(myid, target_nodes=node) @@ -1189,20 +1222,20 @@ async def test_client_unblock(self, r): assert not await r.client_unblock(myid, error=False, target_nodes=node) @skip_if_server_version_lt("6.0.0") - async def test_client_getredir(self, r): + async def test_client_getredir(self, r: RedisCluster) -> None: node = r.get_primaries()[0] assert isinstance(await r.client_getredir(target_nodes=node), int) assert await r.client_getredir(target_nodes=node) == -1 @skip_if_server_version_lt("6.2.0") - async def test_client_info(self, r): + async def test_client_info(self, r: RedisCluster) -> None: node = r.get_primaries()[0] info = await r.client_info(target_nodes=node) assert isinstance(info, dict) assert "addr" in info @skip_if_server_version_lt("2.6.9") - async def test_client_kill(self, r, r2): + async def test_client_kill(self, r: RedisCluster, r2: RedisCluster) -> None: node = r.get_primaries()[0] await r.client_setname("redis-py-c1", target_nodes="all") await r2.client_setname("redis-py-c2", target_nodes="all") @@ -1226,13 +1259,13 @@ async def test_client_kill(self, r, r2): assert clients[0].get("name") == "redis-py-c1" @skip_if_server_version_lt("2.6.0") - async def test_cluster_bitop_not_empty_string(self, r): + async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: await r.set("{foo}a", "") await r.bitop("not", "{foo}r", "{foo}a") assert await r.get("{foo}r") is None @skip_if_server_version_lt("2.6.0") - async def test_cluster_bitop_not(self, r): + async def test_cluster_bitop_not(self, r: RedisCluster) -> None: test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) @@ -1240,7 +1273,7 @@ async def test_cluster_bitop_not(self, r): assert int(binascii.hexlify(await r.get("{foo}r")), 16) == correct @skip_if_server_version_lt("2.6.0") - async def test_cluster_bitop_not_in_place(self, r): + async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: test_str = b"\xAA\x00\xFF\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) @@ -1248,7 +1281,7 @@ async def test_cluster_bitop_not_in_place(self, r): assert int(binascii.hexlify(await r.get("{foo}a")), 16) == correct @skip_if_server_version_lt("2.6.0") - async def test_cluster_bitop_single_string(self, r): + async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: test_str = b"\x01\x02\xFF" await r.set("{foo}a", test_str) await r.bitop("and", "{foo}res1", "{foo}a") @@ -1259,7 +1292,7 @@ async def test_cluster_bitop_single_string(self, r): assert await r.get("{foo}res3") == test_str @skip_if_server_version_lt("2.6.0") - async def test_cluster_bitop_string_operands(self, r): + async def test_cluster_bitop_string_operands(self, r: RedisCluster) -> None: await r.set("{foo}a", b"\x01\x02\xFF\xFF") await r.set("{foo}b", b"\x01\x02\xFF") await r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") @@ -1270,7 +1303,7 @@ async def test_cluster_bitop_string_operands(self, r): assert int(binascii.hexlify(await r.get("{foo}res3")), 16) == 0x000000FF @skip_if_server_version_lt("6.2.0") - async def test_cluster_copy(self, r): + async def test_cluster_copy(self, r: RedisCluster) -> None: assert await r.copy("{foo}a", "{foo}b") == 0 await r.set("{foo}a", "bar") assert await r.copy("{foo}a", "{foo}b") == 1 @@ -1278,25 +1311,25 @@ async def test_cluster_copy(self, r): assert await r.get("{foo}b") == b"bar" @skip_if_server_version_lt("6.2.0") - async def test_cluster_copy_and_replace(self, r): + async def test_cluster_copy_and_replace(self, r: RedisCluster) -> None: await r.set("{foo}a", "foo1") await r.set("{foo}b", "foo2") assert await r.copy("{foo}a", "{foo}b") == 0 assert await r.copy("{foo}a", "{foo}b", replace=True) == 1 @skip_if_server_version_lt("6.2.0") - async def test_cluster_lmove(self, r): + async def test_cluster_lmove(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "one", "two", "three", "four") assert await r.lmove("{foo}a", "{foo}b") assert await r.lmove("{foo}a", "{foo}b", "right", "left") @skip_if_server_version_lt("6.2.0") - async def test_cluster_blmove(self, r): + async def test_cluster_blmove(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "one", "two", "three", "four") assert await r.blmove("{foo}a", "{foo}b", 5) assert await r.blmove("{foo}a", "{foo}b", 1, "RIGHT", "LEFT") - async def test_cluster_msetnx(self, r): + async def test_cluster_msetnx(self, r: RedisCluster) -> None: d = {"{foo}a": b"1", "{foo}b": b"2", "{foo}c": b"3"} assert await r.msetnx(d) d2 = {"{foo}a": b"x", "{foo}d": b"4"} @@ -1305,13 +1338,13 @@ async def test_cluster_msetnx(self, r): assert await r.get(k) == v assert await r.get("{foo}d") is None - async def test_cluster_rename(self, r): + async def test_cluster_rename(self, r: RedisCluster) -> None: await r.set("{foo}a", "1") assert await r.rename("{foo}a", "{foo}b") assert await r.get("{foo}a") is None assert await r.get("{foo}b") == b"1" - async def test_cluster_renamenx(self, r): + async def test_cluster_renamenx(self, r: RedisCluster) -> None: await r.set("{foo}a", "1") await r.set("{foo}b", "2") assert not await r.renamenx("{foo}a", "{foo}b") @@ -1319,7 +1352,7 @@ async def test_cluster_renamenx(self, r): assert await r.get("{foo}b") == b"2" # LIST COMMANDS - async def test_cluster_blpop(self, r): + async def test_cluster_blpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") @@ -1330,7 +1363,7 @@ async def test_cluster_blpop(self, r): await r.rpush("{foo}c", "1") assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - async def test_cluster_brpop(self, r): + async def test_cluster_brpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") @@ -1341,7 +1374,7 @@ async def test_cluster_brpop(self, r): await r.rpush("{foo}c", "1") assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") - async def test_cluster_brpoplpush(self, r): + async def test_cluster_brpoplpush(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") assert await r.brpoplpush("{foo}a", "{foo}b") == b"2" @@ -1350,24 +1383,24 @@ async def test_cluster_brpoplpush(self, r): assert await r.lrange("{foo}a", 0, -1) == [] assert await r.lrange("{foo}b", 0, -1) == [b"1", b"2", b"3", b"4"] - async def test_cluster_brpoplpush_empty_string(self, r): + async def test_cluster_brpoplpush_empty_string(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "") assert await r.brpoplpush("{foo}a", "{foo}b") == b"" - async def test_cluster_rpoplpush(self, r): + async def test_cluster_rpoplpush(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "a1", "a2", "a3") await r.rpush("{foo}b", "b1", "b2", "b3") assert await r.rpoplpush("{foo}a", "{foo}b") == b"a3" assert await r.lrange("{foo}a", 0, -1) == [b"a1", b"a2"] assert await r.lrange("{foo}b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] - async def test_cluster_sdiff(self, r): + async def test_cluster_sdiff(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2", "3") assert await r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"} await r.sadd("{foo}b", "2", "3") assert await r.sdiff("{foo}a", "{foo}b") == {b"1"} - async def test_cluster_sdiffstore(self, r): + async def test_cluster_sdiffstore(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2", "3") assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3 assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} @@ -1375,13 +1408,13 @@ async def test_cluster_sdiffstore(self, r): assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1 assert await r.smembers("{foo}c") == {b"1"} - async def test_cluster_sinter(self, r): + async def test_cluster_sinter(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2", "3") assert await r.sinter("{foo}a", "{foo}b") == set() await r.sadd("{foo}b", "2", "3") assert await r.sinter("{foo}a", "{foo}b") == {b"2", b"3"} - async def test_cluster_sinterstore(self, r): + async def test_cluster_sinterstore(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2", "3") assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0 assert await r.smembers("{foo}c") == set() @@ -1389,33 +1422,33 @@ async def test_cluster_sinterstore(self, r): assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2 assert await r.smembers("{foo}c") == {b"2", b"3"} - async def test_cluster_smove(self, r): + async def test_cluster_smove(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "a1", "a2") await r.sadd("{foo}b", "b1", "b2") assert await r.smove("{foo}a", "{foo}b", "a1") assert await r.smembers("{foo}a") == {b"a2"} assert await r.smembers("{foo}b") == {b"b1", b"b2", b"a1"} - async def test_cluster_sunion(self, r): + async def test_cluster_sunion(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2") await r.sadd("{foo}b", "2", "3") assert await r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"} - async def test_cluster_sunionstore(self, r): + async def test_cluster_sunionstore(self, r: RedisCluster) -> None: await r.sadd("{foo}a", "1", "2") await r.sadd("{foo}b", "2", "3") assert await r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3 assert await r.smembers("{foo}c") == {b"1", b"2", b"3"} @skip_if_server_version_lt("6.2.0") - async def test_cluster_zdiff(self, r): + async def test_cluster_zdiff(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] @skip_if_server_version_lt("6.2.0") - async def test_cluster_zdiffstore(self, r): + async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) @@ -1423,7 +1456,7 @@ async def test_cluster_zdiffstore(self, r): assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] @skip_if_server_version_lt("6.2.0") - async def test_cluster_zinter(self, r): + async def test_cluster_zinter(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1451,7 +1484,7 @@ async def test_cluster_zinter(self, r): {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True ) == [(b"a3", 20), (b"a1", 23)] - async def test_cluster_zinterstore_sum(self, r): + async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1461,7 +1494,7 @@ async def test_cluster_zinterstore_sum(self, r): (b"a1", 9), ] - async def test_cluster_zinterstore_max(self, r): + async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1476,7 +1509,7 @@ async def test_cluster_zinterstore_max(self, r): (b"a1", 6), ] - async def test_cluster_zinterstore_min(self, r): + async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 2, "a2": 3, "a3": 5}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1491,7 +1524,7 @@ async def test_cluster_zinterstore_min(self, r): (b"a3", 3), ] - async def test_cluster_zinterstore_with_weight(self, r): + async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1504,7 +1537,7 @@ async def test_cluster_zinterstore_with_weight(self, r): ] @skip_if_server_version_lt("4.9.0") - async def test_cluster_bzpopmax(self, r): + async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( @@ -1532,7 +1565,7 @@ async def test_cluster_bzpopmax(self, r): assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) @skip_if_server_version_lt("4.9.0") - async def test_cluster_bzpopmin(self, r): + async def test_cluster_bzpopmin(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( @@ -1560,7 +1593,7 @@ async def test_cluster_bzpopmin(self, r): assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) @skip_if_server_version_lt("6.2.0") - async def test_cluster_zrangestore(self, r): + async def test_cluster_zrangestore(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zrangestore("{foo}b", "{foo}a", 0, 1) assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1585,7 +1618,7 @@ async def test_cluster_zrangestore(self, r): assert await r.zrange("{foo}b", 0, -1) == [b"a2"] @skip_if_server_version_lt("6.2.0") - async def test_cluster_zunion(self, r): + async def test_cluster_zunion(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1615,7 +1648,7 @@ async def test_cluster_zunion(self, r): {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] - async def test_cluster_zunionstore_sum(self, r): + async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1627,7 +1660,7 @@ async def test_cluster_zunionstore_sum(self, r): (b"a1", 9), ] - async def test_cluster_zunionstore_max(self, r): + async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1644,7 +1677,7 @@ async def test_cluster_zunionstore_max(self, r): (b"a1", 6), ] - async def test_cluster_zunionstore_min(self, r): + async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 4}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1661,7 +1694,7 @@ async def test_cluster_zunionstore_min(self, r): (b"a4", 4), ] - async def test_cluster_zunionstore_with_weight(self, r): + async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) @@ -1676,7 +1709,7 @@ async def test_cluster_zunionstore_with_weight(self, r): ] @skip_if_server_version_lt("2.8.9") - async def test_cluster_pfcount(self, r): + async def test_cluster_pfcount(self, r: RedisCluster) -> None: members = {b"1", b"2", b"3"} await r.pfadd("{foo}a", *members) assert await r.pfcount("{foo}a") == len(members) @@ -1686,7 +1719,7 @@ async def test_cluster_pfcount(self, r): assert await r.pfcount("{foo}a", "{foo}b") == len(members_b.union(members)) @skip_if_server_version_lt("2.8.9") - async def test_cluster_pfmerge(self, r): + async def test_cluster_pfmerge(self, r: RedisCluster) -> None: mema = {b"1", b"2", b"3"} memb = {b"2", b"3", b"4"} memc = {b"5", b"6", b"7"} @@ -1698,14 +1731,14 @@ async def test_cluster_pfmerge(self, r): await r.pfmerge("{foo}d", "{foo}b") assert await r.pfcount("{foo}d") == 7 - async def test_cluster_sort_store(self, r): + async def test_cluster_sort_store(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "2", "3", "1") assert await r.sort("{foo}a", store="{foo}sorted_values") == 3 assert await r.lrange("{foo}sorted_values", 0, -1) == [b"1", b"2", b"3"] # GEO COMMANDS @skip_if_server_version_lt("6.2.0") - async def test_cluster_geosearchstore(self, r): + async def test_cluster_geosearchstore(self, r: RedisCluster) -> None: values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, @@ -1724,7 +1757,7 @@ async def test_cluster_geosearchstore(self, r): @skip_unless_arch_bits(64) @skip_if_server_version_lt("6.2.0") - async def test_geosearchstore_dist(self, r): + async def test_geosearchstore_dist(self, r: RedisCluster) -> None: values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, @@ -1744,7 +1777,7 @@ async def test_geosearchstore_dist(self, r): assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 @skip_if_server_version_lt("3.2.0") - async def test_cluster_georadius_store(self, r): + async def test_cluster_georadius_store(self, r: RedisCluster) -> None: values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, @@ -1759,7 +1792,7 @@ async def test_cluster_georadius_store(self, r): @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") - async def test_cluster_georadius_store_dist(self, r): + async def test_cluster_georadius_store_dist(self, r: RedisCluster) -> None: values = (2.1909389952632, 41.433791470673, "place1") + ( 2.1873744593677, 41.406342043777, @@ -1773,12 +1806,12 @@ async def test_cluster_georadius_store_dist(self, r): # instead of save the geo score, the distance is saved. assert await r.zscore("{foo}places_barcelona", "place1") == 88.05060698409301 - async def test_cluster_dbsize(self, r): + async def test_cluster_dbsize(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} assert await r.mset_nonatomic(d) assert await r.dbsize(target_nodes="primaries") == len(d) - async def test_cluster_keys(self, r): + async def test_cluster_keys(self, r: RedisCluster) -> None: assert await r.keys() == [] keys_with_underscores = {b"test_a", b"test_b"} keys = keys_with_underscores.union({b"testc"}) @@ -1792,7 +1825,7 @@ async def test_cluster_keys(self, r): # SCAN COMMANDS @skip_if_server_version_lt("2.8.0") - async def test_cluster_scan(self, r): + async def test_cluster_scan(self, r: RedisCluster) -> None: await r.set("a", 1) await r.set("b", 2) await r.set("c", 3) @@ -1811,7 +1844,7 @@ async def test_cluster_scan(self, r): assert all(cursor == 0 for cursor in cursors.values()) @skip_if_server_version_lt("6.0.0") - async def test_cluster_scan_type(self, r): + async def test_cluster_scan_type(self, r: RedisCluster) -> None: await r.sadd("a-set", 1) await r.sadd("b-set", 1) await r.sadd("c-set", 1) @@ -1834,7 +1867,7 @@ async def test_cluster_scan_type(self, r): assert all(cursor == 0 for cursor in cursors.values()) @skip_if_server_version_lt("2.8.0") - async def test_cluster_scan_iter(self, r): + async def test_cluster_scan_iter(self, r: RedisCluster) -> None: keys_all = [] keys_1 = [] for i in range(100): @@ -1855,7 +1888,7 @@ async def test_cluster_scan_iter(self, r): ] assert sorted(keys) == keys_1 - async def test_cluster_randomkey(self, r): + async def test_cluster_randomkey(self, r: RedisCluster) -> None: node = r.get_node_from_key("{foo}") assert await r.randomkey(target_nodes=node) is None for key in ("{foo}a", "{foo}b", "{foo}c"): @@ -1864,7 +1897,9 @@ async def test_cluster_randomkey(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() - async def test_acl_log(self, r, request, create_redis): + async def test_acl_log( + self, r: RedisCluster, request: FixtureRequest, create_redis: Callable + ) -> None: key = "{cache}:" node = r.get_node_from_key(key) username = "redis-py-user" @@ -1914,7 +1949,7 @@ class TestNodesManager: Tests for the NodesManager class """ - async def test_load_balancer(self, r): + async def test_load_balancer(self, r: RedisCluster) -> None: n_manager = r.nodes_manager lb = n_manager.read_load_balancer slot_1 = 1257 @@ -1946,7 +1981,7 @@ async def test_load_balancer(self, r): assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 - async def test_init_slots_cache_not_all_slots_covered(self): + async def test_init_slots_cache_not_all_slots_covered(self) -> None: """ Test that if not all slots are covered it should raise an exception """ @@ -1968,7 +2003,7 @@ async def test_init_slots_cache_not_all_slots_covered(self): "All slots are not covered after query all startup_nodes." ) - async def test_init_slots_cache_not_require_full_coverage_success(self): + async def test_init_slots_cache_not_require_full_coverage_success(self) -> None: """ When require_full_coverage is set to False and not all slots are covered the cluster client initialization should succeed @@ -1991,7 +2026,7 @@ async def test_init_slots_cache_not_require_full_coverage_success(self): await rc.close() - async def test_init_slots_cache(self): + async def test_init_slots_cache(self) -> None: """ Test that slots cache can in initialized and all slots are covered """ @@ -2022,7 +2057,7 @@ async def test_init_slots_cache(self): await rc.close() - async def test_init_slots_cache_cluster_mode_disabled(self): + async def test_init_slots_cache_cluster_mode_disabled(self) -> None: """ Test that creating a RedisCluster failes if one of the startup nodes has cluster mode disabled @@ -2034,7 +2069,7 @@ async def test_init_slots_cache_cluster_mode_disabled(self): await rc.close() assert "Cluster mode is not enabled on this node" in str(e.value) - async def test_empty_startup_nodes(self): + async def test_empty_startup_nodes(self) -> None: """ It should not be possible to create a node manager with no nodes specified @@ -2042,7 +2077,7 @@ async def test_empty_startup_nodes(self): with pytest.raises(RedisClusterException): await NodesManager([]).initialize() - async def test_wrong_startup_nodes_type(self): + async def test_wrong_startup_nodes_type(self) -> None: """ If something other then a list type itteratable is provided it should fail @@ -2050,7 +2085,9 @@ async def test_wrong_startup_nodes_type(self): with pytest.raises(RedisClusterException): await NodesManager({}).initialize() - async def test_init_slots_cache_slots_collision(self, request): + async def test_init_slots_cache_slots_collision( + self, request: FixtureRequest + ) -> None: """ Test that if 2 nodes do not agree on the same slots setup it should raise an error. In this test both nodes will say that the first @@ -2105,7 +2142,7 @@ async def execute_command(*args, **kwargs): "startup_nodes could not agree on a valid slots cache" ), str(ex.value) - async def test_cluster_one_instance(self): + async def test_cluster_one_instance(self) -> None: """ If the cluster exists of only 1 node then there is some hacks that must be validated they work. @@ -2128,7 +2165,7 @@ async def test_cluster_one_instance(self): await rc.close() - async def test_init_with_down_node(self): + async def test_init_with_down_node(self) -> None: """ If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. diff --git a/whitelist.py b/whitelist.py index 27210284c7..d5d9b9970a 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,4 +14,5 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) +AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) From 6627afac28bbd3824640bfea228348a4852abc47 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sat, 16 Apr 2022 04:17:21 +0530 Subject: [PATCH 07/23] async_cluster: add docs --- docs/connections.rst | 74 ++++++++-- redis/asyncio/cluster.py | 300 +++++++++++++++++++-------------------- redis/asyncio/parser.py | 6 +- 3 files changed, 209 insertions(+), 171 deletions(-) diff --git a/docs/connections.rst b/docs/connections.rst index 9804a15bf1..e4b82cdc50 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -1,20 +1,22 @@ Connecting to Redis -##################### +################### + Generic Client ************** -This is the client used to connect directly to a standard redis node. +This is the client used to connect directly to a standard Redis node. .. autoclass:: redis.Redis :members: + Sentinel Client *************** -Redis `Sentinel `_ provides high availability for Redis. There are commands that can only be executed against a redis node running in sentinel mode. Connecting to those nodes, and executing commands against them requires a Sentinel connection. +Redis `Sentinel `_ provides high availability for Redis. There are commands that can only be executed against a Redis node running in sentinel mode. Connecting to those nodes, and executing commands against them requires a Sentinel connection. -Connection example (assumes redis redis on the ports listed below): +Connection example (assumes Redis exists on the ports listed below): >>> from redis import Sentinel >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) @@ -23,33 +25,85 @@ Connection example (assumes redis redis on the ports listed below): >>> sentinel.discover_slaves('mymaster') [('127.0.0.1', 6380)] +Sentinel +======== .. autoclass:: redis.sentinel.Sentinel :members: +SentinelConnectionPool +====================== .. autoclass:: redis.sentinel.SentinelConnectionPool :members: + Cluster Client ************** -This client is used for connecting to a redis cluser. +This client is used for connecting to a Redis Cluster. +RedisCluster +============ .. autoclass:: redis.cluster.RedisCluster :members: -Connection Pools -***************** -.. autoclass:: redis.connection.ConnectionPool +ClusterNode +=========== +.. autoclass:: redis.cluster.ClusterNode :members: -More connection examples can be found `here `_. Async Client ************ +See complete example: `here `_ + This client is used for communicating with Redis, asynchronously. +.. autoclass:: redis.asyncio.client.Redis + :members: + + +Async Cluster Client +******************** + +RedisCluster (Async) +==================== +.. autoclass:: redis.asyncio.cluster.RedisCluster + :members: + +ClusterNode (Async) +=================== +.. autoclass:: redis.asyncio.cluster.ClusterNode + :members: + + +Connection +********** + +See complete example: `here `_ + +Connection +========== +.. autoclass:: redis.connection.Connection + :members: + +Connection (Async) +================== .. autoclass:: redis.asyncio.connection.Connection :members: -More connection examples can be found `here `_ \ No newline at end of file + +Connection Pools +**************** + +See complete example: `here `_ + +ConnectionPool +============== +.. autoclass:: redis.connection.ConnectionPool + :members: + +ConnectionPool (Async) +====================== +.. autoclass:: redis.asyncio.connection.ConnectionPool + :members: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index ccd2f6671f..820a3370e4 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,7 +3,7 @@ import random import socket import warnings -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Union from redis.asyncio.client import Redis from redis.asyncio.connection import ( @@ -47,9 +47,6 @@ log = logging.getLogger(__name__) -_RedisClusterT = TypeVar("_RedisClusterT", bound="RedisCluster") -_ClusterNodeT = TypeVar("_ClusterNodeT", bound="ClusterNode") - class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( @@ -66,10 +63,78 @@ class ClusterParser(DefaultParser): class RedisCluster(AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new RedisCluster client. + + Pass one of parameters: + + - `url` + - `host` + - `startup_nodes` + + | Use :meth:`initialize` to find cluster nodes & create connections. + | Use :meth:`close` to disconnect connections & close client. + + Many commands support the target_nodes kwarg. It can be one of the + :attr:`NODE_FLAGS`: + + - :attr:`PRIMARIES` + - :attr:`REPLICAS` + - :attr:`ALL_NODES` + - :attr:`RANDOM` + - :attr:`DEFAULT_NODE` + + :param host: + | Can be used to point to a startup node + :param port: + | Port used if **host** or **url** is provided + :param startup_nodes: + | :class:`~.ClusterNode` to used as a startup node + :param cluster_error_retry_attempts: + | Retry command execution attempts when encountering :class:`~.ClusterDownError` + or :class:`~.ConnectionError` + :param require_full_coverage: + | When set to ``False``: the client will not require a full coverage of the + slots. However, if not all slots are covered, and at least one node has + ``cluster-require-full-coverage`` set to ``yes``, the server will throw a + :class:`~.ClusterDownError` for some key-based commands. + | When set to ``True``: all slots must be covered to construct the cluster + client. If not all slots are covered, :class:`~.RedisClusterException` will be + thrown. + | See: + https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param reinitialize_steps: + | Specifies the number of MOVED errors that need to occur before reinitializing + the whole cluster topology. If a MOVED error occurs and the cluster does not + need to be reinitialized on this current error handling, only the MOVED slot + will be patched with the redirected node. + To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set reinitialize_steps to + 0. + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. + :param url: + | See :meth:`.from_url` + :param kwargs: + | Extra arguments that will be passed to the + :class:`~redis.asyncio.client.Redis` instance when created + + :raises RedisClusterException: + if any arguments are invalid. Eg: + + - db kwarg + - db != 0 in url + - unix socket connection + - none of host & url & startup_nodes were provided + + """ + @classmethod - def from_url(cls, url: str, **kwargs) -> _RedisClusterT: + def from_url(cls, url: str, **kwargs) -> "RedisCluster": """ - Return a Redis client object configured from the given URL + Return a Redis client object configured from the given URL. For example:: @@ -130,7 +195,7 @@ def __init__( self, host: Optional[str] = None, port: int = 6379, - startup_nodes: Optional[List[_ClusterNodeT]] = None, + startup_nodes: Optional[List["ClusterNode"]] = None, cluster_error_retry_attempts: int = 3, require_full_coverage: bool = False, reinitialize_steps: int = 10, @@ -138,52 +203,6 @@ def __init__( url: Optional[str] = None, **kwargs, ) -> None: - """ - Initialize a new RedisCluster client. - - :startup_nodes: 'list[ClusterNode]' - List of nodes from which initial bootstrapping can be done - :host: 'str' - Can be used to point to a startup node - :port: 'int' - Can be used to point to a startup node - :require_full_coverage: 'bool' - When set to False (default value): the client will not require a - full coverage of the slots. However, if not all slots are covered, - and at least one node has 'cluster-require-full-coverage' set to - 'yes,' the server will throw a ClusterDownError for some key-based - commands. See - - https://redis.io/topics/cluster-tutorial#redis-cluster-configuration-parameters - When set to True: all slots must be covered to construct the - cluster client. If not all slots are covered, RedisClusterException - will be thrown. - :read_from_replicas: 'bool' - Enable read from replicas in READONLY mode. You can read possibly - stale data. - When set to true, read commands will be assigned between the - primary and its replications in a Round-Robin manner. - :cluster_error_retry_attempts: 'int' - Retry command execution attempts when encountering ClusterDownError - or ConnectionError - :reinitialize_steps: 'int' - Specifies the number of MOVED errors that need to occur before - reinitializing the whole cluster topology. If a MOVED error occurs - and the cluster does not need to be reinitialized on this current - error handling, only the MOVED slot will be patched with the - redirected node. - To reinitialize the cluster on every MOVED error, set - reinitialize_steps to 1. - To avoid reinitializing the cluster on moved errors, set - reinitialize_steps to 0. - - :**kwargs: - Extra arguments that will be sent into Redis instance when created - (See Official redis-py doc for supported kwargs - [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) - Some kwargs are not supported and will raise a - RedisClusterException: - - db (Redis do not support database SELECT in cluster mode) - """ if startup_nodes is None: startup_nodes = [] @@ -258,7 +277,8 @@ def __init__( self._initialize = True self._lock = asyncio.Lock() - async def initialize(self) -> _RedisClusterT: + async def initialize(self) -> "RedisCluster": + """Get all nodes from startup nodes & creates connections if not initialized.""" if self._initialize: async with self._lock: if self._initialize: @@ -274,13 +294,14 @@ async def initialize(self) -> _RedisClusterT: return self async def close(self) -> None: + """Close all connections & client if initialized.""" if not self._initialize: async with self._lock: if not self._initialize: self._initialize = True await self.nodes_manager.close() - async def __aenter__(self) -> _RedisClusterT: + async def __aenter__(self) -> "RedisCluster": return await self.initialize() async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: @@ -304,10 +325,6 @@ def __del__(self, _warnings=warnings): ... async def on_connect(self, connection: Connection) -> None: - """ - Initialize the connection, authenticate and select a database and send - READONLY if it is set during object initialization. - """ connection.set_parser(ClusterParser) await connection.on_connect() @@ -326,26 +343,38 @@ def get_node( host: Optional[str] = None, port: Optional[int] = None, node_name: Optional[str] = None, - ) -> Optional[_ClusterNodeT]: + ) -> Optional["ClusterNode"]: + """Get node by (host, port) or node_name.""" return self.nodes_manager.get_node(host, port, node_name) - def get_primaries(self) -> List[_ClusterNodeT]: + def get_primaries(self) -> List["ClusterNode"]: + """Get the primary nodes of the cluster.""" return self.nodes_manager.get_nodes_by_server_type(PRIMARY) - def get_replicas(self) -> List[_ClusterNodeT]: + def get_replicas(self) -> List["ClusterNode"]: + """Get the replica nodes of the cluster.""" return self.nodes_manager.get_nodes_by_server_type(REPLICA) - def get_random_node(self) -> _ClusterNodeT: + def get_random_node(self) -> "ClusterNode": + """Get a random node of the cluster.""" return random.choice(list(self.nodes_manager.nodes_cache.values())) - def get_nodes(self) -> List[_ClusterNodeT]: + def get_nodes(self) -> List["ClusterNode"]: + """Get all nodes of the cluster.""" return list(self.nodes_manager.nodes_cache.values()) - def get_node_from_key(self, key: str, replica: bool = False) -> _ClusterNodeT: + def get_node_from_key( + self, key: str, replica: bool = False + ) -> Optional["ClusterNode"]: """ - Get the node that holds the key's slot. - If replica set to True but the slot doesn't have any replicas, None is - returned. + Get the cluster node corresponding to the provided key. + + :param key: + :param replica: + | Indicates if a replica should be returned + None will returned if no replica holds this key + + :raises SlotNotCoveredError: if the key is not covered by any slot. """ slot = self.keyslot(key) slot_cache = self.nodes_manager.slots_cache.get(slot) @@ -361,18 +390,12 @@ def get_node_from_key(self, key: str, replica: bool = False) -> _ClusterNodeT: return slot_cache[node_idx] - def get_default_node(self) -> _ClusterNodeT: - """ - Get the cluster's default node - """ + def get_default_node(self) -> "ClusterNode": + """Get the default node of the client.""" return self.nodes_manager.default_node - def set_default_node(self, node: Optional[_ClusterNodeT]) -> bool: - """ - Set the default node of the cluster. - :param node: 'ClusterNode' - :return True if the default node was set, else False - """ + def set_default_node(self, node: "ClusterNode") -> bool: + """Set the default node of the client.""" if node is None or self.get_node(node_name=node.name) is None: log.info( "The requested node does not exist in the cluster, so " @@ -384,10 +407,10 @@ def set_default_node(self, node: Optional[_ClusterNodeT]) -> bool: return True def set_response_callback(self, command: KeyT, callback: Callable) -> None: - """Set a custom Response Callback""" + """Set a custom response callback.""" self.cluster_response_callbacks[command] = callback - async def _determine_nodes(self, *args, **kwargs) -> List[_ClusterNodeT]: + async def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: command = args[0] nodes_flag = kwargs.pop("nodes_flag", None) if nodes_flag is not None: @@ -424,22 +447,16 @@ async def _determine_nodes(self, *args, **kwargs) -> List[_ClusterNodeT]: log.debug(f"Target for {args}: slot {slot}") return [node] - def keyslot(self, key: Union[str, int, float, bytes]) -> int: + def keyslot(self, key: EncodableT) -> int: """ - Calculate keyslot for a given key. - See Keys distribution model in https://redis.io/topics/cluster-spec + Find the keyslot for a given key. + + See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding """ k = self.encoder.encode(key) return key_slot(k) async def determine_slot(self, *args) -> int: - """ - Figure out what slot to use based on args. - - Raises a RedisClusterException if there's a missing key and we can't - determine what slots to map the command to; or, if the keys don't - all map to the same key slot. - """ command = args[0] if self.command_flags.get(command) == SLOT_ID: # The command contains the slot ID @@ -495,25 +512,23 @@ async def determine_slot(self, *args) -> int: return slots.pop() def get_encoder(self) -> Encoder: - """ - Get the connections' encoder - """ return self.encoder def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: """ - Get the connections' key-word arguments + Get the kwargs passed to the :class:`~redis.asyncio.client.Redis` object of + each node. """ return self.nodes_manager.connection_kwargs def _is_nodes_flag( - self, target_nodes: Union[List[_ClusterNodeT], _ClusterNodeT, str] + self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] ) -> bool: return isinstance(target_nodes, str) and target_nodes in self.node_flags def _parse_target_nodes( - self, target_nodes: Union[List[_ClusterNodeT], _ClusterNodeT] - ) -> List[_ClusterNodeT]: + self, target_nodes: Union[List["ClusterNode"], "ClusterNode"] + ) -> List["ClusterNode"]: if isinstance(target_nodes, list): nodes = target_nodes elif isinstance(target_nodes, ClusterNode): @@ -535,19 +550,21 @@ def _parse_target_nodes( async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: """ - Wrapper for ERRORS_ALLOW_RETRY error handling. + Execute a raw command on the appropriate cluster node or target_nodes. - It will try the number of times specified by the config option - "self.cluster_error_retry_attempts" which defaults to 3 unless manually - configured. + It will retry the command as specified by :attr:`cluster_error_retry_attempts` & + then raise an exception. - If it reaches the number of times, the command will raise the exception + :param args: + | Raw command args + :param kwargs: - Key argument :target_nodes: can be passed with the following types: - nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM - ClusterNode - list - dict + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot """ target_nodes_specified = False target_nodes = None @@ -603,11 +620,8 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any raise exception async def _execute_command( - self, target_node: _ClusterNodeT, *args: Union[KeyT, EncodableT], **kwargs + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs ) -> Any: - """ - Send a command to a node in the cluster - """ command = args[0] redis_connection = None connection = None @@ -740,16 +754,6 @@ async def _execute_command( raise ClusterError("TTL exhausted.") def _process_result(self, command: KeyT, res: Dict[str, Any], **kwargs) -> Any: - """ - Process the result of the executed command. - The function would return a dict or a single value. - - :type command: str - :type res: dict - - `res` should be in the following format: - Dict - """ if command in self.result_callbacks: return self.result_callbacks[command](command, res, **kwargs) elif len(res) == 1: @@ -761,15 +765,16 @@ def _process_result(self, command: KeyT, res: Dict[str, Any], **kwargs) -> Any: class ClusterNode: + """ + Create a ClusterNode. + + Each ClusterNode manages a :class:`~redis.asyncio.client.Redis` object corresponding + to the (host, port). + """ + __slots__ = ("_lock", "host", "name", "port", "redis_connection", "server_type") - def __init__( - self, - host: str, - port: int, - server_type: Optional[str] = None, - redis_connection: None = None, - ) -> None: + def __init__(self, host: str, port: int, server_type: Optional[str] = None) -> None: if host == "localhost": host = socket.gethostbyname(host) @@ -777,7 +782,7 @@ def __init__( self.port = port self.name = get_node_name(host, port) self.server_type = server_type - self.redis_connection = redis_connection + self.redis_connection = None self._lock = asyncio.Lock() def __repr__(self) -> str: @@ -789,7 +794,7 @@ def __repr__(self) -> str: f"redis_connection={self.redis_connection}]" ) - def __eq__(self, obj: _ClusterNodeT) -> bool: + def __eq__(self, obj: "ClusterNode") -> bool: return isinstance(obj, ClusterNode) and obj.name == self.name _DEL_MESSAGE = "Unclosed ClusterNode object" @@ -807,6 +812,7 @@ def __del__(self, _warnings=warnings): ... async def initialize(self, from_url: bool = False, **kwargs) -> Redis: + """Create a redis object & make connections.""" if not self.redis_connection: async with self._lock: if not self.redis_connection: @@ -822,6 +828,7 @@ async def initialize(self, from_url: bool = False, **kwargs) -> Redis: return self.redis_connection async def close(self) -> None: + """Close all redis client connections & object.""" if self.redis_connection: async with self._lock: if self.redis_connection: @@ -846,7 +853,7 @@ class NodesManager: def __init__( self, - startup_nodes: List[_ClusterNodeT], + startup_nodes: List["ClusterNode"], from_url: bool = False, require_full_coverage: bool = False, **kwargs, @@ -867,12 +874,7 @@ def get_node( host: Optional[str] = None, port: Optional[int] = None, node_name: Optional[str] = None, - ) -> Optional[_ClusterNodeT]: - """ - Get the requested node from the cluster's nodes. - nodes. - :return: ClusterNode if the node exists, else None - """ + ) -> Optional["ClusterNode"]: if host and port: # the user passed host and port if host == "localhost": @@ -889,7 +891,7 @@ def get_node( return None async def set_nodes( - self, old: Dict[str, _ClusterNodeT], new: Dict[str, _ClusterNodeT] + self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"] ) -> None: tasks = [node.close() for name, node in old.items() if name not in new] for name, node in new.items(): @@ -901,9 +903,6 @@ async def set_nodes( await asyncio.gather(*tasks) async def _update_moved_slots(self) -> None: - """ - Update the slot's node with the redirected one - """ e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) if redirected_node is not None: @@ -943,10 +942,7 @@ async def _update_moved_slots(self) -> None: async def get_node_from_slot( self, slot: int, read_from_replicas: bool = False, server_type: None = None - ) -> _ClusterNodeT: - """ - Gets a node that servers this hash slot - """ + ) -> "ClusterNode": if self._moved_exception: async with self._lock: if self._moved_exception: @@ -978,19 +974,14 @@ async def get_node_from_slot( return self.slots_cache[slot][node_idx] - def get_nodes_by_server_type(self, server_type: str) -> List[_ClusterNodeT]: - """ - Get all nodes with the specified server type - :param server_type: 'primary' or 'replica' - :return: list of ClusterNode - """ + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: return [ node for node in self.nodes_cache.values() if node.server_type == server_type ] - def check_slots_coverage(self, slots_cache: Dict[int, List[_ClusterNodeT]]) -> bool: + def check_slots_coverage(self, slots_cache: Dict[int, List["ClusterNode"]]) -> bool: # Validate if all slots are covered or if we should try next # startup node for i in range(0, REDIS_CLUSTER_HASH_SLOTS): @@ -999,11 +990,6 @@ def check_slots_coverage(self, slots_cache: Dict[int, List[_ClusterNodeT]]) -> b return True async def initialize(self) -> None: - """ - Initializes the nodes cache, slots cache and redis connections. - :startup_nodes: - Responsible for discovering other nodes in the cluster - """ log.debug("Initializing the nodes' topology of the cluster") self.reset() tmp_nodes_cache = {} diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 7a84373c26..9bc6bc53f1 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, List, Optional, Union from redis.asyncio.client import Redis from redis.exceptions import RedisError, ResponseError @@ -6,8 +6,6 @@ if TYPE_CHECKING: from redis.asyncio.cluster import RedisCluster -_RedisClusterT = TypeVar("_RedisClusterT", bound="RedisCluster") - class CommandsParser: """ @@ -23,7 +21,7 @@ class CommandsParser: def __init__(self) -> None: self.commands = {} - async def initialize(self, r: _RedisClusterT) -> None: + async def initialize(self, r: "RedisCluster") -> None: commands = await r.execute_command("COMMAND") uppercase_commands = [] for cmd in commands: From 7e6f9f116fd1739fe88e2d3f470564065a965948 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sat, 16 Apr 2022 04:50:48 +0530 Subject: [PATCH 08/23] docs: update sphinx & add sphinx_autodoc_typehints --- docs/conf.py | 9 +++++++-- docs/requirements.txt | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index b99e46c879..618d95aea1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,7 @@ "nbsphinx", "sphinx_gallery.load_style", "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", "sphinx.ext.doctest", "sphinx.ext.viewcode", "sphinx.ext.autosectionlabel", @@ -41,6 +42,10 @@ autosectionlabel_prefix_document = True autosectionlabel_maxdepth = 2 +# AutodocTypehints settings. +always_document_param_types = True +typehints_defaults = "comma" + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -210,7 +215,7 @@ # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ - ("index", "redis-py.tex", "redis-py Documentation", "Redis Inc", "manual"), + ("index", "redis-py.tex", "redis-py Documentation", "Redis Inc", "manual") ] # The name of an image file (relative to this directory) to place at the top of @@ -258,7 +263,7 @@ "redis-py", "One line description of project.", "Miscellaneous", - ), + ) ] # Documents to append as an appendix to all manuals. diff --git a/docs/requirements.txt b/docs/requirements.txt index bbb7dc6149..23ddc948f4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ -sphinx<2 +sphinx<5 docutils<0.18 sphinx-rtd-theme nbsphinx sphinx_gallery ipython +sphinx-autodoc-typehints From 9ce1fede4dc3ff14db5b20e8b0cf411f2d331dd6 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Tue, 19 Apr 2022 01:30:31 +0530 Subject: [PATCH 09/23] async_cluster: move TargetNodesT to cluster module --- redis/asyncio/cluster.py | 6 +++++- redis/commands/cluster.py | 41 ++++++++++++++++++--------------------- whitelist.py | 1 + 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 820a3370e4..56acbe3907 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,7 +3,7 @@ import random import socket import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from redis.asyncio.client import Redis from redis.asyncio.connection import ( @@ -47,6 +47,10 @@ log = logging.getLogger(__name__) +TargetNodesT = TypeVar( + "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] +) + class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index f85db27170..d880febe79 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -11,7 +11,6 @@ Mapping, NoReturn, Optional, - TypeVar, Union, ) @@ -44,11 +43,7 @@ from .redismodules import RedisModuleCommands if TYPE_CHECKING: - from redis.asyncio.cluster import ClusterNode - -_TargetNodesT = TypeVar( - "_TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] -) + from redis.asyncio.cluster import TargetNodesT class ClusterMultiKeyCommands(ClusterCommandsProtocol): @@ -941,7 +936,7 @@ class AsyncRedisClusterCommands( r.cluster_info(target_nodes=RedisCluster.ALL_NODES) """ - def cluster_myid(self, target_node: _TargetNodesT) -> Awaitable: + def cluster_myid(self, target_node: "TargetNodesT") -> Awaitable: """ Returns the node’s id. @@ -953,7 +948,7 @@ def cluster_myid(self, target_node: _TargetNodesT) -> Awaitable: return self.execute_command("CLUSTER MYID", target_nodes=target_node) def cluster_addslots( - self, target_node: _TargetNodesT, *slots: EncodableT + self, target_node: "TargetNodesT", *slots: EncodableT ) -> Awaitable: """ Assign new hash slots to receiving node. Sends to specified node. @@ -968,7 +963,7 @@ def cluster_addslots( ) def cluster_addslotsrange( - self, target_node: _TargetNodesT, *slots: EncodableT + self, target_node: "TargetNodesT", *slots: EncodableT ) -> Awaitable: """ Similar to the CLUSTER ADDSLOTS command. @@ -1028,7 +1023,7 @@ def cluster_delslotsrange(self, *slots: EncodableT) -> Awaitable: return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) def cluster_failover( - self, target_node: _TargetNodesT, option: Optional[str] = None + self, target_node: "TargetNodesT", option: Optional[str] = None ) -> Awaitable: """ Forces a slave to perform a manual failover of its master @@ -1051,7 +1046,7 @@ def cluster_failover( else: return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - def cluster_info(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: + def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: """ Provides info about Redis Cluster node state. The command will be sent to a random node in the cluster if no target @@ -1071,7 +1066,7 @@ def cluster_keyslot(self, key: str) -> Awaitable: return self.execute_command("CLUSTER KEYSLOT", key) def cluster_meet( - self, host: str, port: int, target_nodes: Optional[_TargetNodesT] = None + self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None ) -> Awaitable: """ Force a node cluster to handshake with another node. @@ -1092,7 +1087,9 @@ def cluster_nodes(self) -> Awaitable: """ return self.execute_command("CLUSTER NODES") - def cluster_replicate(self, target_nodes: _TargetNodesT, node_id: str) -> Awaitable: + def cluster_replicate( + self, target_nodes: "TargetNodesT", node_id: str + ) -> Awaitable: """ Reconfigure a node as a slave of the specified master node @@ -1103,7 +1100,7 @@ def cluster_replicate(self, target_nodes: _TargetNodesT, node_id: str) -> Awaita ) def cluster_reset( - self, soft: bool = True, target_nodes: Optional[_TargetNodesT] = None + self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None ) -> Awaitable: """ Reset a Redis Cluster node @@ -1118,7 +1115,7 @@ def cluster_reset( ) def cluster_save_config( - self, target_nodes: Optional[_TargetNodesT] = None + self, target_nodes: Optional["TargetNodesT"] = None ) -> Awaitable: """ Forces the node to save cluster state on disk @@ -1136,7 +1133,7 @@ def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> Awaitable: return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) def cluster_set_config_epoch( - self, epoch: int, target_nodes: Optional[_TargetNodesT] = None + self, epoch: int, target_nodes: Optional["TargetNodesT"] = None ) -> Awaitable: """ Set the configuration epoch in a new node @@ -1148,7 +1145,7 @@ def cluster_set_config_epoch( ) def cluster_setslot( - self, target_node: _TargetNodesT, node_id: str, slot_id: int, state: str + self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str ) -> Awaitable: """ Bind an hash slot to a specific node @@ -1177,7 +1174,7 @@ def cluster_setslot_stable(self, slot_id: int) -> Awaitable: return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") def cluster_replicas( - self, node_id: str, target_nodes: Optional[_TargetNodesT] = None + self, node_id: str, target_nodes: Optional["TargetNodesT"] = None ) -> Awaitable: """ Provides a list of replica nodes replicating from the specified primary @@ -1189,7 +1186,7 @@ def cluster_replicas( "CLUSTER REPLICAS", node_id, target_nodes=target_nodes ) - def cluster_slots(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: + def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: """ Get array of Cluster slot to node mappings @@ -1197,7 +1194,7 @@ def cluster_slots(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitab """ return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - def cluster_links(self, target_node: _TargetNodesT) -> Awaitable: + def cluster_links(self, target_node: "TargetNodesT") -> Awaitable: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each peer in the cluster: One for sending outbound messages towards the peer and one @@ -1209,7 +1206,7 @@ def cluster_links(self, target_node: _TargetNodesT) -> Awaitable: """ return self.execute_command("CLUSTER LINKS", target_nodes=target_node) - def readonly(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: + def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: """ Enables read queries. The command will be sent to the default cluster node if target_nodes is @@ -1223,7 +1220,7 @@ def readonly(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: self.read_from_replicas = True return self.execute_command("READONLY", target_nodes=target_nodes) - def readwrite(self, target_nodes: Optional[_TargetNodesT] = None) -> Awaitable: + def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: """ Disables read queries. The command will be sent to the default cluster node if target_nodes is diff --git a/whitelist.py b/whitelist.py index d5d9b9970a..8c9cee3c29 100644 --- a/whitelist.py +++ b/whitelist.py @@ -16,3 +16,4 @@ AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) +TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46) From 66893f62c5b0504bdd640dd16b4004697eb12137 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Tue, 19 Apr 2022 01:45:21 +0530 Subject: [PATCH 10/23] async_cluster/commands: inherit commands from sync class if possible --- redis/commands/cluster.py | 1079 ++++++++++++------------------------- 1 file changed, 342 insertions(+), 737 deletions(-) diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index d880febe79..b188540b0e 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -3,7 +3,6 @@ TYPE_CHECKING, Any, AsyncIterator, - Awaitable, Dict, Iterable, Iterator, @@ -37,6 +36,7 @@ FunctionCommands, ManagementCommands, PubSubCommands, + ResponseT, ScriptCommands, ) from .helpers import list_or_args @@ -51,7 +51,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol): A class containing commands that handle more than one key """ - def _partition_keys_by_slot(self, keys): + def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: """ Split keys into a dictionary that maps a slot to a list of keys. @@ -64,7 +64,7 @@ def _partition_keys_by_slot(self, keys): return slots_to_keys - def mget_nonatomic(self, keys, *args): + def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET for the keys of every slot. This operation will not be atomic @@ -100,7 +100,7 @@ def mget_nonatomic(self, keys, *args): vals_in_order = [all_results[key] for key in keys] return vals_in_order - def mset_nonatomic(self, mapping): + def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -129,7 +129,7 @@ def mset_nonatomic(self, mapping): return res - def _split_command_across_slots(self, command, *keys): + def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: """ Runs the given command once for the keys of each slot. Returns the sum of the return values. @@ -144,7 +144,7 @@ def _split_command_across_slots(self, command, *keys): return total - def exists(self, *keys): + def exists(self, *keys: KeyT) -> ResponseT: """ Returns the number of ``names`` that exist in the whole cluster. The keys are first split up into slots @@ -154,7 +154,7 @@ def exists(self, *keys): """ return self._split_command_across_slots("EXISTS", *keys) - def delete(self, *keys): + def delete(self, *keys: KeyT) -> ResponseT: """ Deletes the given keys in the cluster. The keys are first split up into slots @@ -167,7 +167,7 @@ def delete(self, *keys): """ return self._split_command_across_slots("DEL", *keys) - def touch(self, *keys): + def touch(self, *keys: KeyT) -> ResponseT: """ Updates the last access time of given keys across the cluster. @@ -182,7 +182,7 @@ def touch(self, *keys): """ return self._split_command_across_slots("TOUCH", *keys) - def unlink(self, *keys): + def unlink(self, *keys: KeyT) -> ResponseT: """ Remove the specified keys in a different thread. @@ -197,24 +197,11 @@ def unlink(self, *keys): return self._split_command_across_slots("UNLINK", *keys) -class AsyncClusterMultiKeyCommands(ClusterCommandsProtocol): +class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): """ A class containing commands that handle more than one key """ - def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: - """ - Split keys into a dictionary that maps a slot to - a list of keys. - """ - slots_to_keys = {} - for key in keys: - k = self.encoder.encode(key) - slot = key_slot(k) - slots_to_keys.setdefault(slot, []).append(key) - - return slots_to_keys - async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: """ Splits the keys into different slots and then calls MGET @@ -300,58 +287,6 @@ async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: ) ) - def exists(self, *keys: KeyT) -> Awaitable: - """ - Returns the number of ``names`` that exist in the - whole cluster. The keys are first split up into slots - and then an EXISTS command is sent for every slot - - For more information see https://redis.io/commands/exists - """ - return self._split_command_across_slots("EXISTS", *keys) - - def delete(self, *keys: KeyT) -> Awaitable: - """ - Deletes the given keys in the cluster. - The keys are first split up into slots - and then an DEL command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were deleted. - - For more information see https://redis.io/commands/del - """ - return self._split_command_across_slots("DEL", *keys) - - def touch(self, *keys: KeyT) -> Awaitable: - """ - Updates the last access time of given keys across the - cluster. - - The keys are first split up into slots - and then an TOUCH command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were touched. - - For more information see https://redis.io/commands/touch - """ - return self._split_command_across_slots("TOUCH", *keys) - - def unlink(self, *keys: KeyT) -> Awaitable: - """ - Remove the specified keys in a different thread. - - The keys are first split up into slots - and then an TOUCH command is sent for every slot - - Non-existant keys are ignored. - Returns the number of keys that were unlinked. - - For more information see https://redis.io/commands/unlink - """ - return self._split_command_across_slots("UNLINK", *keys) - class ClusterManagementCommands(ManagementCommands): """ @@ -361,7 +296,7 @@ class ClusterManagementCommands(ManagementCommands): required adjustments to work with cluster mode """ - def slaveof(self, *args, **kwargs): + def slaveof(self, *args, **kwargs) -> NoReturn: """ Make the server a replica of another instance, or promote it as master. @@ -369,7 +304,7 @@ def slaveof(self, *args, **kwargs): """ raise RedisClusterException("SLAVEOF is not supported in cluster mode") - def replicaof(self, *args, **kwargs): + def replicaof(self, *args, **kwargs) -> NoReturn: """ Make the server a replica of another instance, or promote it as master. @@ -377,7 +312,7 @@ def replicaof(self, *args, **kwargs): """ raise RedisClusterException("REPLICAOF is not supported in cluster" " mode") - def swapdb(self, *args, **kwargs): + def swapdb(self, *args, **kwargs) -> NoReturn: """ Swaps two Redis databases. @@ -385,681 +320,128 @@ def swapdb(self, *args, **kwargs): """ raise RedisClusterException("SWAPDB is not supported in cluster" " mode") + def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT: + """ + Returns the node’s id. -class AsyncClusterManagementCommands(AsyncManagementCommands): - """ - A class for Redis Cluster management commands - - The class inherits from Redis's core ManagementCommands class and do the - required adjustments to work with cluster mode - """ + :target_node: 'ClusterNode' + The node to execute the command on - def slaveof(self, *args, **kwargs) -> NoReturn: + For more information check https://redis.io/commands/cluster-myid/ """ - Make the server a replica of another instance, or promote it as master. + return self.execute_command("CLUSTER MYID", target_nodes=target_node) - For more information see https://redis.io/commands/slaveof + def cluster_addslots( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: """ - raise RedisClusterException("SLAVEOF is not supported in cluster mode") + Assign new hash slots to receiving node. Sends to specified node. - def replicaof(self, *args, **kwargs) -> NoReturn: + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslots """ - Make the server a replica of another instance, or promote it as master. + return self.execute_command( + "CLUSTER ADDSLOTS", *slots, target_nodes=target_node + ) - For more information see https://redis.io/commands/replicaof + def cluster_addslotsrange( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: """ - raise RedisClusterException("REPLICAOF is not supported in cluster" " mode") + Similar to the CLUSTER ADDSLOTS command. + The difference between the two commands is that ADDSLOTS takes a list of slots + to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges + (specified by start and end slots) to assign to the node. - def swapdb(self, *args, **kwargs) -> NoReturn: + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslotsrange """ - Swaps two Redis databases. + return self.execute_command( + "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node + ) - For more information see https://redis.io/commands/swapdb + def cluster_countkeysinslot(self, slot_id: int) -> ResponseT: """ - raise RedisClusterException("SWAPDB is not supported in cluster" " mode") + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + For more information see https://redis.io/commands/cluster-countkeysinslot + """ + return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) -class ClusterDataAccessCommands(DataAccessCommands): - """ - A class for Redis Cluster Data Access Commands + def cluster_count_failure_report(self, node_id: str) -> ResponseT: + """ + Return the number of failure reports active for a given node + Sends to a random node - The class inherits from Redis's core DataAccessCommand class and do the - required adjustments to work with cluster mode - """ + For more information see https://redis.io/commands/cluster-count-failure-reports + """ + return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - def stralgo( - self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, - **kwargs, - ): + def cluster_delslots(self, *slots: EncodableT) -> List[bool]: """ - Implements complex algorithms that operate on strings. - Right now the only algorithm implemented is the LCS algorithm - (longest common substring). However new algorithms could be - implemented in the future. + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there - ``algo`` Right now must be LCS - ``value1`` and ``value2`` Can be two strings or two keys - ``specific_argument`` Specifying if the arguments to the algorithm - will be keys or strings. strings is the default. - ``len`` Returns just the len of the match. - ``idx`` Returns the match positions in each string. - ``minmatchlen`` Restrict the list of matches to the ones of a given - minimal length. Can be provided only when ``idx`` set to True. - ``withmatchlen`` Returns the matches with the len of the match. - Can be provided only when ``idx`` set to True. + Returns a list of the results for each processed slot. - For more information see https://redis.io/commands/stralgo + For more information see https://redis.io/commands/cluster-delslots """ - target_nodes = kwargs.pop("target_nodes", None) - if specific_argument == "strings" and target_nodes is None: - target_nodes = "default-node" - kwargs.update({"target_nodes": target_nodes}) - return super().stralgo( - algo, - value1, - value2, - specific_argument, - len, - idx, - minmatchlen, - withmatchlen, - **kwargs, - ) + return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] - def scan_iter( - self, - match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, - **kwargs, - ) -> Iterator: - # Do the first query with cursor=0 for all nodes - cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) - yield from data + def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT: + """ + Similar to the CLUSTER DELSLOTS command. + The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove + from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove + from the node. - cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} - if cursors: - # Get nodes by name - nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + For more information see https://redis.io/commands/cluster-delslotsrange + """ + return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) - # Iterate over each node till its cursor is 0 - kwargs.pop("target_nodes", None) - while cursors: - for name, cursor in cursors.items(): - cur, data = self.scan( - cursor=cursor, - match=match, - count=count, - _type=_type, - target_nodes=nodes[name], - **kwargs, - ) - yield from data - cursors[name] = cur[name] + def cluster_failover( + self, target_node: "TargetNodesT", option: Optional[str] = None + ) -> ResponseT: + """ + Forces a slave to perform a manual failover of its master + Sends to specified node - cursors = { - name: cursor for name, cursor in cursors.items() if cursor != 0 - } + :target_node: 'ClusterNode' + The node to execute the command on + For more information see https://redis.io/commands/cluster-failover + """ + if option: + if option.upper() not in ["FORCE", "TAKEOVER"]: + raise RedisError( + f"Invalid option for CLUSTER FAILOVER command: {option}" + ) + else: + return self.execute_command( + "CLUSTER FAILOVER", option, target_nodes=target_node + ) + else: + return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) -class AsyncClusterDataAccessCommands(AsyncDataAccessCommands): - """ - A class for Redis Cluster Data Access Commands + def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. - The class inherits from Redis's core DataAccessCommand class and do the - required adjustments to work with cluster mode - """ + For more information see https://redis.io/commands/cluster-info + """ + return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - def stralgo( - self, - algo: Literal["LCS"], - value1: KeyT, - value2: KeyT, - specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", - len: bool = False, - idx: bool = False, - minmatchlen: Optional[int] = None, - withmatchlen: bool = False, - **kwargs, - ) -> Awaitable: + def cluster_keyslot(self, key: str) -> ResponseT: """ - Implements complex algorithms that operate on strings. - Right now the only algorithm implemented is the LCS algorithm - (longest common substring). However new algorithms could be - implemented in the future. - - ``algo`` Right now must be LCS - ``value1`` and ``value2`` Can be two strings or two keys - ``specific_argument`` Specifying if the arguments to the algorithm - will be keys or strings. strings is the default. - ``len`` Returns just the len of the match. - ``idx`` Returns the match positions in each string. - ``minmatchlen`` Restrict the list of matches to the ones of a given - minimal length. Can be provided only when ``idx`` set to True. - ``withmatchlen`` Returns the matches with the len of the match. - Can be provided only when ``idx`` set to True. - - For more information see https://redis.io/commands/stralgo - """ - target_nodes = kwargs.pop("target_nodes", None) - if specific_argument == "strings" and target_nodes is None: - target_nodes = "default-node" - kwargs.update({"target_nodes": target_nodes}) - return super().stralgo( - algo, - value1, - value2, - specific_argument, - len, - idx, - minmatchlen, - withmatchlen, - **kwargs, - ) - - async def scan_iter( - self, - match: Optional[PatternT] = None, - count: Optional[int] = None, - _type: Optional[str] = None, - **kwargs, - ) -> AsyncIterator: - # Do the first query with cursor=0 for all nodes - cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) - for value in data: - yield value - - cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} - if cursors: - # Get nodes by name - nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} - - # Iterate over each node till its cursor is 0 - kwargs.pop("target_nodes", None) - while cursors: - for name, cursor in cursors.items(): - cur, data = await self.scan( - cursor=cursor, - match=match, - count=count, - _type=_type, - target_nodes=nodes[name], - **kwargs, - ) - for value in data: - yield value - cursors[name] = cur[name] - - cursors = { - name: cursor for name, cursor in cursors.items() if cursor != 0 - } - - -class RedisClusterCommands( - ClusterMultiKeyCommands, - ClusterManagementCommands, - ACLCommands, - PubSubCommands, - ClusterDataAccessCommands, - ScriptCommands, - FunctionCommands, - RedisModuleCommands, -): - """ - A class for all Redis Cluster commands - - For key-based commands, the target node(s) will be internally determined - by the keys' hash slot. - Non-key-based commands can be executed with the 'target_nodes' argument to - target specific nodes. By default, if target_nodes is not specified, the - command will be executed on the default cluster node. - - - :param :target_nodes: type can be one of the followings: - - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM - - 'ClusterNode' - - 'list(ClusterNodes)' - - 'dict(any:clusterNodes)' - - for example: - r.cluster_info(target_nodes=RedisCluster.ALL_NODES) - """ - - def cluster_myid(self, target_node): - """ - Returns the node’s id. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information check https://redis.io/commands/cluster-myid/ - """ - return self.execute_command("CLUSTER MYID", target_nodes=target_node) - - def cluster_addslots(self, target_node, *slots): - """ - Assign new hash slots to receiving node. Sends to specified node. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-addslots - """ - return self.execute_command( - "CLUSTER ADDSLOTS", *slots, target_nodes=target_node - ) - - def cluster_addslotsrange(self, target_node, *slots): - """ - Similar to the CLUSTER ADDSLOTS command. - The difference between the two commands is that ADDSLOTS takes a list of slots - to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges - (specified by start and end slots) to assign to the node. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-addslotsrange - """ - return self.execute_command( - "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node - ) - - def cluster_countkeysinslot(self, slot_id): - """ - Return the number of local keys in the specified hash slot - Send to node based on specified slot_id - - For more information see https://redis.io/commands/cluster-countkeysinslot - """ - return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) - - def cluster_count_failure_report(self, node_id): - """ - Return the number of failure reports active for a given node - Sends to a random node - - For more information see https://redis.io/commands/cluster-count-failure-reports - """ - return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - - def cluster_delslots(self, *slots): - """ - Set hash slots as unbound in the cluster. - It determines by it self what node the slot is in and sends it there - - Returns a list of the results for each processed slot. - - For more information see https://redis.io/commands/cluster-delslots - """ - return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] - - def cluster_delslotsrange(self, *slots): - """ - Similar to the CLUSTER DELSLOTS command. - The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove - from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove - from the node. - - For more information see https://redis.io/commands/cluster-delslotsrange - """ - return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) - - def cluster_failover(self, target_node, option=None): - """ - Forces a slave to perform a manual failover of its master - Sends to specified node - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-failover - """ - if option: - if option.upper() not in ["FORCE", "TAKEOVER"]: - raise RedisError( - f"Invalid option for CLUSTER FAILOVER command: {option}" - ) - else: - return self.execute_command( - "CLUSTER FAILOVER", option, target_nodes=target_node - ) - else: - return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - - def cluster_info(self, target_nodes=None): - """ - Provides info about Redis Cluster node state. - The command will be sent to a random node in the cluster if no target - node is specified. - - For more information see https://redis.io/commands/cluster-info - """ - return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - - def cluster_keyslot(self, key): - """ - Returns the hash slot of the specified key - Sends to random node in the cluster - - For more information see https://redis.io/commands/cluster-keyslot - """ - return self.execute_command("CLUSTER KEYSLOT", key) - - def cluster_meet(self, host, port, target_nodes=None): - """ - Force a node cluster to handshake with another node. - Sends to specified node. - - For more information see https://redis.io/commands/cluster-meet - """ - return self.execute_command( - "CLUSTER MEET", host, port, target_nodes=target_nodes - ) - - def cluster_nodes(self): - """ - Get Cluster config for the node. - Sends to random node in the cluster - - For more information see https://redis.io/commands/cluster-nodes - """ - return self.execute_command("CLUSTER NODES") - - def cluster_replicate(self, target_nodes, node_id): - """ - Reconfigure a node as a slave of the specified master node - - For more information see https://redis.io/commands/cluster-replicate - """ - return self.execute_command( - "CLUSTER REPLICATE", node_id, target_nodes=target_nodes - ) - - def cluster_reset(self, soft=True, target_nodes=None): - """ - Reset a Redis Cluster node - - If 'soft' is True then it will send 'SOFT' argument - If 'soft' is False then it will send 'HARD' argument - - For more information see https://redis.io/commands/cluster-reset - """ - return self.execute_command( - "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes - ) - - def cluster_save_config(self, target_nodes=None): - """ - Forces the node to save cluster state on disk - - For more information see https://redis.io/commands/cluster-saveconfig - """ - return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) - - def cluster_get_keys_in_slot(self, slot, num_keys): - """ - Returns the number of keys in the specified cluster slot - - For more information see https://redis.io/commands/cluster-getkeysinslot - """ - return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) - - def cluster_set_config_epoch(self, epoch, target_nodes=None): - """ - Set the configuration epoch in a new node - - For more information see https://redis.io/commands/cluster-set-config-epoch - """ - return self.execute_command( - "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes - ) - - def cluster_setslot(self, target_node, node_id, slot_id, state): - """ - Bind an hash slot to a specific node - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-setslot - """ - if state.upper() in ("IMPORTING", "NODE", "MIGRATING"): - return self.execute_command( - "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node - ) - elif state.upper() == "STABLE": - raise RedisError('For "stable" state please use ' "cluster_setslot_stable") - else: - raise RedisError(f"Invalid slot state: {state}") - - def cluster_setslot_stable(self, slot_id): - """ - Clears migrating / importing state from the slot. - It determines by it self what node the slot is in and sends it there. - - For more information see https://redis.io/commands/cluster-setslot - """ - return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") - - def cluster_replicas(self, node_id, target_nodes=None): - """ - Provides a list of replica nodes replicating from the specified primary - target node. - - For more information see https://redis.io/commands/cluster-replicas - """ - return self.execute_command( - "CLUSTER REPLICAS", node_id, target_nodes=target_nodes - ) - - def cluster_slots(self, target_nodes=None): - """ - Get array of Cluster slot to node mappings - - For more information see https://redis.io/commands/cluster-slots - """ - return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - - def cluster_links(self, target_node): - """ - Each node in a Redis Cluster maintains a pair of long-lived TCP link with each - peer in the cluster: One for sending outbound messages towards the peer and one - for receiving inbound messages from the peer. - - This command outputs information of all such peer links as an array. - - For more information see https://redis.io/commands/cluster-links - """ - return self.execute_command("CLUSTER LINKS", target_nodes=target_node) - - def readonly(self, target_nodes=None): - """ - Enables read queries. - The command will be sent to the default cluster node if target_nodes is - not specified. - - For more information see https://redis.io/commands/readonly - """ - if target_nodes == "replicas" or target_nodes == "all": - # read_from_replicas will only be enabled if the READONLY command - # is sent to all replicas - self.read_from_replicas = True - return self.execute_command("READONLY", target_nodes=target_nodes) - - def readwrite(self, target_nodes=None): - """ - Disables read queries. - The command will be sent to the default cluster node if target_nodes is - not specified. - - For more information see https://redis.io/commands/readwrite - """ - # Reset read from replicas flag - self.read_from_replicas = False - return self.execute_command("READWRITE", target_nodes=target_nodes) - - -class AsyncRedisClusterCommands( - AsyncClusterMultiKeyCommands, - AsyncClusterManagementCommands, - AsyncACLCommands, - AsyncClusterDataAccessCommands, - AsyncScriptCommands, - AsyncFunctionCommands, -): - """ - A class for all Redis Cluster commands - - For key-based commands, the target node(s) will be internally determined - by the keys' hash slot. - Non-key-based commands can be executed with the 'target_nodes' argument to - target specific nodes. By default, if target_nodes is not specified, the - command will be executed on the default cluster node. - - - :param :target_nodes: type can be one of the followings: - - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM - - 'ClusterNode' - - 'list(ClusterNodes)' - - 'dict(any:clusterNodes)' - - for example: - r.cluster_info(target_nodes=RedisCluster.ALL_NODES) - """ - - def cluster_myid(self, target_node: "TargetNodesT") -> Awaitable: - """ - Returns the node’s id. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information check https://redis.io/commands/cluster-myid/ - """ - return self.execute_command("CLUSTER MYID", target_nodes=target_node) - - def cluster_addslots( - self, target_node: "TargetNodesT", *slots: EncodableT - ) -> Awaitable: - """ - Assign new hash slots to receiving node. Sends to specified node. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-addslots - """ - return self.execute_command( - "CLUSTER ADDSLOTS", *slots, target_nodes=target_node - ) - - def cluster_addslotsrange( - self, target_node: "TargetNodesT", *slots: EncodableT - ) -> Awaitable: - """ - Similar to the CLUSTER ADDSLOTS command. - The difference between the two commands is that ADDSLOTS takes a list of slots - to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges - (specified by start and end slots) to assign to the node. - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-addslotsrange - """ - return self.execute_command( - "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node - ) - - def cluster_countkeysinslot(self, slot_id: int) -> Awaitable: - """ - Return the number of local keys in the specified hash slot - Send to node based on specified slot_id - - For more information see https://redis.io/commands/cluster-countkeysinslot - """ - return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) - - def cluster_count_failure_report(self, node_id: str) -> Awaitable: - """ - Return the number of failure reports active for a given node - Sends to a random node - - For more information see https://redis.io/commands/cluster-count-failure-reports - """ - return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) - - async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: - """ - Set hash slots as unbound in the cluster. - It determines by it self what node the slot is in and sends it there - - Returns a list of the results for each processed slot. - - For more information see https://redis.io/commands/cluster-delslots - """ - return await asyncio.gather( - *[self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] - ) - - def cluster_delslotsrange(self, *slots: EncodableT) -> Awaitable: - """ - Similar to the CLUSTER DELSLOTS command. - The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove - from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove - from the node. - - For more information see https://redis.io/commands/cluster-delslotsrange - """ - return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) - - def cluster_failover( - self, target_node: "TargetNodesT", option: Optional[str] = None - ) -> Awaitable: - """ - Forces a slave to perform a manual failover of its master - Sends to specified node - - :target_node: 'ClusterNode' - The node to execute the command on - - For more information see https://redis.io/commands/cluster-failover - """ - if option: - if option.upper() not in ["FORCE", "TAKEOVER"]: - raise RedisError( - f"Invalid option for CLUSTER FAILOVER command: {option}" - ) - else: - return self.execute_command( - "CLUSTER FAILOVER", option, target_nodes=target_node - ) - else: - return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) - - def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: - """ - Provides info about Redis Cluster node state. - The command will be sent to a random node in the cluster if no target - node is specified. - - For more information see https://redis.io/commands/cluster-info - """ - return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) - - def cluster_keyslot(self, key: str) -> Awaitable: - """ - Returns the hash slot of the specified key - Sends to random node in the cluster + Returns the hash slot of the specified key + Sends to random node in the cluster For more information see https://redis.io/commands/cluster-keyslot """ @@ -1067,7 +449,7 @@ def cluster_keyslot(self, key: str) -> Awaitable: def cluster_meet( self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None - ) -> Awaitable: + ) -> ResponseT: """ Force a node cluster to handshake with another node. Sends to specified node. @@ -1078,7 +460,7 @@ def cluster_meet( "CLUSTER MEET", host, port, target_nodes=target_nodes ) - def cluster_nodes(self) -> Awaitable: + def cluster_nodes(self) -> ResponseT: """ Get Cluster config for the node. Sends to random node in the cluster @@ -1089,7 +471,7 @@ def cluster_nodes(self) -> Awaitable: def cluster_replicate( self, target_nodes: "TargetNodesT", node_id: str - ) -> Awaitable: + ) -> ResponseT: """ Reconfigure a node as a slave of the specified master node @@ -1101,7 +483,7 @@ def cluster_replicate( def cluster_reset( self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None - ) -> Awaitable: + ) -> ResponseT: """ Reset a Redis Cluster node @@ -1116,7 +498,7 @@ def cluster_reset( def cluster_save_config( self, target_nodes: Optional["TargetNodesT"] = None - ) -> Awaitable: + ) -> ResponseT: """ Forces the node to save cluster state on disk @@ -1124,7 +506,7 @@ def cluster_save_config( """ return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) - def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> Awaitable: + def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT: """ Returns the number of keys in the specified cluster slot @@ -1134,7 +516,7 @@ def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> Awaitable: def cluster_set_config_epoch( self, epoch: int, target_nodes: Optional["TargetNodesT"] = None - ) -> Awaitable: + ) -> ResponseT: """ Set the configuration epoch in a new node @@ -1146,7 +528,7 @@ def cluster_set_config_epoch( def cluster_setslot( self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str - ) -> Awaitable: + ) -> ResponseT: """ Bind an hash slot to a specific node @@ -1164,7 +546,7 @@ def cluster_setslot( else: raise RedisError(f"Invalid slot state: {state}") - def cluster_setslot_stable(self, slot_id: int) -> Awaitable: + def cluster_setslot_stable(self, slot_id: int) -> ResponseT: """ Clears migrating / importing state from the slot. It determines by it self what node the slot is in and sends it there. @@ -1175,7 +557,7 @@ def cluster_setslot_stable(self, slot_id: int) -> Awaitable: def cluster_replicas( self, node_id: str, target_nodes: Optional["TargetNodesT"] = None - ) -> Awaitable: + ) -> ResponseT: """ Provides a list of replica nodes replicating from the specified primary target node. @@ -1186,7 +568,7 @@ def cluster_replicas( "CLUSTER REPLICAS", node_id, target_nodes=target_nodes ) - def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: + def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Get array of Cluster slot to node mappings @@ -1194,7 +576,7 @@ def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaita """ return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) - def cluster_links(self, target_node: "TargetNodesT") -> Awaitable: + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each peer in the cluster: One for sending outbound messages towards the peer and one @@ -1206,7 +588,7 @@ def cluster_links(self, target_node: "TargetNodesT") -> Awaitable: """ return self.execute_command("CLUSTER LINKS", target_nodes=target_node) - def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: + def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Enables read queries. The command will be sent to the default cluster node if target_nodes is @@ -1220,7 +602,7 @@ def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: self.read_from_replicas = True return self.execute_command("READONLY", target_nodes=target_nodes) - def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: + def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: """ Disables read queries. The command will be sent to the default cluster node if target_nodes is @@ -1231,3 +613,226 @@ def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> Awaitable: # Reset read from replicas flag self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) + + +class AsyncClusterManagementCommands( + ClusterManagementCommands, AsyncManagementCommands +): + """ + A class for Redis Cluster management commands + + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ + + async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + + For more information see https://redis.io/commands/cluster-delslots + """ + return await asyncio.gather( + *[self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + ) + + +class ClusterDataAccessCommands(DataAccessCommands): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + def stralgo( + self, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Optional[int] = None, + withmatchlen: bool = False, + **kwargs, + ) -> ResponseT: + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + + For more information see https://redis.io/commands/stralgo + """ + target_nodes = kwargs.pop("target_nodes", None) + if specific_argument == "strings" and target_nodes is None: + target_nodes = "default-node" + kwargs.update({"target_nodes": target_nodes}) + return super().stralgo( + algo, + value1, + value2, + specific_argument, + len, + idx, + minmatchlen, + withmatchlen, + **kwargs, + ) + + def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> Iterator: + # Do the first query with cursor=0 for all nodes + cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) + yield from data + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + yield from data + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class AsyncClusterDataAccessCommands( + ClusterDataAccessCommands, AsyncDataAccessCommands +): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + async def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> AsyncIterator: + # Do the first query with cursor=0 for all nodes + cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) + for value in data: + yield value + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = await self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + for value in data: + yield value + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class RedisClusterCommands( + ClusterMultiKeyCommands, + ClusterManagementCommands, + ACLCommands, + PubSubCommands, + ClusterDataAccessCommands, + ScriptCommands, + FunctionCommands, + RedisModuleCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ + + +class AsyncRedisClusterCommands( + AsyncClusterMultiKeyCommands, + AsyncClusterManagementCommands, + AsyncACLCommands, + AsyncClusterDataAccessCommands, + AsyncScriptCommands, + AsyncFunctionCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ From 7ac208f4ff8840763e123aa9037c791a54545d29 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Wed, 20 Apr 2022 03:45:04 +0530 Subject: [PATCH 11/23] async_cluster: add benchmark script with aredis & aioredis-cluster --- benchmarks/cluster_async.py | 256 ++++++++++++++++++++++++++++++++++++ tox.ini | 3 +- 2 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 benchmarks/cluster_async.py diff --git a/benchmarks/cluster_async.py b/benchmarks/cluster_async.py new file mode 100644 index 0000000000..7211078440 --- /dev/null +++ b/benchmarks/cluster_async.py @@ -0,0 +1,256 @@ +import asyncio +import functools +import time + +import aioredis_cluster +import aredis +import uvloop + +import redis.asyncio as redispy + + +def timer(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + tic = time.perf_counter() + await func(*args, **kwargs) + toc = time.perf_counter() + return f"{toc - tic:.4f}" + + return wrapper + + +@timer +async def set_str(client, gather, data): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.set(f"bench:str_{i}", data)) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.set(f"bench:str_{i}", data) + + +@timer +async def set_int(client, gather, data): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.set(f"bench:int_{i}", data)) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.set(f"bench:int_{i}", data) + + +@timer +async def get_str(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.get(f"bench:str_{i}")) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.get(f"bench:str_{i}") + + +@timer +async def get_int(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.get(f"bench:int_{i}")) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.get(f"bench:int_{i}") + + +@timer +async def hset(client, gather, data): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.hset("bench:hset", str(i), data)) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.hset("bench:hset", str(i), data) + + +@timer +async def hget(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.hget("bench:hset", str(i))) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.hget("bench:hset", str(i)) + + +@timer +async def incr(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.incr("bench:incr")) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.incr("bench:incr") + + +@timer +async def lpush(client, gather, data): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.lpush("bench:lpush", data)) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.lpush("bench:lpush", data) + + +@timer +async def lrange_300(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.lrange("bench:lpush", i, i + 300)) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.lrange("bench:lpush", i, i + 300) + + +@timer +async def lpop(client, gather): + if gather: + for _ in range(count // 100): + tasks = [] + for i in range(100): + tasks.append(client.lpop("bench:lpush")) + await asyncio.gather(*tasks) + else: + for i in range(count): + await client.lpop("bench:lpush") + + +@timer +async def warmup(client): + tasks = [] + for i in range(1000): + tasks.append(client.exists(f"bench:warmup_{i}")) + await asyncio.gather(*tasks) + + +@timer +async def run(client, gather): + data_str = "a" * size + data_int = int("1" * size) + + if gather is False: + for ret in await asyncio.gather( + set_str(client, gather, data_str), + set_int(client, gather, data_int), + hset(client, gather, data_str), + incr(client, gather), + lpush(client, gather, data_int), + ): + print(ret) + for ret in await asyncio.gather( + get_str(client, gather), + get_int(client, gather), + hget(client, gather), + lrange_300(client, gather), + lpop(client, gather), + ): + print(ret) + else: + print(await set_str(client, gather, data_str)) + print(await set_int(client, gather, data_int)) + print(await hset(client, gather, data_str)) + print(await incr(client, gather)) + print(await lpush(client, gather, data_int)) + + print(await get_str(client, gather)) + print(await get_int(client, gather)) + print(await hget(client, gather)) + print(await lrange_300(client, gather)) + print(await lpop(client, gather)) + + +async def main(loop, gather=None): + arc = aredis.StrictRedisCluster( + host=host, + port=port, + password=password, + max_connections=2 ** 31, + max_connections_per_node=2 ** 31, + readonly=False, + reinitialize_steps=count, + skip_full_coverage_check=True, + decode_responses=False, + max_idle_time=count, + idle_check_interval=count, + ) + print(f"{loop} {gather} {await warmup(arc)} aredis") + print(await run(arc, gather=gather)) + arc.connection_pool.disconnect() + + aiorc = await aioredis_cluster.create_redis_cluster( + [(host, port)], + password=password, + state_reload_interval=count, + idle_connection_timeout=count, + pool_maxsize=2 ** 31, + ) + print(f"{loop} {gather} {await warmup(aiorc)} aioredis-cluster") + print(await run(aiorc, gather=gather)) + aiorc.close() + await aiorc.wait_closed() + + async with redispy.RedisCluster( + host=host, + port=port, + password=password, + reinitialize_steps=count, + read_from_replicas=False, + decode_responses=False, + max_connections=2 ** 31, + ) as rca: + print(f"{loop} {gather} {await warmup(rca)} redispy") + print(await run(rca, gather=gather)) + + +if __name__ == "__main__": + host = "localhost" + port = 16379 + password = None + + count = 1000 + size = 16 + + asyncio.run(main("asyncio")) + asyncio.run(main("asyncio", gather=False)) + asyncio.run(main("asyncio", gather=True)) + + uvloop.install() + + asyncio.run(main("uvloop")) + asyncio.run(main("uvloop", gather=False)) + asyncio.run(main("uvloop", gather=True)) diff --git a/tox.ini b/tox.ini index 4641ec3638..0ceb008cf6 100644 --- a/tox.ini +++ b/tox.ini @@ -115,7 +115,7 @@ volumes = [docker:redismod_cluster] name = redismod_cluster image = redisfab/redis-py-modcluster:6.2.6 -ports = +ports = 46379:46379/tcp 46380:46380/tcp 46381:46381/tcp @@ -337,7 +337,6 @@ skipsdist = true skip_install = true deps = -r {toxinidir}/dev_requirements.txt docker = {[testenv]docker} -commands = /usr/bin/echo docker_up [testenv:linters] deps_files = dev_requirements.txt From 0b6fc92a10b77e3b3160af1723b8f542c41fe2ac Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Thu, 21 Apr 2022 01:06:55 +0530 Subject: [PATCH 12/23] async_cluster: remove logging --- redis/asyncio/cluster.py | 73 ++++++++---------------------- tests/test_asyncio/test_cluster.py | 8 ++-- 2 files changed, 25 insertions(+), 56 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 56acbe3907..d9bceb57ba 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,5 +1,4 @@ import asyncio -import logging import random import socket import warnings @@ -34,6 +33,7 @@ ClusterDownError, ClusterError, ConnectionError, + DataError, MasterDownError, MovedError, RedisClusterException, @@ -45,8 +45,6 @@ from redis.typing import EncodableT, KeyT from redis.utils import dict_merge, str_if_bytes -log = logging.getLogger(__name__) - TargetNodesT = TypeVar( "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) @@ -248,7 +246,6 @@ def __init__( " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," " ClusterNode('localhost', 6378)])" ) - log.debug(f"startup_nodes : {startup_nodes}") # Update the connection arguments # Whenever a new connection is established, RedisCluster's on_connect # method should be run @@ -398,17 +395,16 @@ def get_default_node(self) -> "ClusterNode": """Get the default node of the client.""" return self.nodes_manager.default_node - def set_default_node(self, node: "ClusterNode") -> bool: - """Set the default node of the client.""" + def set_default_node(self, node: "ClusterNode") -> None: + """ + Set the default node of the client. + + :raises DataError: if None is passed or node does not exist in cluster. + """ if node is None or self.get_node(node_name=node.name) is None: - log.info( - "The requested node does not exist in the cluster, so " - "the default node was not changed." - ) - return False + raise DataError("The requested node does not exist in the cluster.") + self.nodes_manager.default_node = node - log.info(f"Changed the default cluster node to {node}") - return True def set_response_callback(self, command: KeyT, callback: Callable) -> None: """Set a custom response callback.""" @@ -423,8 +419,6 @@ async def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the nodes group for this command if it was predefined command_flag = self.command_flags.get(command) - if command_flag: - log.debug(f"Target node/s for {command}: {command_flag}") if command_flag == self.__class__.RANDOM: # return a random node return [self.get_random_node()] @@ -448,7 +442,6 @@ async def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: node = await self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) - log.debug(f"Target for {args}: slot {slot}") return [node] def keyslot(self, key: EncodableT) -> int: @@ -649,10 +642,6 @@ async def _execute_command( ) moved = False - log.debug( - f"Executing command {command} on target node: " - f"{target_node.server_type} {target_node.name}" - ) redis_connection = await target_node.initialize( **self.get_connection_kwargs() ) @@ -679,12 +668,9 @@ async def _execute_command( response, **kwargs ) return response - - except (RedisClusterException, BusyLoadingError) as e: - log.exception(type(e)) + except (RedisClusterException, BusyLoadingError): raise - except (ConnectionError, TimeoutError) as e: - log.exception(type(e)) + except (ConnectionError, TimeoutError): # ConnectionError can also be raised if we couldn't get a # connection from the pool before timing out, so check that # this is an actual connection before attempting to disconnect. @@ -712,7 +698,6 @@ async def _execute_command( # the same client object is shared between multiple threads. To # reduce the frequency you can set this variable in the # RedisCluster constructor. - log.exception("MovedError") self.reinitialize_counter += 1 if ( self.reinitialize_steps @@ -725,32 +710,24 @@ async def _execute_command( self.nodes_manager._moved_exception = e moved = True except TryAgainError: - log.exception("TryAgainError") - if ttl < self.RedisClusterRequestTTL / 2: await asyncio.sleep(0.05) except AskError as e: - log.exception("AskError") - redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except ClusterDownError as e: - log.exception("ClusterDownError") + except ClusterDownError: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command await asyncio.sleep(0.25) await self.close() - raise e - except ResponseError as e: - message = e.__str__() - log.exception(f"ResponseError: {message}") - raise e - except BaseException as e: - log.exception("BaseException") + raise + except ResponseError: + raise + except BaseException: if connection: await connection.disconnect() - raise e + raise finally: if connection is not None: await redis_connection.connection_pool.release(connection) @@ -878,7 +855,7 @@ def get_node( host: Optional[str] = None, port: Optional[int] = None, node_name: Optional[str] = None, - ) -> Optional["ClusterNode"]: + ) -> "ClusterNode": if host and port: # the user passed host and port if host == "localhost": @@ -887,12 +864,11 @@ def get_node( elif node_name: return self.nodes_cache.get(node_name) else: - log.error( + raise DataError( "get_node requires one of the following: " "1. node name " "2. host and port" ) - return None async def set_nodes( self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"] @@ -994,7 +970,6 @@ def check_slots_coverage(self, slots_cache: Dict[int, List["ClusterNode"]]) -> b return True async def initialize(self) -> None: - log.debug("Initializing the nodes' topology of the cluster") self.reset() tmp_nodes_cache = {} tmp_slots = {} @@ -1016,17 +991,9 @@ async def initialize(self) -> None: await redis_connection.execute_command("CLUSTER SLOTS") ) startup_nodes_reachable = True - except (ConnectionError, TimeoutError) as e: - msg = e.__str__ - log.exception( - "An exception occurred while trying to" - " initialize the cluster using the seed node" - f" {startup_node.name}:\n{msg}" - ) + except (ConnectionError, TimeoutError): continue except ResponseError as e: - log.exception('ReseponseError sending "cluster slots" to redis server') - # Isn't a cluster connection, so it won't parse these # exceptions automatically message = e.__str__() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 39f6abc13a..5827449d83 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -612,7 +612,7 @@ async def test_set_default_node_success(self, r: RedisCluster) -> None: if node != default_node: new_def_node = node break - assert r.set_default_node(new_def_node) is True + r.set_default_node(new_def_node) assert r.get_default_node() == new_def_node async def test_set_default_node_failure(self, r: RedisCluster) -> None: @@ -621,8 +621,10 @@ async def test_set_default_node_failure(self, r: RedisCluster) -> None: """ default_node = r.get_default_node() new_def_node = ClusterNode("1.1.1.1", 1111) - assert r.set_default_node(None) is False - assert r.set_default_node(new_def_node) is False + with pytest.raises(DataError): + r.set_default_node(None) + with pytest.raises(DataError): + r.set_default_node(new_def_node) assert r.get_default_node() == default_node async def test_get_node_from_key(self, r: RedisCluster) -> None: From 337947503d7a67d6a7207ab8d841845c8cc13896 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sat, 23 Apr 2022 03:22:15 +0530 Subject: [PATCH 13/23] async_cluster/commands: optimize parser --- redis/asyncio/cluster.py | 4 +- redis/asyncio/parser.py | 86 +++++++++++++++++----------------------- 2 files changed, 40 insertions(+), 50 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index d9bceb57ba..5a1aa1c39e 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -286,7 +286,9 @@ async def initialize(self) -> "RedisCluster": self._initialize = False try: await self.nodes_manager.initialize() - await self.commands_parser.initialize(self) + await self.commands_parser.initialize( + self.nodes_manager.default_node.redis_connection + ) except BaseException: self._initialize = True await self.nodes_manager.close() diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 9bc6bc53f1..9afd0ccfba 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,19 +1,23 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union from redis.asyncio.client import Redis from redis.exceptions import RedisError, ResponseError -if TYPE_CHECKING: - from redis.asyncio.cluster import RedisCluster - class CommandsParser: """ Parses Redis commands to get command keys. + COMMAND output is used to determine key locations. - Commands that do not have a predefined key location are flagged with - 'movablekeys', and these commands' keys are determined by the command - 'COMMAND GETKEYS'. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. """ __slots__ = ("commands",) @@ -21,7 +25,7 @@ class CommandsParser: def __init__(self) -> None: self.commands = {} - async def initialize(self, r: "RedisCluster") -> None: + async def initialize(self, r: Redis) -> None: commands = await r.execute_command("COMMAND") uppercase_commands = [] for cmd in commands: @@ -29,6 +33,14 @@ async def initialize(self, r: "RedisCluster") -> None: uppercase_commands.append(cmd) for cmd in uppercase_commands: commands[cmd.lower()] = commands.pop(cmd) + + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 self.commands = commands # As soon as this PR is merged into Redis, we should reimplement @@ -37,25 +49,16 @@ async def initialize(self, r: "RedisCluster") -> None: async def get_keys( self, redis_conn: Redis, *args ) -> Optional[Union[List[str], List[bytes]]]: - """ - Get the keys from the passed command. - - NOTE: Due to a bug in redis<7.0, this function does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this function with EVAL or EVALSHA. - """ if len(args) < 2: # The command has no keys in it return None - cmd_name = args[0].lower() - if cmd_name not in self.commands: - # try to split the command name and to take only the main command, + try: + command = self.commands[args[0].lower()] + except KeyError: + # try to split the command name and to take only the main command # e.g. 'memory' for 'memory usage' - cmd_name_split = cmd_name.split() + cmd_name_split = args[0].lower().split() cmd_name = cmd_name_split[0] if cmd_name in self.commands: # save the splitted command to args @@ -69,36 +72,21 @@ async def get_keys( f"{cmd_name.upper()} command doesn't exist in Redis commands" ) - command = self.commands.get(cmd_name) - if "movablekeys" in command["flags"]: - keys = await self._get_moveable_keys(redis_conn, *args) - else: - if ( - command["step_count"] == 0 - and command["first_key_pos"] == 0 - and command["last_key_pos"] == 0 - ): - # The command doesn't have keys in it - return None - last_key_pos = command["last_key_pos"] - if last_key_pos < 0: - last_key_pos = len(args) - abs(last_key_pos) - keys_pos = list( - range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) - ) - keys = [args[pos] for pos in keys_pos] + command = self.commands[cmd_name] - return keys + if command == 1: + return [args[1]] + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(redis_conn, *args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] async def _get_moveable_keys(self, redis_conn: Redis, *args) -> Optional[List[str]]: - """ - NOTE: Due to a bug in redis<7.0, this function does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this function with EVAL or EVALSHA. - """ pieces = [] cmd_name = args[0] # The command name should be splitted into separate arguments, From 9566a7f3e7474cbf9a14a292b556e21d633ee7de Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 22 Apr 2022 06:53:46 +0530 Subject: [PATCH 14/23] async_cluster: use ensure_future & generators for gather --- benchmarks/cluster_async.py | 115 +++++++++++++++++++----------------- redis/asyncio/cluster.py | 55 ++++++++++++----- redis/commands/cluster.py | 24 +++++--- 3 files changed, 117 insertions(+), 77 deletions(-) diff --git a/benchmarks/cluster_async.py b/benchmarks/cluster_async.py index 7211078440..aec3f1c403 100644 --- a/benchmarks/cluster_async.py +++ b/benchmarks/cluster_async.py @@ -24,10 +24,12 @@ async def wrapper(*args, **kwargs): async def set_str(client, gather, data): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.set(f"bench:str_{i}", data)) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.set(f"bench:str_{i}", data)) + for i in range(100) + ) + ) else: for i in range(count): await client.set(f"bench:str_{i}", data) @@ -37,10 +39,12 @@ async def set_str(client, gather, data): async def set_int(client, gather, data): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.set(f"bench:int_{i}", data)) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.set(f"bench:int_{i}", data)) + for i in range(100) + ) + ) else: for i in range(count): await client.set(f"bench:int_{i}", data) @@ -50,10 +54,9 @@ async def set_int(client, gather, data): async def get_str(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.get(f"bench:str_{i}")) - await asyncio.gather(*tasks) + await asyncio.gather( + *(asyncio.create_task(client.get(f"bench:str_{i}")) for i in range(100)) + ) else: for i in range(count): await client.get(f"bench:str_{i}") @@ -63,10 +66,9 @@ async def get_str(client, gather): async def get_int(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.get(f"bench:int_{i}")) - await asyncio.gather(*tasks) + await asyncio.gather( + *(asyncio.create_task(client.get(f"bench:int_{i}")) for i in range(100)) + ) else: for i in range(count): await client.get(f"bench:int_{i}") @@ -76,10 +78,12 @@ async def get_int(client, gather): async def hset(client, gather, data): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.hset("bench:hset", str(i), data)) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.hset("bench:hset", str(i), data)) + for i in range(100) + ) + ) else: for i in range(count): await client.hset("bench:hset", str(i), data) @@ -89,10 +93,12 @@ async def hset(client, gather, data): async def hget(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.hget("bench:hset", str(i))) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.hget("bench:hset", str(i))) + for i in range(100) + ) + ) else: for i in range(count): await client.hget("bench:hset", str(i)) @@ -102,10 +108,9 @@ async def hget(client, gather): async def incr(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.incr("bench:incr")) - await asyncio.gather(*tasks) + await asyncio.gather( + *(asyncio.create_task(client.incr("bench:incr")) for i in range(100)) + ) else: for i in range(count): await client.incr("bench:incr") @@ -115,10 +120,12 @@ async def incr(client, gather): async def lpush(client, gather, data): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.lpush("bench:lpush", data)) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.lpush("bench:lpush", data)) + for i in range(100) + ) + ) else: for i in range(count): await client.lpush("bench:lpush", data) @@ -128,10 +135,12 @@ async def lpush(client, gather, data): async def lrange_300(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.lrange("bench:lpush", i, i + 300)) - await asyncio.gather(*tasks) + await asyncio.gather( + *( + asyncio.create_task(client.lrange("bench:lpush", i, i + 300)) + for i in range(100) + ) + ) else: for i in range(count): await client.lrange("bench:lpush", i, i + 300) @@ -141,10 +150,9 @@ async def lrange_300(client, gather): async def lpop(client, gather): if gather: for _ in range(count // 100): - tasks = [] - for i in range(100): - tasks.append(client.lpop("bench:lpush")) - await asyncio.gather(*tasks) + await asyncio.gather( + *(asyncio.create_task(client.lpop("bench:lpush")) for i in range(100)) + ) else: for i in range(count): await client.lpop("bench:lpush") @@ -152,10 +160,9 @@ async def lpop(client, gather): @timer async def warmup(client): - tasks = [] - for i in range(1000): - tasks.append(client.exists(f"bench:warmup_{i}")) - await asyncio.gather(*tasks) + await asyncio.gather( + *(asyncio.create_task(client.exists(f"bench:warmup_{i}")) for i in range(100)) + ) @timer @@ -165,19 +172,19 @@ async def run(client, gather): if gather is False: for ret in await asyncio.gather( - set_str(client, gather, data_str), - set_int(client, gather, data_int), - hset(client, gather, data_str), - incr(client, gather), - lpush(client, gather, data_int), + asyncio.create_task(set_str(client, gather, data_str)), + asyncio.create_task(set_int(client, gather, data_int)), + asyncio.create_task(hset(client, gather, data_str)), + asyncio.create_task(incr(client, gather)), + asyncio.create_task(lpush(client, gather, data_int)), ): print(ret) for ret in await asyncio.gather( - get_str(client, gather), - get_int(client, gather), - hget(client, gather), - lrange_300(client, gather), - lpop(client, gather), + asyncio.create_task(get_str(client, gather)), + asyncio.create_task(get_int(client, gather)), + asyncio.create_task(hget(client, gather)), + asyncio.create_task(lrange_300(client, gather)), + asyncio.create_task(lpop(client, gather)), ): print(ret) else: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 5a1aa1c39e..c2e5d29296 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -596,15 +596,31 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any f"No targets were found to execute {args} command on" ) - keys = [node.name for node in target_nodes] - values = await asyncio.gather( - *[ - self._execute_command(node, *args, **kwargs) - for node in target_nodes - ] - ) - # Return the processed result - return self._process_result(args[0], dict(zip(keys, values)), **kwargs) + if len(target_nodes) == 1: + # Return the processed result + return self._process_result( + args[0], + { + target_nodes[0].name: await self._execute_command( + target_nodes[0], *args, **kwargs + ) + }, + **kwargs, + ) + else: + keys = [node.name for node in target_nodes] + values = await asyncio.gather( + *( + asyncio.ensure_future( + self._execute_command(node, *args, **kwargs) + ) + for node in target_nodes + ) + ) + # Return the processed result + return self._process_result( + args[0], dict(zip(keys, values)), **kwargs + ) except BaseException as e: if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. @@ -875,12 +891,16 @@ def get_node( async def set_nodes( self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"] ) -> None: - tasks = [node.close() for name, node in old.items() if name not in new] + tasks = [ + asyncio.ensure_future(node.close()) + for name, node in old.items() + if name not in new + ] for name, node in new.items(): if name in old: if old[name] is node: continue - tasks.append(old[name].close()) + tasks.append(asyncio.ensure_future(old[name].close())) old[name] = node await asyncio.gather(*tasks) @@ -1107,10 +1127,10 @@ async def initialize(self) -> None: # Create Redis connections to all nodes await asyncio.gather( - *[ - node.initialize(**self.connection_kwargs) + *( + asyncio.ensure_future(node.initialize(**self.connection_kwargs)) for node in self.nodes_cache.values() - ] + ) ) # Set the default node @@ -1120,7 +1140,12 @@ async def initialize(self) -> None: async def close(self, attr: str = "nodes_cache") -> None: self.default_node = None - await asyncio.gather(*[node.close() for node in getattr(self, attr).values()]) + await asyncio.gather( + *( + asyncio.ensure_future(node.close()) + for node in getattr(self, attr).values() + ) + ) def reset(self) -> None: try: diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index b188540b0e..4249184bdc 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -228,10 +228,12 @@ async def mget_nonatomic(self, keys: KeysT, *args) -> List[Optional[Any]]: # the results # We must make sure that the keys are returned in order all_values = await asyncio.gather( - *[ - self.execute_command("MGET", *slot_keys, **options) + *( + asyncio.ensure_future( + self.execute_command("MGET", *slot_keys, **options) + ) for slot_keys in slots_to_keys.values() - ] + ) ) all_results = {} @@ -266,7 +268,10 @@ async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bo # Call MSET for every slot and concatenate # the results (one result per slot) return await asyncio.gather( - *[self.execute_command("MSET", *pairs) for pairs in slots_to_pairs.values()] + *( + asyncio.ensure_future(self.execute_command("MSET", *pairs)) + for pairs in slots_to_pairs.values() + ) ) async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: @@ -280,10 +285,10 @@ async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: # Sum up the reply from each command return sum( await asyncio.gather( - *[ - self.execute_command(command, *slot_keys) + *( + asyncio.ensure_future(self.execute_command(command, *slot_keys)) for slot_keys in slots_to_keys.values() - ] + ) ) ) @@ -635,7 +640,10 @@ async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: For more information see https://redis.io/commands/cluster-delslots """ return await asyncio.gather( - *[self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + *( + asyncio.ensure_future(self.execute_command("CLUSTER DELSLOTS", slot)) + for slot in slots + ) ) From 8cbf56dede1588e5c989f5ae352e29e95fecfd37 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 22 Apr 2022 07:30:02 +0530 Subject: [PATCH 15/23] async_cluster: inline functions --- redis/asyncio/cluster.py | 113 +++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 64 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index c2e5d29296..661d96f72b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -184,6 +184,7 @@ class initializer. In the case of conflicting arguments, querystring "cluster_response_callbacks", "command_flags", "commands_parser", + "connection_kwargs", "encoder", "node_flags", "nodes_manager", @@ -205,7 +206,7 @@ def __init__( url: Optional[str] = None, **kwargs, ) -> None: - if startup_nodes is None: + if not startup_nodes: startup_nodes = [] if "db" in kwargs: @@ -216,7 +217,7 @@ def __init__( # Get the startup node/s from_url = False - if url is not None: + if url: from_url = True url_options = parse_url(url) if "path" in url_options: @@ -235,7 +236,7 @@ def __init__( startup_nodes.append(ClusterNode(host, port)) elif host is not None and port is not None: startup_nodes.append(ClusterNode(host, port)) - elif len(startup_nodes) == 0: + elif not startup_nodes: # No startup node was provided raise RedisClusterException( "RedisCluster requires at least one node to discover the " @@ -246,11 +247,12 @@ def __init__( " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," " ClusterNode('localhost', 6378)])" ) + # Update the connection arguments # Whenever a new connection is established, RedisCluster's on_connect # method should be run - kwargs.update({"redis_connect_func": self.on_connect}) - kwargs = cleanup_kwargs(**kwargs) + kwargs["redis_connect_func"] = self.on_connect + self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), @@ -267,7 +269,7 @@ def __init__( startup_nodes=startup_nodes, from_url=from_url, require_full_coverage=require_full_coverage, - **kwargs, + **self.connection_kwargs, ) self.cluster_response_callbacks = CaseInsensitiveDict( @@ -381,7 +383,7 @@ def get_node_from_key( """ slot = self.keyslot(key) slot_cache = self.nodes_manager.slots_cache.get(slot) - if slot_cache is None or len(slot_cache) == 0: + if not slot_cache: raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') if replica and len(self.nodes_manager.slots_cache[slot]) < 2: return None @@ -403,7 +405,7 @@ def set_default_node(self, node: "ClusterNode") -> None: :raises DataError: if None is passed or node does not exist in cluster. """ - if node is None or self.get_node(node_name=node.name) is None: + if not node or not self.get_node(node_name=node.name): raise DataError("The requested node does not exist in the cluster.") self.nodes_manager.default_node = node @@ -477,15 +479,16 @@ async def determine_slot(self, *args) -> int: eval_keys = args[3 : 3 + num_actual_keys] # if there are 0 keys, that means the script can be run on any node # so we can just return a random slot - if len(eval_keys) == 0: + if not eval_keys: return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) keys = eval_keys else: - redis_connection = await self.get_default_node().initialize( - **self.get_connection_kwargs() + node = self.nodes_manager.default_node + redis_connection = node.redis_connection or await node.initialize( + **self.connection_kwargs ) keys = await self.commands_parser.get_keys(redis_connection, *args) - if keys is None or len(keys) == 0: + if not keys: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot if command in ("FCALL", "FCALL_RO"): @@ -518,7 +521,7 @@ def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: Get the kwargs passed to the :class:`~redis.asyncio.client.Redis` object of each node. """ - return self.nodes_manager.connection_kwargs + return self.connection_kwargs def _is_nodes_flag( self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] @@ -565,10 +568,11 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ + command = args[0] target_nodes_specified = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) - if passed_targets is not None and not self._is_nodes_flag(passed_targets): + if passed_targets and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) target_nodes_specified = True # If an error that allows retrying was thrown, the nodes and slots @@ -584,7 +588,8 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any ) exception = None for _ in range(0, retry_attempts): - await self.initialize() + if self._initialize: + await self.initialize() try: if not target_nodes_specified: # Determine the nodes to execute the command on @@ -598,15 +603,12 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any if len(target_nodes) == 1: # Return the processed result - return self._process_result( - args[0], - { - target_nodes[0].name: await self._execute_command( - target_nodes[0], *args, **kwargs - ) - }, - **kwargs, - ) + ret = await self._execute_command(target_nodes[0], *args, **kwargs) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, {target_nodes[0].name: ret}, **kwargs + ) + return ret else: keys = [node.name for node in target_nodes] values = await asyncio.gather( @@ -617,10 +619,11 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any for node in target_nodes ) ) - # Return the processed result - return self._process_result( - args[0], dict(zip(keys, values)), **kwargs - ) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, dict(zip(keys, values)), **kwargs + ) + return dict(zip(keys, values)) except BaseException as e: if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. @@ -660,8 +663,9 @@ async def _execute_command( ) moved = False - redis_connection = await target_node.initialize( - **self.get_connection_kwargs() + redis_connection = ( + target_node.redis_connection + or await target_node.initialize(**self.connection_kwargs) ) connection = ( redis_connection.connection @@ -686,7 +690,7 @@ async def _execute_command( response, **kwargs ) return response - except (RedisClusterException, BusyLoadingError): + except BusyLoadingError: raise except (ConnectionError, TimeoutError): # ConnectionError can also be raised if we couldn't get a @@ -740,8 +744,6 @@ async def _execute_command( await asyncio.sleep(0.25) await self.close() raise - except ResponseError: - raise except BaseException: if connection: await connection.disconnect() @@ -752,16 +754,6 @@ async def _execute_command( raise ClusterError("TTL exhausted.") - def _process_result(self, command: KeyT, res: Dict[str, Any], **kwargs) -> Any: - if command in self.result_callbacks: - return self.result_callbacks[command](command, res, **kwargs) - elif len(res) == 1: - # When we execute the command on a single node, we can - # remove the dictionary and return a single response - return list(res.values())[0] - else: - return res - class ClusterNode: """ @@ -817,8 +809,11 @@ async def initialize(self, from_url: bool = False, **kwargs) -> Redis: if not self.redis_connection: if from_url: # Create a redis node with a costumed connection pool - kwargs.update(host=self.host, port=self.port) - conn = Redis(connection_pool=ConnectionPool(**kwargs)) + conn = Redis( + connection_pool=ConnectionPool( + host=self.host, port=self.port, **kwargs + ) + ) else: conn = Redis(host=self.host, port=self.port, **kwargs) @@ -843,7 +838,6 @@ class NodesManager: "_require_full_coverage", "connection_kwargs", "default_node", - "from_url", "nodes_cache", "read_load_balancer", "slots_cache", @@ -853,7 +847,6 @@ class NodesManager: def __init__( self, startup_nodes: List["ClusterNode"], - from_url: bool = False, require_full_coverage: bool = False, **kwargs, ) -> None: @@ -861,7 +854,6 @@ def __init__( self.slots_cache = {} self.startup_nodes = {node.name: node for node in startup_nodes} self.default_node = None - self.from_url = from_url self._require_full_coverage = require_full_coverage self._moved_exception = None self.connection_kwargs = kwargs @@ -907,9 +899,9 @@ async def set_nodes( async def _update_moved_slots(self) -> None: e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) - if redirected_node is not None: + if redirected_node: # The node already exists - if redirected_node.server_type is not PRIMARY: + if redirected_node.server_type != PRIMARY: # Update the node's server type redirected_node.server_type = PRIMARY else: @@ -950,20 +942,20 @@ async def get_node_from_slot( if self._moved_exception: await self._update_moved_slots() - if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: + if not self.slots_cache.get(slot): raise SlotNotCoveredError( f'Slot "{slot}" not covered by the cluster. ' f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas is True: + if read_from_replicas: # get the server index in a Round-Robin manner primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( primary_name, len(self.slots_cache[slot]) ) elif ( - server_type is None + not server_type or server_type == PRIMARY or len(self.slots_cache[slot]) == 1 ): @@ -992,7 +984,7 @@ def check_slots_coverage(self, slots_cache: Dict[int, List["ClusterNode"]]) -> b return True async def initialize(self) -> None: - self.reset() + self.read_load_balancer.reset() tmp_nodes_cache = {} tmp_slots = {} disagreements = [] @@ -1042,7 +1034,7 @@ async def initialize(self) -> None: # Fix it to the host in startup_nodes if ( len(cluster_slots) == 1 - and len(cluster_slots[0][2][0]) == 0 + and not cluster_slots[0][2][0] and len(self.startup_nodes) == 1 ): cluster_slots[0][2][0] = startup_node.host @@ -1057,7 +1049,7 @@ async def initialize(self) -> None: port = int(primary_node[1]) target_node = tmp_nodes_cache.get(get_node_name(host, port)) - if target_node is None: + if not target_node: target_node = ClusterNode(host, port, PRIMARY) # add this node to the nodes cache tmp_nodes_cache[target_node.name] = target_node @@ -1075,7 +1067,7 @@ async def initialize(self) -> None: target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) ) - if target_replica_node is None: + if not target_replica_node: target_replica_node = ClusterNode(host, port, REPLICA) tmp_slots[i].append(target_replica_node) # add this node to the nodes cache @@ -1146,10 +1138,3 @@ async def close(self, attr: str = "nodes_cache") -> None: for node in getattr(self, attr).values() ) ) - - def reset(self) -> None: - try: - self.read_load_balancer.reset() - except TypeError: - # The read_load_balancer is None, do nothing - pass From c4abee1870dc6259d37304c1cee3ff8b551a8ecd Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 22 Apr 2022 07:00:50 +0530 Subject: [PATCH 16/23] async_cluster: optimize determine_slot --- redis/asyncio/cluster.py | 78 ++++++++++++++---------------- tests/test_asyncio/test_cluster.py | 14 +++--- 2 files changed, 42 insertions(+), 50 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 661d96f72b..5dcecd5cfe 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,6 +1,7 @@ import asyncio import random import socket +import threading import warnings from typing import Any, Callable, Dict, List, Optional, TypeVar, Union @@ -443,7 +444,7 @@ async def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the node that holds the key's slot slot = await self.determine_slot(*args) - node = await self.nodes_manager.get_node_from_slot( + node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) return [node] @@ -658,7 +659,7 @@ async def _execute_command( # MOVED occurred and the slots cache was updated, # refresh the target node slot = await self.determine_slot(*args) - target_node = await self.nodes_manager.get_node_from_slot( + target_node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) moved = False @@ -858,7 +859,7 @@ def __init__( self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() - self._lock = asyncio.Lock() + self._lock = threading.Lock() def get_node( self, @@ -880,23 +881,27 @@ def get_node( "2. host and port" ) - async def set_nodes( - self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"] + def set_nodes( + self, + old: Dict[str, "ClusterNode"], + new: Dict[str, "ClusterNode"], + remove_old=False, ) -> None: - tasks = [ - asyncio.ensure_future(node.close()) - for name, node in old.items() - if name not in new - ] + tasks = [] + if remove_old: + tasks = [ + asyncio.ensure_future(node.close()) + for name, node in old.items() + if name not in new + ] for name, node in new.items(): if name in old: if old[name] is node: continue tasks.append(asyncio.ensure_future(old[name].close())) old[name] = node - await asyncio.gather(*tasks) - async def _update_moved_slots(self) -> None: + def _update_moved_slots(self) -> None: e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) if redirected_node: @@ -907,9 +912,7 @@ async def _update_moved_slots(self) -> None: else: # This is a new node, we will add it to the nodes cache redirected_node = ClusterNode(e.host, e.port, PRIMARY) - await self.set_nodes( - self.nodes_cache, {redirected_node.name: redirected_node} - ) + self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) if redirected_node in self.slots_cache[e.slot_id]: # The MOVED error resulted from a failover, and the new slot owner # had previously been a replica. @@ -934,40 +937,29 @@ async def _update_moved_slots(self) -> None: # Reset moved_exception self._moved_exception = None - async def get_node_from_slot( - self, slot: int, read_from_replicas: bool = False, server_type: None = None + def get_node_from_slot( + self, slot: int, read_from_replicas: bool = False ) -> "ClusterNode": if self._moved_exception: - async with self._lock: + with self._lock: if self._moved_exception: - await self._update_moved_slots() - - if not self.slots_cache.get(slot): + self._update_moved_slots() + + try: + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + return self.slots_cache[slot][node_idx] + return self.slots_cache[slot][0] + except (IndexError, TypeError): raise SlotNotCoveredError( f'Slot "{slot}" not covered by the cluster. ' f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas: - # get the server index in a Round-Robin manner - primary_name = self.slots_cache[slot][0].name - node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) - ) - elif ( - not server_type - or server_type == PRIMARY - or len(self.slots_cache[slot]) == 1 - ): - # return a primary - node_idx = 0 - else: - # return a replica - # randomly choose one of the replicas - node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) - - return self.slots_cache[slot][node_idx] - def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: return [ node @@ -1113,9 +1105,9 @@ async def initialize(self) -> None: # Set the tmp variables to the real variables self.slots_cache = tmp_slots - await self.set_nodes(self.nodes_cache, tmp_nodes_cache) + self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) # Populate the startup nodes with all discovered nodes - await self.set_nodes(self.startup_nodes, self.nodes_cache) + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) # Create Redis connections to all nodes await asyncio.gather( diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 5827449d83..40a16fa3e1 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -172,7 +172,7 @@ async def moved_redirection_helper( slot = 12182 redirect_node = None # Get the current primary that holds this slot - prev_primary = await rc.nodes_manager.get_node_from_slot(slot) + prev_primary = rc.nodes_manager.get_node_from_slot(slot) if failover: if len(rc.nodes_manager.slots_cache[slot]) < 2: warnings.warn("Skipping this test since it requires to have a " "replica") @@ -793,7 +793,7 @@ async def test_cluster_addslotsrange(self, r): @skip_if_redis_enterprise() async def test_cluster_countkeysinslot(self, r: RedisCluster) -> None: - node = await r.nodes_manager.get_node_from_slot(1) + node = r.nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) assert await r.cluster_countkeysinslot(1) == 2 @@ -955,7 +955,7 @@ async def test_cluster_save_config(self, r: RedisCluster) -> None: @skip_if_redis_enterprise() async def test_cluster_get_keys_in_slot(self, r: RedisCluster) -> None: response = [b"{foo}1", b"{foo}2"] - node = await r.nodes_manager.get_node_from_slot(12182) + node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = await r.cluster_get_keys_in_slot(12182, 4) assert keys == response @@ -981,7 +981,7 @@ async def test_cluster_setslot(self, r: RedisCluster) -> None: await r.cluster_failover(node, "STATE") async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: - node = await r.nodes_manager.get_node_from_slot(12182) + node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @@ -1056,7 +1056,7 @@ async def test_info(self, r: RedisCluster) -> None: await r.set("z{1}", 3) # Get node that handles the slot slot = r.keyslot("x{1}") - node = await r.nodes_manager.get_node_from_slot(slot) + node = r.nodes_manager.get_node_from_slot(slot) # Run info on that node info = await r.info(target_nodes=node) assert isinstance(info, dict) @@ -1120,7 +1120,7 @@ async def test_slowlog_get_limit( async def test_slowlog_length(self, r: RedisCluster, slowlog: None) -> None: await r.get("foo") - node = await r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = await r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) @@ -1146,7 +1146,7 @@ async def test_memory_stats(self, r: RedisCluster) -> None: # put a key into the current db to make sure that "db." # has data await r.set("foo", "bar") - node = await r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) stats = await r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): From 8dad0a2b1322b95a65d8d9534d214d6cef8d70df Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Fri, 22 Apr 2022 17:31:15 +0530 Subject: [PATCH 17/23] async_cluster: optimize determine_nodes --- redis/asyncio/cluster.py | 119 +++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 68 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 5dcecd5cfe..7e991b90c4 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -415,39 +415,13 @@ def set_response_callback(self, command: KeyT, callback: Callable) -> None: """Set a custom response callback.""" self.cluster_response_callbacks[command] = callback - async def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: - command = args[0] - nodes_flag = kwargs.pop("nodes_flag", None) - if nodes_flag is not None: - # nodes flag passed by the user - command_flag = nodes_flag - else: - # get the nodes group for this command if it was predefined - command_flag = self.command_flags.get(command) - if command_flag == self.__class__.RANDOM: - # return a random node - return [self.get_random_node()] - elif command_flag == self.__class__.PRIMARIES: - # return all primaries - return self.get_primaries() - elif command_flag == self.__class__.REPLICAS: - # return all replicas - return self.get_replicas() - elif command_flag == self.__class__.ALL_NODES: - # return all nodes - return self.get_nodes() - elif command_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - elif command in self.__class__.SEARCH_COMMANDS[0]: - return [self.nodes_manager.default_node] - else: - # get the node that holds the key's slot - slot = await self.determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS - ) - return [node] + def get_encoder(self) -> Encoder: + """Get the encoder object of the client.""" + return self.encoder + + def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: + """Get the connection kwargs passed to :class:`~redis.asyncio.client.Redis`.""" + return self.connection_kwargs def keyslot(self, key: EncodableT) -> int: """ @@ -458,7 +432,39 @@ def keyslot(self, key: EncodableT) -> int: k = self.encoder.encode(key) return key_slot(k) - async def determine_slot(self, *args) -> int: + async def _determine_nodes( + self, *args, node_flag: Optional[str] = None + ) -> List["ClusterNode"]: + command = args[0] + if not node_flag: + # get the nodes group for this command if it was predefined + node_flag = self.command_flags.get(command) + + if node_flag in self.node_flags: + if node_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + if node_flag == self.__class__.PRIMARIES: + # return all primaries + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + if node_flag == self.__class__.REPLICAS: + # return all replicas + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + if node_flag == self.__class__.ALL_NODES: + # return all nodes + return list(self.nodes_manager.nodes_cache.values()) + if node_flag == self.__class__.RANDOM: + # return a random node + return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + + # get the node that holds the key's slot + slot = await self._determine_slot(*args) + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) + return [node] + + async def _determine_slot(self, *args) -> int: command = args[0] if self.command_flags.get(command) == SLOT_ID: # The command contains the slot ID @@ -514,17 +520,7 @@ async def determine_slot(self, *args) -> int: return slots.pop() - def get_encoder(self) -> Encoder: - return self.encoder - - def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: - """ - Get the kwargs passed to the :class:`~redis.asyncio.client.Redis` object of - each node. - """ - return self.connection_kwargs - - def _is_nodes_flag( + def _is_node_flag( self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] ) -> bool: return isinstance(target_nodes, str) and target_nodes in self.node_flags @@ -570,24 +566,15 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any can't be mapped to a slot """ command = args[0] - target_nodes_specified = False - target_nodes = None + target_nodes_specified = target_nodes = exception = None + retry_attempts = self.cluster_error_retry_attempts + passed_targets = kwargs.pop("target_nodes", None) - if passed_targets and not self._is_nodes_flag(passed_targets): + if passed_targets and not self._is_node_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) target_nodes_specified = True - # If an error that allows retrying was thrown, the nodes and slots - # cache were reinitialized. We will retry executing the command with - # the updated cluster setup only when the target nodes can be - # determined again with the new cache tables. Therefore, when target - # nodes were passed to this function, we cannot retry the command - # execution since the nodes may not be valid anymore after the tables - # were reinitialized. So in case of passed target nodes, - # retry_attempts will be set to 1. - retry_attempts = ( - 1 if target_nodes_specified else self.cluster_error_retry_attempts - ) - exception = None + retry_attempts = 1 + for _ in range(0, retry_attempts): if self._initialize: await self.initialize() @@ -595,7 +582,7 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = await self._determine_nodes( - *args, **kwargs, nodes_flag=passed_targets + *args, node_flag=passed_targets ) if not target_nodes: raise RedisClusterException( @@ -642,12 +629,8 @@ async def _execute_command( self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs ) -> Any: command = args[0] - redis_connection = None - connection = None - redirect_addr = None - asking = False - moved = False - ttl = int(self.RedisClusterRequestTTL) + redis_connection = connection = redirect_addr = asking = moved = None + ttl = self.RedisClusterRequestTTL connection_error_retry_counter = 0 while ttl > 0: @@ -658,7 +641,7 @@ async def _execute_command( elif moved: # MOVED occurred and the slots cache was updated, # refresh the target node - slot = await self.determine_slot(*args) + slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) From cf1b3e3014046a999d7717aa1bbe2d54d24e427a Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sat, 23 Apr 2022 02:02:38 +0530 Subject: [PATCH 18/23] async_cluster: manage Connection instead of Redis Client --- redis/asyncio/cluster.py | 292 +++++++++++++++-------------- redis/asyncio/parser.py | 27 ++- redis/client.py | 17 +- redis/cluster.py | 23 +-- tests/test_asyncio/test_cluster.py | 193 +++++++++---------- tests/test_cluster.py | 6 +- 6 files changed, 262 insertions(+), 296 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7e991b90c4..3b66d675d5 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1,20 +1,14 @@ import asyncio +import collections import random import socket import threading import warnings -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union - -from redis.asyncio.client import Redis -from redis.asyncio.connection import ( - Connection, - ConnectionPool, - DefaultParser, - Encoder, - parse_url, -) +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union + +from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url from redis.asyncio.parser import CommandsParser -from redis.client import CaseInsensitiveDict from redis.cluster import ( PRIMARY, READ_COMMANDS, @@ -24,6 +18,7 @@ LoadBalancer, cleanup_kwargs, get_node_name, + parse_cluster_slots, ) from redis.commands import AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot @@ -65,7 +60,7 @@ class ClusterParser(DefaultParser): ) -class RedisCluster(AbstractRedisCluster, AsyncRedisClusterCommands): +class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ Create a new RedisCluster client. @@ -122,7 +117,7 @@ class RedisCluster(AbstractRedisCluster, AsyncRedisClusterCommands): | See :meth:`.from_url` :param kwargs: | Extra arguments that will be passed to the - :class:`~redis.asyncio.client.Redis` instance when created + :class:`~redis.asyncio.connection.Connection` instance when created :raises RedisClusterException: if any arguments are invalid. Eg: @@ -182,7 +177,6 @@ class initializer. In the case of conflicting arguments, querystring "_initialize", "_lock", "cluster_error_retry_attempts", - "cluster_response_callbacks", "command_flags", "commands_parser", "connection_kwargs", @@ -192,6 +186,7 @@ class initializer. In the case of conflicting arguments, querystring "read_from_replicas", "reinitialize_counter", "reinitialize_steps", + "response_callbacks", "result_callbacks", ) @@ -200,10 +195,10 @@ def __init__( host: Optional[str] = None, port: int = 6379, startup_nodes: Optional[List["ClusterNode"]] = None, - cluster_error_retry_attempts: int = 3, require_full_coverage: bool = False, - reinitialize_steps: int = 10, read_from_replicas: bool = False, + cluster_error_retry_attempts: int = 3, + reinitialize_steps: int = 10, url: Optional[str] = None, **kwargs, ) -> None: @@ -217,9 +212,7 @@ def __init__( ) # Get the startup node/s - from_url = False if url: - from_url = True url_options = parse_url(url) if "path" in url_options: raise RedisClusterException( @@ -234,10 +227,7 @@ def __init__( kwargs.update(url_options) host = kwargs.get("host") port = kwargs.get("port", port) - startup_nodes.append(ClusterNode(host, port)) - elif host is not None and port is not None: - startup_nodes.append(ClusterNode(host, port)) - elif not startup_nodes: + elif (not host or not port) and not startup_nodes: # No startup node was provided raise RedisClusterException( "RedisCluster requires at least one node to discover the " @@ -254,6 +244,11 @@ def __init__( # method should be run kwargs["redis_connect_func"] = self.on_connect self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) + self.response_callbacks = kwargs[ + "response_callbacks" + ] = self.__class__.RESPONSE_CALLBACKS + if host and port: + startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), @@ -268,15 +263,16 @@ def __init__( self.reinitialize_steps = reinitialize_steps self.nodes_manager = NodesManager( startup_nodes=startup_nodes, - from_url=from_url, require_full_coverage=require_full_coverage, **self.connection_kwargs, ) - self.cluster_response_callbacks = CaseInsensitiveDict( - self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS + self.result_callbacks = self.__class__.RESULT_CALLBACKS + self.result_callbacks[ + "CLUSTER SLOTS" + ] = lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs ) - self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) self.commands_parser = CommandsParser() self._initialize = True self._lock = asyncio.Lock() @@ -290,7 +286,7 @@ async def initialize(self) -> "RedisCluster": try: await self.nodes_manager.initialize() await self.commands_parser.initialize( - self.nodes_manager.default_node.redis_connection + self.nodes_manager.default_node ) except BaseException: self._initialize = True @@ -413,14 +409,14 @@ def set_default_node(self, node: "ClusterNode") -> None: def set_response_callback(self, command: KeyT, callback: Callable) -> None: """Set a custom response callback.""" - self.cluster_response_callbacks[command] = callback + self.response_callbacks[command] = callback def get_encoder(self) -> Encoder: """Get the encoder object of the client.""" return self.encoder def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: - """Get the connection kwargs passed to :class:`~redis.asyncio.client.Redis`.""" + """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" return self.connection_kwargs def keyslot(self, key: EncodableT) -> int: @@ -490,11 +486,9 @@ async def _determine_slot(self, *args) -> int: return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) keys = eval_keys else: - node = self.nodes_manager.default_node - redis_connection = node.redis_connection or await node.initialize( - **self.connection_kwargs + keys = await self.commands_parser.get_keys( + self.nodes_manager.default_node, *args ) - keys = await self.commands_parser.get_keys(redis_connection, *args) if not keys: # FCALL can call a function with 0 keys, that means the function # can be run on any node so we can just return a random slot @@ -628,8 +622,7 @@ async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any async def _execute_command( self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs ) -> Any: - command = args[0] - redis_connection = connection = redirect_addr = asking = moved = None + redirect_addr = asking = moved = None ttl = self.RedisClusterRequestTTL connection_error_retry_counter = 0 @@ -638,50 +631,21 @@ async def _execute_command( try: if asking: target_node = self.get_node(node_name=redirect_addr) + await target_node.execute_command("ASKING") + asking = False elif moved: # MOVED occurred and the slots cache was updated, # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, self.read_from_replicas and args[0] in READ_COMMANDS ) moved = False - redis_connection = ( - target_node.redis_connection - or await target_node.initialize(**self.connection_kwargs) - ) - connection = ( - redis_connection.connection - or await redis_connection.connection_pool.get_connection( - command, **kwargs - ) - ) - - if asking: - await connection.send_command("ASKING") - await redis_connection.parse_response( - connection, "ASKING", **kwargs - ) - asking = False - - await connection.send_command(*args) - response = await redis_connection.parse_response( - connection, command, **kwargs - ) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - return response + return await target_node.execute_command(*args, **kwargs) except BusyLoadingError: raise except (ConnectionError, TimeoutError): - # ConnectionError can also be raised if we couldn't get a - # connection from the pool before timing out, so check that - # this is an actual connection before attempting to disconnect. - if connection is not None: - await connection.disconnect() connection_error_retry_counter += 1 # Give the node 0.25 seconds to get back up and retry again @@ -728,45 +692,63 @@ async def _execute_command( await asyncio.sleep(0.25) await self.close() raise - except BaseException: - if connection: - await connection.disconnect() - raise - finally: - if connection is not None: - await redis_connection.connection_pool.release(connection) raise ClusterError("TTL exhausted.") class ClusterNode: """ - Create a ClusterNode. + Create a new ClusterNode. - Each ClusterNode manages a :class:`~redis.asyncio.client.Redis` object corresponding - to the (host, port). + Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` + objects for the (host, port). """ - __slots__ = ("_lock", "host", "name", "port", "redis_connection", "server_type") + __slots__ = ( + "_connections", + "_free", + "connection_class", + "connection_kwargs", + "host", + "max_connections", + "name", + "port", + "response_callbacks", + "server_type", + ) - def __init__(self, host: str, port: int, server_type: Optional[str] = None) -> None: + def __init__( + self, + host: str, + port: int, + server_type: Optional[str] = None, + max_connections: int = 2 ** 31, + connection_class: Type[Connection] = Connection, + response_callbacks: Dict = None, + **connection_kwargs, + ) -> None: if host == "localhost": host = socket.gethostbyname(host) + connection_kwargs["host"] = host + connection_kwargs["port"] = port self.host = host self.port = port self.name = get_node_name(host, port) self.server_type = server_type - self.redis_connection = None - self._lock = asyncio.Lock() + + self.max_connections = max_connections + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.response_callbacks = response_callbacks + + self._connections = [] + self._free = collections.deque(maxlen=self.max_connections) def __repr__(self) -> str: return ( - f"[host={self.host}," - f"port={self.port}," - f"name={self.name}," - f"server_type={self.server_type}," - f"redis_connection={self.redis_connection}]" + f"[host={self.host}, port={self.port}, " + f"name={self.name}, server_type={self.server_type}]" ) def __eq__(self, obj: "ClusterNode") -> bool: @@ -775,44 +757,70 @@ def __eq__(self, obj: "ClusterNode") -> bool: _DEL_MESSAGE = "Unclosed ClusterNode object" def __del__(self, _warnings=warnings): - if hasattr(self, "redis_connection") and self.redis_connection: - _warnings.warn( - f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self - ) - try: - context = {"client": self, "message": self._DEL_MESSAGE} - # TODO: Change to get_running_loop() when dropping support for py3.6 - asyncio.get_event_loop().call_exception_handler(context) - except RuntimeError: - ... - - async def initialize(self, from_url: bool = False, **kwargs) -> Redis: - """Create a redis object & make connections.""" - if not self.redis_connection: - async with self._lock: - if not self.redis_connection: - if from_url: - # Create a redis node with a costumed connection pool - conn = Redis( - connection_pool=ConnectionPool( - host=self.host, port=self.port, **kwargs - ) - ) - else: - conn = Redis(host=self.host, port=self.port, **kwargs) - - self.redis_connection = await conn.initialize() - - return self.redis_connection + for connection in self._connections: + if connection.is_connected: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + break - async def close(self) -> None: - """Close all redis client connections & object.""" - if self.redis_connection: - async with self._lock: - if self.redis_connection: - conn = self.redis_connection - self.redis_connection = None - await conn.close(True) + async def disconnect(self) -> None: + ret = await asyncio.gather( + *( + asyncio.ensure_future(connection.disconnect()) + for connection in self._connections + ), + return_exceptions=True, + ) + exc = next((res for res in ret if isinstance(res, Exception)), None) + if exc: + raise exc + + async def execute_command(self, *args, **kwargs) -> Any: + # Acquire connection + connection = None + if self._free: + for _ in range(len(self._free)): + if self._free[0].is_connected: + connection = self._free.popleft() + break + self._free.rotate(-1) + else: + connection = self._free.popleft() + else: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + else: + raise ConnectionError("Too many connections") + + # Execute command + command = connection.pack_command(*args) + await connection.send_packed_command(command, False) + try: + if NEVER_DECODE in kwargs: + response = await connection.read_response(disable_decoding=True) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in kwargs: + return kwargs[EMPTY_RESPONSE] + raise + finally: + # Release connection + self._free.append(connection) + + # Return response + try: + return self.response_callbacks[args[0]](response, **kwargs) + except KeyError: + return response class NodesManager: @@ -873,7 +881,7 @@ def set_nodes( tasks = [] if remove_old: tasks = [ - asyncio.ensure_future(node.close()) + asyncio.ensure_future(node.disconnect()) for name, node in old.items() if name not in new ] @@ -881,7 +889,7 @@ def set_nodes( if name in old: if old[name] is node: continue - tasks.append(asyncio.ensure_future(old[name].close())) + tasks.append(asyncio.ensure_future(old[name].disconnect())) old[name] = node def _update_moved_slots(self) -> None: @@ -894,7 +902,9 @@ def _update_moved_slots(self) -> None: redirected_node.server_type = PRIMARY else: # This is a new node, we will add it to the nodes cache - redirected_node = ClusterNode(e.host, e.port, PRIMARY) + redirected_node = ClusterNode( + e.host, e.port, PRIMARY, **self.connection_kwargs + ) self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) if redirected_node in self.slots_cache[e.slot_id]: # The MOVED error resulted from a failover, and the new slot owner @@ -967,17 +977,15 @@ async def initialize(self) -> None: fully_covered = False for startup_node in self.startup_nodes.values(): try: - redis_connection = await startup_node.initialize( - **self.connection_kwargs - ) - # Make sure cluster mode is enabled on this node - if not (await redis_connection.info()).get("cluster_enabled"): + if not (await startup_node.execute_command("INFO")).get( + "cluster_enabled" + ): raise RedisClusterException( "Cluster mode is not enabled on this node" ) cluster_slots = str_if_bytes( - await redis_connection.execute_command("CLUSTER SLOTS") + await startup_node.execute_command("CLUSTER SLOTS") ) startup_nodes_reachable = True except (ConnectionError, TimeoutError): @@ -1025,7 +1033,9 @@ async def initialize(self) -> None: target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: - target_node = ClusterNode(host, port, PRIMARY) + target_node = ClusterNode( + host, port, PRIMARY, **self.connection_kwargs + ) # add this node to the nodes cache tmp_nodes_cache[target_node.name] = target_node @@ -1043,7 +1053,9 @@ async def initialize(self) -> None: get_node_name(host, port) ) if not target_replica_node: - target_replica_node = ClusterNode(host, port, REPLICA) + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) tmp_slots[i].append(target_replica_node) # add this node to the nodes cache tmp_nodes_cache[ @@ -1092,14 +1104,6 @@ async def initialize(self) -> None: # Populate the startup nodes with all discovered nodes self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) - # Create Redis connections to all nodes - await asyncio.gather( - *( - asyncio.ensure_future(node.initialize(**self.connection_kwargs)) - for node in self.nodes_cache.values() - ) - ) - # Set the default node self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] # If initialize was called after a MovedError, clear it @@ -1109,7 +1113,7 @@ async def close(self, attr: str = "nodes_cache") -> None: self.default_node = None await asyncio.gather( *( - asyncio.ensure_future(node.close()) + asyncio.ensure_future(node.disconnect()) for node in getattr(self, attr).values() ) ) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 9afd0ccfba..57c20e81fc 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,8 +1,10 @@ -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union -from redis.asyncio.client import Redis from redis.exceptions import RedisError, ResponseError +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + class CommandsParser: """ @@ -25,15 +27,8 @@ class CommandsParser: def __init__(self) -> None: self.commands = {} - async def initialize(self, r: Redis) -> None: + async def initialize(self, r: "ClusterNode") -> None: commands = await r.execute_command("COMMAND") - uppercase_commands = [] - for cmd in commands: - if any(x.isupper() for x in cmd): - uppercase_commands.append(cmd) - for cmd in uppercase_commands: - commands[cmd.lower()] = commands.pop(cmd) - for cmd, command in commands.items(): if "movablekeys" in command["flags"]: commands[cmd] = -1 @@ -41,24 +36,24 @@ async def initialize(self, r: Redis) -> None: commands[cmd] = 0 elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: commands[cmd] = 1 - self.commands = commands + self.commands = {cmd.upper(): command for cmd, command in commands.items()} # As soon as this PR is merged into Redis, we should reimplement # our logic to use COMMAND INFO changes to determine the key positions # https://github.com/redis/redis/pull/8324 async def get_keys( - self, redis_conn: Redis, *args + self, redis_conn: "ClusterNode", *args ) -> Optional[Union[List[str], List[bytes]]]: if len(args) < 2: # The command has no keys in it return None try: - command = self.commands[args[0].lower()] + command = self.commands[args[0]] except KeyError: # try to split the command name and to take only the main command # e.g. 'memory' for 'memory usage' - cmd_name_split = args[0].lower().split() + cmd_name_split = args[0].split() cmd_name = cmd_name_split[0] if cmd_name in self.commands: # save the splitted command to args @@ -86,7 +81,9 @@ async def get_keys( last_key_pos = len(args) + last_key_pos return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] - async def _get_moveable_keys(self, redis_conn: Redis, *args) -> Optional[List[str]]: + async def _get_moveable_keys( + self, redis_conn: "ClusterNode", *args + ) -> Optional[List[str]]: pieces = [] cmd_name = args[0] # The command name should be splitted into separate arguments, diff --git a/redis/client.py b/redis/client.py index d8d7a75ce0..1ba056c86f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -410,11 +410,7 @@ def parse_slowlog_get(response, **options): space = " " if options.get("decode_responses", False) else b" " def parse_item(item): - result = { - "id": item[0], - "start_time": int(item[1]), - "duration": int(item[2]), - } + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} # Redis Enterprise injects another entry at index [3], which has # the complexity info (i.e. the value N in case the command has # an O(N) complexity) instead of the command. @@ -690,7 +686,7 @@ class AbstractRedis: **string_keys_to_dict("SORT", sort_return_tuples), **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", bool_ok, ), @@ -740,17 +736,18 @@ class AbstractRedis: "CLUSTER DELSLOTSRANGE": bool_ok, "CLUSTER FAILOVER": bool_ok, "CLUSTER FORGET": bool_ok, + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), "CLUSTER INFO": parse_cluster_info, "CLUSTER KEYSLOT": lambda x: int(x), "CLUSTER MEET": bool_ok, "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, "CLUSTER REPLICATE": bool_ok, "CLUSTER RESET": bool_ok, "CLUSTER SAVECONFIG": bool_ok, "CLUSTER SET-CONFIG-EPOCH": bool_ok, "CLUSTER SETSLOT": bool_ok, "CLUSTER SLAVES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, "COMMAND": parse_command, "COMMAND COUNT": int, "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), @@ -1023,11 +1020,7 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.response_callbacks[command] = callback - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. diff --git a/redis/cluster.py b/redis/cluster.py index b61d384791..538db28500 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -354,28 +354,7 @@ class AbstractRedisCluster: ], ) - CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { - "CLUSTER ADDSLOTS": bool, - "CLUSTER ADDSLOTSRANGE": bool, - "CLUSTER COUNT-FAILURE-REPORTS": int, - "CLUSTER COUNTKEYSINSLOT": int, - "CLUSTER DELSLOTS": bool, - "CLUSTER DELSLOTSRANGE": bool, - "CLUSTER FAILOVER": bool, - "CLUSTER FORGET": bool, - "CLUSTER GETKEYSINSLOT": list, - "CLUSTER KEYSLOT": int, - "CLUSTER MEET": bool, - "CLUSTER REPLICATE": bool, - "CLUSTER RESET": bool, - "CLUSTER SAVECONFIG": bool, - "CLUSTER SET-CONFIG-EPOCH": bool, - "CLUSTER SETSLOT": bool, - "CLUSTER SLOTS": parse_cluster_slots, - "ASKING": bool, - "READONLY": bool, - "READWRITE": bool, - } + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = {"CLUSTER SLOTS": parse_cluster_slots} RESULT_CALLBACKS = dict_merge( list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 40a16fa3e1..fb301c47eb 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -17,7 +17,7 @@ from _pytest.fixtures import FixtureRequest, SubRequest -from redis.asyncio import Connection, Redis, RedisCluster +from redis.asyncio import Connection, RedisCluster from redis.asyncio.cluster import ( PRIMARY, REDIS_CLUSTER_HASH_SLOTS, @@ -88,7 +88,7 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: cluster_slots = kwargs.pop("cluster_slots", default_cluster_slots) coverage_res = kwargs.pop("coverage_result", "yes") cluster_enabled = kwargs.pop("cluster_enabled", True) - with mock.patch.object(Redis, "execute_command") as execute_command_mock: + with mock.patch.object(ClusterNode, "execute_command") as execute_command_mock: async def execute_command(*_args, **_kwargs): if _args[0] == "CLUSTER SLOTS": @@ -111,7 +111,7 @@ async def execute_command(*_args, **_kwargs): def cmd_init_mock(self, r): self.commands = { - "get": { + "GET": { "name": "get", "arity": 2, "flags": ["readonly", "fast"], @@ -133,8 +133,11 @@ def mock_node_resp( ], ) -> ClusterNode: connection = mock.AsyncMock() + connection.is_connected = True connection.read_response.return_value = response - node.redis_connection.connection = connection + while node._free: + node._free.pop() + node._free.append(connection) return node @@ -183,19 +186,21 @@ async def moved_redirection_helper( redirect_node = rc.get_primaries()[0] r_host = redirect_node.host r_port = redirect_node.port - with mock.patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: - def moved_redirect_effect(connection, *args, **options): - def ok_response(connection, *args, **options): - assert connection.host == r_host - assert connection.port == r_port + def moved_redirect_effect(self, *args, **options): + def ok_response(self, *args, **options): + assert self.host == r_host + assert self.port == r_port return "MOCK_OK" - parse_response.side_effect = ok_response + execute_command.side_effect = ok_response raise MovedError(f"{slot} {r_host}:{r_port}") - parse_response.side_effect = moved_redirect_effect + execute_command.side_effect = moved_redirect_effect assert await rc.execute_command("SET", "foo", "bar") == "MOCK_OK" slot_primary = rc.nodes_manager.slots_cache[slot][0] assert slot_primary == redirect_node @@ -282,10 +287,10 @@ async def test_execute_command_node_flag_primaries(self, r: RedisCluster) -> Non mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.PRIMARIES) is True for primary in primaries: - conn = primary.redis_connection.connection + conn = primary._free.pop() assert conn.read_response.called is True for replica in replicas: - conn = replica.redis_connection.connection + conn = replica._free.pop() assert conn.read_response.called is not True async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None: @@ -299,10 +304,10 @@ async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.REPLICAS) is True for replica in replicas: - conn = replica.redis_connection.connection + conn = replica._free.pop() assert conn.read_response.called is True for primary in primaries: - conn = primary.redis_connection.connection + conn = primary._free.pop() assert conn.read_response.called is not True await r.close() @@ -314,7 +319,7 @@ async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> Non mock_all_nodes_resp(r, "PONG") assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True for node in r.get_nodes(): - conn = node.redis_connection.connection + conn = node._free.pop() assert conn.read_response.called is True async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: @@ -325,7 +330,7 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: assert await r.ping(target_nodes=RedisCluster.RANDOM) is True called_count = 0 for node in r.get_nodes(): - conn = node.redis_connection.connection + conn = node._free.pop() if conn.read_response.called is True: called_count += 1 assert called_count == 1 @@ -338,7 +343,7 @@ async def test_execute_command_default_node(self, r: RedisCluster) -> None: def_node = r.get_default_node() mock_node_resp(def_node, "PONG") assert await r.ping() is True - conn = def_node.redis_connection.connection + conn = def_node._free.pop() assert conn.read_response.called async def test_ask_redirection(self, r: RedisCluster) -> None: @@ -351,19 +356,21 @@ async def test_ask_redirection(self, r: RedisCluster) -> None: Important thing to verify is that it tries to talk to the second node. """ redirect_node = r.get_nodes()[0] - with mock.patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: - def ask_redirect_effect(connection, *args, **options): - def ok_response(connection, *args, **options): - assert connection.host == redirect_node.host - assert connection.port == redirect_node.port + def ask_redirect_effect(self, *args, **options): + def ok_response(self, *args, **options): + assert self.host == redirect_node.host + assert self.port == redirect_node.port return "MOCK_OK" - parse_response.side_effect = ok_response + execute_command.side_effect = ok_response raise AskError(f"12182 {redirect_node.host}:{redirect_node.port}") - parse_response.side_effect = ask_redirect_effect + execute_command.side_effect = ask_redirect_effect assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" @@ -392,27 +399,29 @@ async def test_refresh_using_specific_nodes( """ node_7006 = ClusterNode(host=default_host, port=7006, server_type=PRIMARY) node_7007 = ClusterNode(host=default_host, port=7007, server_type=PRIMARY) - with mock.patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: with mock.patch.object( NodesManager, "initialize", autospec=True ) as initialize: with mock.patch.multiple( Connection, - send_command=mock.DEFAULT, + send_packed_command=mock.DEFAULT, connect=mock.DEFAULT, can_read=mock.DEFAULT, ) as mocks: # simulate 7006 as a failed node - def parse_response_mock(connection, command_name, **options): - if connection.port == 7006: - parse_response.failed_calls += 1 + def execute_command_mock(self, *args, **options): + if self.port == 7006: + execute_command.failed_calls += 1 raise ClusterDownError( "CLUSTERDOWN The cluster is " "down. Use CLUSTER INFO for " "more information" ) - elif connection.port == 7007: - parse_response.successful_calls += 1 + elif self.port == 7007: + execute_command.successful_calls += 1 def initialize_mock(self): # start with all slots mapped to 7006 @@ -436,12 +445,12 @@ def map_7007(self): # Change initialize side effect for the second call initialize.side_effect = map_7007 - parse_response.side_effect = parse_response_mock - parse_response.successful_calls = 0 - parse_response.failed_calls = 0 + execute_command.side_effect = execute_command_mock + execute_command.successful_calls = 0 + execute_command.failed_calls = 0 initialize.side_effect = initialize_mock mocks["can_read"].return_value = False - mocks["send_command"].return_value = "MOCK_OK" + mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( CommandsParser, "initialize", autospec=True @@ -449,7 +458,7 @@ def map_7007(self): def cmd_init_mock(self, r): self.commands = { - "get": { + "GET": { "name": "get", "arity": 2, "flags": ["readonly", "fast"], @@ -472,8 +481,8 @@ def cmd_init_mock(self, r): assert len(rc.get_nodes()) == 1 assert rc.get_node(node_name=node_7007.name) is not None assert rc.get_node(node_name=node_7006.name) is None - assert parse_response.failed_calls == 1 - assert parse_response.successful_calls == 1 + assert execute_command.failed_calls == 1 + assert execute_command.successful_calls == 1 async def test_reading_from_replicas_in_round_robin(self) -> None: with mock.patch.multiple( @@ -484,29 +493,32 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: can_read=mock.DEFAULT, on_connect=mock.DEFAULT, ) as mocks: - with mock.patch.object(Redis, "parse_response") as parse_response: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: - def parse_response_mock_first(connection, *args, **options): + async def execute_command_mock_first(self, *args, **options): + await self.connection_class(**self.connection_kwargs).connect() # Primary - assert connection.port == 7001 - parse_response.side_effect = parse_response_mock_second + assert self.port == 7001 + execute_command.side_effect = execute_command_mock_second return "MOCK_OK" - def parse_response_mock_second(connection, *args, **options): + def execute_command_mock_second(self, *args, **options): # Replica - assert connection.port == 7002 - parse_response.side_effect = parse_response_mock_third + assert self.port == 7002 + execute_command.side_effect = execute_command_mock_third return "MOCK_OK" - def parse_response_mock_third(connection, *args, **options): + def execute_command_mock_third(self, *args, **options): # Primary - assert connection.port == 7001 + assert self.port == 7001 return "MOCK_OK" # We don't need to create a real cluster connection but we # do want RedisCluster.on_connect function to get called, # so we'll mock some of the Connection's functions to allow it - parse_response.side_effect = parse_response_mock_first + execute_command.side_effect = execute_command_mock_first mocks["send_command"].return_value = True mocks["read_response"].return_value = "OK" mocks["_connect"].return_value = True @@ -677,12 +689,6 @@ class TestClusterRedisCommands: Tests for RedisCluster unique commands """ - async def test_case_insensitive_command_names(self, r: RedisCluster) -> None: - assert ( - r.cluster_response_callbacks["cluster addslots"] - == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] - ) - async def test_get_and_set(self, r: RedisCluster) -> None: # get and set can't be tested independently of each other assert await r.get("a") is None @@ -814,8 +820,8 @@ async def test_cluster_delslots(self) -> None: node0 = r.get_node(default_host, 7000) node1 = r.get_node(default_host, 7001) assert await r.cluster_delslots(0, 8192) == [True, True] - assert node0.redis_connection.connection.read_response.called - assert node1.redis_connection.connection.read_response.called + assert node0._free.pop().read_response.called + assert node1._free.pop().read_response.called await r.close() @@ -954,7 +960,7 @@ async def test_cluster_save_config(self, r: RedisCluster) -> None: @skip_if_redis_enterprise() async def test_cluster_get_keys_in_slot(self, r: RedisCluster) -> None: - response = [b"{foo}1", b"{foo}2"] + response = ["{foo}1", "{foo}2"] node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = await r.cluster_get_keys_in_slot(12182, 4) @@ -984,7 +990,7 @@ async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True - assert node.redis_connection.connection.read_response.called + assert node._free.pop().read_response.called @skip_if_redis_enterprise() async def test_cluster_replicas(self, r: RedisCluster) -> None: @@ -1026,7 +1032,7 @@ async def test_readonly(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica.redis_connection.connection.read_response.called + assert replica._free.pop().read_response.called await r.close() @@ -1039,7 +1045,7 @@ async def test_readwrite(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica.redis_connection.connection.read_response.called + assert replica._free.pop().read_response.called await r.close() @@ -2095,9 +2101,11 @@ async def test_init_slots_cache_slots_collision( raise an error. In this test both nodes will say that the first slots block should be bound to different servers. """ - with mock.patch.object(ClusterNode, "initialize", autospec=True) as initialize: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: - async def mocked_initialize(self, **kwargs): + async def mocked_execute_command(self, *args, **kwargs): """ Helper function to return custom slots cache data from different redis nodes @@ -2116,24 +2124,14 @@ async def mocked_initialize(self, **kwargs): else: result = [] - r_node = Redis(host=self.host, port=self.port) - - orig_execute_command = r_node.execute_command + if args[0] == "CLUSTER SLOTS": + return result + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} - async def execute_command(*args, **kwargs): - if args[0] == "CLUSTER SLOTS": - return result - elif args[0] == "INFO": - return {"cluster_enabled": True} - elif args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "yes"} - else: - return orig_execute_command(*args, **kwargs) - - r_node.execute_command = execute_command - return r_node - - initialize.side_effect = mocked_initialize + execute_command.side_effect = mocked_execute_command with pytest.raises(RedisClusterException) as ex: node_1 = ClusterNode("127.0.0.1", 7000) @@ -2172,30 +2170,25 @@ async def test_init_with_down_node(self) -> None: If I can't connect to one of the nodes, everything should still work. But if I can't connect to any of the nodes, exception should be thrown. """ - with mock.patch.object(ClusterNode, "initialize", autospec=True) as initialize: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command: - async def mocked_initialize(self, **kwargs): + async def mocked_execute_command(self, *args, **kwargs): if self.port == 7000: raise ConnectionError("mock connection error for 7000") - r_node = Redis(host=self.host, port=self.port, decode_responses=True) - - async def execute_command(*args, **kwargs): - if args[0] == "CLUSTER SLOTS": - return [ - [0, 8191, ["127.0.0.1", 7001, "node_1"]], - [8192, 16383, ["127.0.0.1", 7002, "node_2"]], - ] - elif args[0] == "INFO": - return {"cluster_enabled": True} - elif args[1] == "cluster-require-full-coverage": - return {"cluster-require-full-coverage": "yes"} - - r_node.execute_command = execute_command - - return r_node + if args[0] == "CLUSTER SLOTS": + return [ + [0, 8191, ["127.0.0.1", 7001, "node_1"]], + [8192, 16383, ["127.0.0.1", 7002, "node_2"]], + ] + elif args[0] == "INFO": + return {"cluster_enabled": True} + elif args[1] == "cluster-require-full-coverage": + return {"cluster-require-full-coverage": "yes"} - initialize.side_effect = mocked_initialize + execute_command.side_effect = mocked_execute_command node_1 = ClusterNode("127.0.0.1", 7000) node_2 = ClusterNode("127.0.0.1", 7001) @@ -2213,7 +2206,7 @@ async def execute_command(*args, **kwargs): def cmd_init_mock(self, r): self.commands = { - "get": { + "GET": { "name": "get", "arity": 2, "flags": ["readonly", "fast"], diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 3794c31891..376e3f8aaf 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -665,8 +665,8 @@ class TestClusterRedisCommands: def test_case_insensitive_command_names(self, r): assert ( - r.cluster_response_callbacks["cluster addslots"] - == r.cluster_response_callbacks["CLUSTER ADDSLOTS"] + r.cluster_response_callbacks["cluster slots"] + == r.cluster_response_callbacks["CLUSTER SLOTS"] ) def test_get_and_set(self, r): @@ -1038,7 +1038,7 @@ def test_cluster_save_config(self, r): @skip_if_redis_enterprise() def test_cluster_get_keys_in_slot(self, r): - response = [b"{foo}1", b"{foo}2"] + response = ["{foo}1", "{foo}2"] node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = r.cluster_get_keys_in_slot(12182, 4) From ee338b1193fd07aa7bb6285ae9924060322df2c7 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Sat, 23 Apr 2022 19:29:31 +0530 Subject: [PATCH 19/23] async_conn: optimize --- redis/asyncio/connection.py | 107 ++++++++++----------- tests/test_asyncio/test_connection_pool.py | 2 +- tests/test_asyncio/test_lock.py | 2 +- 3 files changed, 51 insertions(+), 60 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index f0e6d3da4f..a77c426849 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -24,7 +24,6 @@ Type, TypeVar, Union, - cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -110,32 +109,32 @@ def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): def encode(self, value: EncodableT) -> EncodedT: """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, str): + return value.encode(self.encoding, self.encoding_errors) if isinstance(value, (bytes, memoryview)): return value - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) if isinstance(value, (int, float)): + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) return repr(value).encode() - if not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ # type: ignore[unreachable] - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - return value.encode(self.encoding, self.encoding_errors) + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" if self.decode_responses or force: - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) if isinstance(value, bytes): return value.decode(self.encoding, self.encoding_errors) + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) return value @@ -336,7 +335,7 @@ def purge(self): def close(self): try: self.purge() - self._buffer.close() # type: ignore[union-attr] + self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -466,7 +465,7 @@ def on_disconnect(self): self._next_response = False async def can_read(self, timeout: float): - if not self._reader: + if not self._stream or not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._next_response is False: @@ -480,14 +479,14 @@ async def read_from_socket( timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ): - if self._stream is None or self._reader is None: - raise RedisError("Parser already closed.") - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - async with async_timeout.timeout(timeout): + if timeout is None: buffer = await self._stream.read(self._read_size) - if not isinstance(buffer, bytes) or len(buffer) == 0: + else: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None self._reader.feed(buffer) # data was read from the socket and added to the buffer. @@ -516,9 +515,6 @@ async def read_response( self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response: Union[ - EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] - ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -541,8 +537,7 @@ async def read_response( and isinstance(response[0], ConnectionError) ): raise response[0] - # cast as there won't be a ConnectionError here. - return cast(Union[EncodableT, List[EncodableT]], response) + return response DefaultParser: Type[Union[PythonParser, HiredisParser]] @@ -637,7 +632,7 @@ def __init__( self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: - if retry is None: + if not retry: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable @@ -681,7 +676,7 @@ def __del__(self): @property def is_connected(self): - return bool(self._reader and self._writer) + return self._reader and self._writer def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -713,7 +708,7 @@ async def connect(self): raise ConnectionError(exc) from exc try: - if self.redis_connect_func is None: + if not self.redis_connect_func: # Use the default on_connect function await self.on_connect() else: @@ -745,7 +740,7 @@ async def _connect(self): self._reader = reader self._writer = writer sock = writer.transport.get_extra_info("socket") - if sock is not None: + if sock: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: # TCP_KEEPALIVE @@ -856,32 +851,29 @@ async def check_health(self): await self.retry.call_with_retry(self._send_ping, self._ping_failed) async def _send_packed_command(self, command: Iterable[bytes]) -> None: - if self._writer is None: - raise RedisError("Connection already closed.") - self._writer.writelines(command) await self._writer.drain() async def send_packed_command( - self, - command: Union[bytes, str, Iterable[bytes]], - check_health: bool = True, - ): - """Send an already packed command to the Redis server""" - if not self._writer: + self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + ) -> None: + if not self.is_connected: await self.connect() - # guard against health check recursion - if check_health: + elif check_health: await self.check_health() + try: if isinstance(command, str): command = command.encode() if isinstance(command, bytes): command = [command] - await asyncio.wait_for( - self._send_packed_command(command), - self.socket_timeout, - ) + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(command), self.socket_timeout + ) + else: + self._writer.writelines(command) + await self._writer.drain() except asyncio.TimeoutError: await self.disconnect() raise TimeoutError("Timeout writing to socket") from None @@ -901,8 +893,6 @@ async def send_packed_command( async def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - if not self.is_connected: - await self.connect() await self.send_packed_command( self.pack_command(*args), check_health=kwargs.get("check_health", True) ) @@ -917,7 +907,12 @@ async def read_response(self, disable_decoding: bool = False): """Read the response from a previously sent command""" try: async with self._lock: - async with async_timeout.timeout(self.socket_timeout): + if self.socket_timeout: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: response = await self._parser.read_response( disable_decoding=disable_decoding ) @@ -1176,10 +1171,7 @@ def __init__( self._lock = asyncio.Lock() def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: - pieces = [ - ("path", self.path), - ("db", self.db), - ] + pieces = [("path", self.path), ("db", self.db)] if self.client_name: pieces.append(("client_name", self.client_name)) return pieces @@ -1248,12 +1240,11 @@ def parse_url(url: str) -> ConnectKwargs: parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - # We can't type this. - kwargs[name] = parser(value) # type: ignore[misc] + kwargs[name] = parser(value) except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value # type: ignore[misc] + kwargs[name] = value if parsed.username: kwargs["username"] = unquote(parsed.username) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 2cd9480c08..6c56558d59 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -664,7 +664,7 @@ async def r(self, create_redis): def assert_interval_advanced(self, connection): diff = connection.next_health_check - asyncio.get_event_loop().time() - assert self.interval > diff > (self.interval - 1) + assert self.interval >= diff > (self.interval - 1) async def test_health_check_runs(self, r): if r.connection: diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index c496718a67..8ceb3bc958 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -114,7 +114,7 @@ async def test_blocking_timeout(self, r, event_loop): start = event_loop.time() assert not await lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (event_loop.time() - start) > bt - sleep + assert bt >= (event_loop.time() - start) > bt - sleep await lock1.release() async def test_context_manager(self, r): From 0a48fff8fd0e577db8aa32624b344a45f218ba5d Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Wed, 27 Apr 2022 16:11:40 +0530 Subject: [PATCH 20/23] async_cluster/parser: optimize _get_moveable_keys --- redis/asyncio/parser.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 57c20e81fc..273fe0339f 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -53,12 +53,9 @@ async def get_keys( except KeyError: # try to split the command name and to take only the main command # e.g. 'memory' for 'memory usage' - cmd_name_split = args[0].split() - cmd_name = cmd_name_split[0] - if cmd_name in self.commands: - # save the splitted command to args - args = cmd_name_split + list(args[1:]) - else: + args = args[0].split() + list(args[1:]) + cmd_name = args[0] + if cmd_name not in self.commands: # We'll try to reinitialize the commands cache, if the engine # version has changed, the commands may not be current await self.initialize(redis_conn) @@ -84,14 +81,8 @@ async def get_keys( async def _get_moveable_keys( self, redis_conn: "ClusterNode", *args ) -> Optional[List[str]]: - pieces = [] - cmd_name = args[0] - # The command name should be splitted into separate arguments, - # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces = pieces + cmd_name.split() - pieces = pieces + list(args[1:]) try: - keys = await redis_conn.execute_command("COMMAND GETKEYS", *pieces) + keys = await redis_conn.execute_command("COMMAND GETKEYS", *args) except ResponseError as e: message = e.__str__() if ( From 9a4c741df59170a632eb75a7d300a9735fb4f1bf Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Wed, 27 Apr 2022 16:42:43 +0530 Subject: [PATCH 21/23] async_cluster: inlined check_slots_coverage --- redis/asyncio/cluster.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 3b66d675d5..6b09dc7fac 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -454,11 +454,12 @@ async def _determine_nodes( return [random.choice(list(self.nodes_manager.nodes_cache.values()))] # get the node that holds the key's slot - slot = await self._determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS - ) - return [node] + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(*args), + self.read_from_replicas and command in READ_COMMANDS, + ) + ] async def _determine_slot(self, *args) -> int: command = args[0] @@ -646,12 +647,11 @@ async def _execute_command( except BusyLoadingError: raise except (ConnectionError, TimeoutError): - connection_error_retry_counter += 1 - # Give the node 0.25 seconds to get back up and retry again # with same node and configuration. After 5 attempts then try # to reinitialize the cluster and see if the nodes # configuration has changed or not + connection_error_retry_counter += 1 if connection_error_retry_counter < 5: await asyncio.sleep(0.25) else: @@ -960,14 +960,6 @@ def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: if node.server_type == server_type ] - def check_slots_coverage(self, slots_cache: Dict[int, List["ClusterNode"]]) -> bool: - # Validate if all slots are covered or if we should try next - # startup node - for i in range(0, REDIS_CLUSTER_HASH_SLOTS): - if i not in slots_cache: - return False - return True - async def initialize(self) -> None: self.read_load_balancer.reset() tmp_nodes_cache = {} @@ -1076,10 +1068,13 @@ async def initialize(self) -> None: f'slots cache: {", ".join(disagreements)}' ) - fully_covered = self.check_slots_coverage(tmp_slots) + # Validate if all slots are covered or if we should try next startup node + fully_covered = True + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in tmp_slots: + fully_covered = False + break if fully_covered: - # Don't need to continue to the next startup node if all - # slots are covered break if not startup_nodes_reachable: From 4344145491564bcd8d26ca8e871c35f0bb7ee1cf Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Wed, 27 Apr 2022 16:41:39 +0530 Subject: [PATCH 22/23] async_cluster: update docstrings --- redis/asyncio/cluster.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 6b09dc7fac..7ce4b69961 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -67,11 +67,11 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand Pass one of parameters: - `url` - - `host` + - `host` & `port` - `startup_nodes` - | Use :meth:`initialize` to find cluster nodes & create connections. - | Use :meth:`close` to disconnect connections & close client. + | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. + | Use ``await`` :meth:`close` to disconnect connections & close client. Many commands support the target_nodes kwarg. It can be one of the :attr:`NODE_FLAGS`: @@ -82,10 +82,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand - :attr:`RANDOM` - :attr:`DEFAULT_NODE` + Note: This client is not thread/process/fork safe. + :param host: | Can be used to point to a startup node :param port: - | Port used if **host** or **url** is provided + | Port used if **host** is provided :param startup_nodes: | :class:`~.ClusterNode` to used as a startup node :param cluster_error_retry_attempts: @@ -117,7 +119,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See :meth:`.from_url` :param kwargs: | Extra arguments that will be passed to the - :class:`~redis.asyncio.connection.Connection` instance when created + :class:`~redis.asyncio.connection.Connection` instances when created :raises RedisClusterException: if any arguments are invalid. Eg: @@ -165,9 +167,9 @@ def from_url(cls, url: str, **kwargs) -> "RedisCluster": All querystring options are cast to their appropriate Python types. Boolean arguments can be specified with string values "True"/"False" or "Yes"/"No". Values that cannot be properly cast cause a - ``ValueError`` to be raised. Once parsed, the querystring arguments - and keyword arguments are passed to the ``ConnectionPool``'s - class initializer. In the case of conflicting arguments, querystring + ``ValueError`` to be raised. Once parsed, the querystring arguments and + keyword arguments are passed to :class:`~redis.asyncio.connection.Connection` + when created. In the case of conflicting arguments, querystring arguments always win. """ From 22506e22ca0412353d96992fa372dd2205d995e7 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Wed, 27 Apr 2022 22:08:00 +0530 Subject: [PATCH 23/23] async_cluster: add concurrent test & use read_response/_update_moved_slots without lock --- redis/asyncio/cluster.py | 21 ++++++-------- redis/asyncio/connection.py | 35 ++++++++++++++++++++++++ tests/test_asyncio/test_cluster.py | 44 ++++++++++++++++++------------ 3 files changed, 70 insertions(+), 30 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7ce4b69961..10a5675e82 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2,7 +2,6 @@ import collections import random import socket -import threading import warnings from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union @@ -339,7 +338,7 @@ async def on_connect(self, connection: Connection) -> None: # regardless of the server type. If this is a primary connection, # READONLY would not affect executing write commands. await connection.send_command("READONLY") - if str_if_bytes(await connection.read_response()) != "OK": + if str_if_bytes(await connection.read_response_without_lock()) != "OK": raise ConnectionError("READONLY command failed") def get_node( @@ -789,10 +788,10 @@ async def execute_command(self, *args, **kwargs) -> Any: connection = None if self._free: for _ in range(len(self._free)): - if self._free[0].is_connected: - connection = self._free.popleft() + connection = self._free.popleft() + if connection.is_connected: break - self._free.rotate(-1) + self._free.append(connection) else: connection = self._free.popleft() else: @@ -807,9 +806,11 @@ async def execute_command(self, *args, **kwargs) -> Any: await connection.send_packed_command(command, False) try: if NEVER_DECODE in kwargs: - response = await connection.read_response(disable_decoding=True) + response = await connection.read_response_without_lock( + disable_decoding=True + ) else: - response = await connection.read_response() + response = await connection.read_response_without_lock() except ResponseError: if EMPTY_RESPONSE in kwargs: return kwargs[EMPTY_RESPONSE] @@ -827,7 +828,6 @@ async def execute_command(self, *args, **kwargs) -> Any: class NodesManager: __slots__ = ( - "_lock", "_moved_exception", "_require_full_coverage", "connection_kwargs", @@ -852,7 +852,6 @@ def __init__( self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() - self._lock = threading.Lock() def get_node( self, @@ -936,9 +935,7 @@ def get_node_from_slot( self, slot: int, read_from_replicas: bool = False ) -> "ClusterNode": if self._moved_exception: - with self._lock: - if self._moved_exception: - self._update_moved_slots() + self._update_moved_slots() try: if read_from_replicas: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index a77c426849..ec91c952c0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -939,6 +939,41 @@ async def read_response(self, disable_decoding: bool = False): raise response from None return response + async def read_response_without_lock(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + if self.socket_timeout: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + def pack_command(self, *args: EncodableT) -> List[bytes]: """Pack a series of arguments into the Redis protocol""" output = [] diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index fb301c47eb..6543e2849c 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -134,7 +134,7 @@ def mock_node_resp( ) -> ClusterNode: connection = mock.AsyncMock() connection.is_connected = True - connection.read_response.return_value = response + connection.read_response_without_lock.return_value = response while node._free: node._free.pop() node._free.append(connection) @@ -288,10 +288,10 @@ async def test_execute_command_node_flag_primaries(self, r: RedisCluster) -> Non assert await r.ping(target_nodes=RedisCluster.PRIMARIES) is True for primary in primaries: conn = primary._free.pop() - assert conn.read_response.called is True + assert conn.read_response_without_lock.called is True for replica in replicas: conn = replica._free.pop() - assert conn.read_response.called is not True + assert conn.read_response_without_lock.called is not True async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None: """ @@ -305,10 +305,10 @@ async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None assert await r.ping(target_nodes=RedisCluster.REPLICAS) is True for replica in replicas: conn = replica._free.pop() - assert conn.read_response.called is True + assert conn.read_response_without_lock.called is True for primary in primaries: conn = primary._free.pop() - assert conn.read_response.called is not True + assert conn.read_response_without_lock.called is not True await r.close() @@ -320,7 +320,7 @@ async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> Non assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True for node in r.get_nodes(): conn = node._free.pop() - assert conn.read_response.called is True + assert conn.read_response_without_lock.called is True async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: """ @@ -331,7 +331,7 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None: called_count = 0 for node in r.get_nodes(): conn = node._free.pop() - if conn.read_response.called is True: + if conn.read_response_without_lock.called is True: called_count += 1 assert called_count == 1 @@ -344,7 +344,7 @@ async def test_execute_command_default_node(self, r: RedisCluster) -> None: mock_node_resp(def_node, "PONG") assert await r.ping() is True conn = def_node._free.pop() - assert conn.read_response.called + assert conn.read_response_without_lock.called async def test_ask_redirection(self, r: RedisCluster) -> None: """ @@ -488,7 +488,7 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: with mock.patch.multiple( Connection, send_command=mock.DEFAULT, - read_response=mock.DEFAULT, + read_response_without_lock=mock.DEFAULT, _connect=mock.DEFAULT, can_read=mock.DEFAULT, on_connect=mock.DEFAULT, @@ -520,7 +520,7 @@ def execute_command_mock_third(self, *args, **options): # so we'll mock some of the Connection's functions to allow it execute_command.side_effect = execute_command_mock_first mocks["send_command"].return_value = True - mocks["read_response"].return_value = "OK" + mocks["read_response_without_lock"].return_value = "OK" mocks["_connect"].return_value = True mocks["can_read"].return_value = False mocks["on_connect"].return_value = True @@ -682,6 +682,14 @@ async def test_not_require_full_coverage_cluster_down_error( else: raise e + async def test_can_run_concurrent_commands(self, r: RedisCluster) -> None: + assert await r.ping(target_nodes=RedisCluster.ALL_NODES) is True + assert all( + await asyncio.gather( + *(r.ping(target_nodes=RedisCluster.ALL_NODES) for _ in range(100)) + ) + ) + @pytest.mark.onlycluster class TestClusterRedisCommands: @@ -792,7 +800,7 @@ async def test_cluster_addslots(self, r: RedisCluster) -> None: @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - async def test_cluster_addslotsrange(self, r): + async def test_cluster_addslotsrange(self, r: RedisCluster): node = r.get_random_node() mock_node_resp(node, "OK") assert await r.cluster_addslotsrange(node, 1, 5) @@ -820,14 +828,14 @@ async def test_cluster_delslots(self) -> None: node0 = r.get_node(default_host, 7000) node1 = r.get_node(default_host, 7001) assert await r.cluster_delslots(0, 8192) == [True, True] - assert node0._free.pop().read_response.called - assert node1._free.pop().read_response.called + assert node0._free.pop().read_response_without_lock.called + assert node1._free.pop().read_response_without_lock.called await r.close() @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - async def test_cluster_delslotsrange(self, r): + async def test_cluster_delslotsrange(self, r: RedisCluster): node = r.get_random_node() mock_node_resp(node, "OK") await r.cluster_addslots(node, 1, 2, 3, 4, 5) @@ -990,7 +998,7 @@ async def test_cluster_setslot_stable(self, r: RedisCluster) -> None: node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert await r.cluster_setslot_stable(12182) is True - assert node._free.pop().read_response.called + assert node._free.pop().read_response_without_lock.called @skip_if_redis_enterprise() async def test_cluster_replicas(self, r: RedisCluster) -> None: @@ -1014,7 +1022,7 @@ async def test_cluster_replicas(self, r: RedisCluster) -> None: ) @skip_if_server_version_lt("7.0.0") - async def test_cluster_links(self, r): + async def test_cluster_links(self, r: RedisCluster): node = r.get_random_node() res = await r.cluster_links(node) links_to = sum(x.count("to") for x in res) @@ -1032,7 +1040,7 @@ async def test_readonly(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica._free.pop().read_response.called + assert replica._free.pop().read_response_without_lock.called await r.close() @@ -1045,7 +1053,7 @@ async def test_readwrite(self) -> None: for res in all_replicas_results.values(): assert res is True for replica in r.get_replicas(): - assert replica._free.pop().read_response.called + assert replica._free.pop().read_response_without_lock.called await r.close()