diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 0f9db8fb1a..f49a4fcd46 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -51,6 +51,7 @@ jobs: timeout-minutes: 30 strategy: max-parallel: 15 + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] test-type: ['standalone', 'cluster'] @@ -108,6 +109,7 @@ jobs: name: Install package from commit hash runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] steps: diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 3427956ced..544c733178 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -1,12 +1,12 @@ from base import Benchmark -from redis.connection import HiredisParser, PythonParser +from redis.connection import PythonParser, _HiredisParser class SocketReadBenchmark(Benchmark): ARGUMENTS = ( - {"name": "parser", "values": [PythonParser, HiredisParser]}, + {"name": "parser", "values": [PythonParser, _HiredisParser]}, { "name": "value_size", "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index bf90dde555..7b9508334d 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -7,7 +7,6 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -38,7 +37,6 @@ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", - "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9e16ee08de..9d84e5a61e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -253,6 +253,9 @@ def __init__( self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + if self.connection_pool.connection_kwargs.get("protocol") == "3": + self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + # If using a single connection client, we need to lock creation-of and use-of # the client in order to avoid race conditions such as using asyncio.gather # on a set of redis commands diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 569a0765f8..525c17b22d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -17,15 +17,8 @@ ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import ( - Connection, - DefaultParser, - Encoder, - SSLConnection, - parse_url, -) +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis @@ -60,6 +53,7 @@ TimeoutError, TryAgainError, ) +from redis.parsers import AsyncCommandsParser, Encoder from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import dict_merge, safe_str, str_if_bytes @@ -250,6 +244,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + protocol: Optional[int] = 2, ) -> None: if db: raise RedisClusterException( @@ -290,6 +285,7 @@ def __init__( "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, "retry": retry, + "protocol": protocol, } if ssl: @@ -344,7 +340,7 @@ def __init__( self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 - self.commands_parser = CommandsParser() + self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 057067a83e..d9c95834d5 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -38,26 +38,23 @@ from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.typing import EncodableT, EncodedT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -hiredis = None -if HIREDIS_AVAILABLE: - import hiredis +from ..parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -65,371 +62,19 @@ SYM_LF = b"\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - class _Sentinel(enum.Enum): sentinel = object() SENTINEL = _Sentinel.sentinel -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class _HiredisReaderArgs(TypedDict, total=False): - protocolError: Callable[[str], Exception] - replyError: Callable[[str], Exception] - encoding: Optional[str] - errors: Optional[str] - - -class Encoder: - """Encode strings to bytes-like and decode bytes-like to strings""" - - __slots__ = "encoding", "encoding_errors", "decode_responses" - - def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value: EncodableT) -> EncodedT: - """Return a bytestring or bytes-like representation of the value""" - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - if isinstance(value, (bytes, memoryview)): - return value - if isinstance(value, (int, float)): - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) - return repr(value).encode() - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - - def decode(self, value: EncodableT, force=False) -> EncodableT: - """Return a unicode string from the bytes-like representation""" - if self.decode_responses or force: - if isinstance(value, bytes): - return value.decode(self.encoding, self.encoding_errors) - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) - return value - - -ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - - -class BaseParser: - """Plain Python parsing class""" - - __slots__ = "_stream", "_read_size", "_connected" - - EXCEPTION_CLASSES: ExceptionMappingT = { - "ERR": { - "max number of clients reached": ConnectionError, - "Client sent AUTH, but no password is set": AuthenticationError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def __init__(self, socket_read_size: int): - self._stream: Optional[asyncio.StreamReader] = None - self._read_size = socket_read_size - self._connected = False - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def parse_error(self, response: str) -> ResponseError: - """Parse an error response""" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - def on_disconnect(self): - raise NotImplementedError() - - def on_connect(self, connection: "Connection"): - raise NotImplementedError() - - async def can_read_destructive(self) -> bool: - raise NotImplementedError() - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: - raise NotImplementedError() - - -class PythonParser(BaseParser): - """Plain Python parsing class""" - - __slots__ = ("encoder", "_buffer", "_pos", "_chunks") - - def __init__(self, socket_read_size: int): - super().__init__(socket_read_size) - self.encoder: Optional[Encoder] = None - self._buffer = b"" - self._chunks = [] - self._pos = 0 - - def _clear(self): - self._buffer = b"" - self._chunks.clear() - - def on_connect(self, connection: "Connection"): - """Called when the stream connects""" - self._stream = connection._reader - if self._stream is None: - raise RedisError("Buffer is closed.") - self.encoder = connection.encoder - self._clear() - self._connected = True - - def on_disconnect(self): - """Called when the stream disconnects""" - self._connected = False - - async def can_read_destructive(self) -> bool: - if not self._connected: - raise RedisError("Buffer is closed.") - if self._buffer: - return True - try: - async with async_timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read_response(self, disable_decoding: bool = False): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response - async def _read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None]: - raw = await self._readline() - response: Any - byte, response = raw[:1], raw[1:] - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - self._clear() # Successful parse - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - return int(response) - # bulk response - elif byte == b"$" and response == b"-1": - return None - elif byte == b"$": - response = await self._read(int(response)) - # multi-bulk response - elif byte == b"*" and response == b"-1": - return None - elif byte == b"*": - response = [ - (await self._read_response(disable_decoding)) - for _ in range(int(response)) # noqa - ] - else: - raise InvalidResponse(f"Protocol Error: {raw!r}") - - if disable_decoding is False: - response = self.encoder.decode(response) - return response - - async def _read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want - return result - - async def _readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 - return result - - -class HiredisParser(BaseParser): - """Parser class for connections using Hiredis""" - - __slots__ = ("_reader",) - - def __init__(self, socket_read_size: int): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not available.") - super().__init__(socket_read_size=socket_read_size) - self._reader: Optional[hiredis.Reader] = None - - def on_connect(self, connection: "Connection"): - self._stream = connection._reader - kwargs: _HiredisReaderArgs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - } - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - kwargs["errors"] = connection.encoder.encoding_errors - - self._reader = hiredis.Reader(**kwargs) - self._connected = True - - def on_disconnect(self): - self._connected = False - async def can_read_destructive(self): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets(): - return True - try: - async with async_timeout(0): - return await self.read_from_socket() - except asyncio.TimeoutError: - return False - - async def read_from_socket(self): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, List[EncodableT]]: - # If `on_disconnect()` has been called, prohibit any more reads - # even if they could happen because data might be present. - # We still allow reads in progress to finish - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - - response = self._reader.gets() - while response is False: - await self.read_from_socket() - response = self._reader.gets() - - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -DefaultParser: Type[Union[PythonParser, HiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _AsyncHiredisParser else: - DefaultParser = PythonParser + DefaultParser = _AsyncRESP2Parser class ConnectCallbackProtocol(Protocol): @@ -470,6 +115,7 @@ class Connection: "last_active_at", "encoder", "ssl_context", + "protocol", "_reader", "_writer", "_parser", @@ -506,6 +152,7 @@ def __init__( redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): if (username or password) and credential_provider is not None: raise DataError( @@ -556,6 +203,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + self.protocol = protocol def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) @@ -710,6 +358,18 @@ async def on_connect(self) -> None: if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + if self.protocol != 2: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: await self.send_command("CLIENT", "SETNAME", self.client_name) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py deleted file mode 100644 index 5faf8f8c57..0000000000 --- a/redis/asyncio/parser.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union - -from redis.exceptions import RedisError, ResponseError - -if TYPE_CHECKING: - from redis.asyncio.cluster import ClusterNode - - -class CommandsParser: - """ - Parses Redis commands to get command keys. - - COMMAND output is used to determine key locations. - Commands that do not have a predefined key location are flagged with 'movablekeys', - and these commands' keys are determined by the command 'COMMAND GETKEYS'. - - NOTE: Due to a bug in redis<7.0, this does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this with EVAL or EVALSHA. - """ - - __slots__ = ("commands", "node") - - def __init__(self) -> None: - self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} - - async def initialize(self, node: Optional["ClusterNode"] = None) -> None: - if node: - self.node = node - - commands = await self.node.execute_command("COMMAND") - for cmd, command in commands.items(): - if "movablekeys" in command["flags"]: - commands[cmd] = -1 - elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: - commands[cmd] = 0 - elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: - commands[cmd] = 1 - self.commands = {cmd.upper(): command for cmd, command in commands.items()} - - # As soon as this PR is merged into Redis, we should reimplement - # our logic to use COMMAND INFO changes to determine the key positions - # https://github.com/redis/redis/pull/8324 - async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - if len(args) < 2: - # The command has no keys in it - return None - - try: - command = self.commands[args[0]] - except KeyError: - # try to split the command name and to take only the main command - # e.g. 'memory' for 'memory usage' - args = args[0].split() + list(args[1:]) - cmd_name = args[0].upper() - if cmd_name not in self.commands: - # We'll try to reinitialize the commands cache, if the engine - # version has changed, the commands may not be current - await self.initialize() - if cmd_name not in self.commands: - raise RedisError( - f"{cmd_name} command doesn't exist in Redis commands" - ) - - command = self.commands[cmd_name] - - if command == 1: - return (args[1],) - if command == 0: - return None - if command == -1: - return await self._get_moveable_keys(*args) - - last_key_pos = command["last_key_pos"] - if last_key_pos < 0: - last_key_pos = len(args) + last_key_pos - return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] - - async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - try: - keys = await self.node.execute_command("COMMAND GETKEYS", *args) - except ResponseError as e: - message = e.__str__() - if ( - "Invalid arguments" in message - or "The command has no key arguments" in message - ): - return None - else: - raise e - return keys diff --git a/redis/client.py b/redis/client.py index 1a9b96b83d..15dddc9bd7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -318,7 +318,10 @@ def parse_xautoclaim(response, **options): def parse_xinfo_stream(response, **options): - data = pairs_to_dict(response, decode_keys=True) + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} if not options.get("full", False): first = data["first-entry"] if first is not None: @@ -340,6 +343,12 @@ def parse_xread(response): return [[r[0], parse_stream_list(r[1])] for r in response] +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + def parse_xpending(response, **options): if options.get("parse_detail", False): return parse_xpending_range(response) @@ -578,7 +587,10 @@ def parse_client_kill(response, **options): def parse_acl_getuser(response, **options): if response is None: return None - data = pairs_to_dict(response, decode_keys=True) + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} # convert everything but user-defined data in 'keys' to native strings data["flags"] = list(map(str_if_bytes, data["flags"])) @@ -841,6 +853,43 @@ class AbstractRedis: "ZMSCORE": parse_zmscore, } + RESP3_RESPONSE_CALLBACKS = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} + for x in r + ] + if isinstance(r, list) + else bool_ok(r), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "MEMORY STATS": lambda r: { + str_if_bytes(key): value for key, value in r.items() + }, + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} for d in r + ], + } + class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): """ @@ -942,6 +991,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -990,6 +1040,7 @@ def __init__( "client_name": client_name, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1037,6 +1088,9 @@ def __init__( self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + if self.connection_pool.connection_kwargs.get("protocol") == "3": + self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" diff --git a/redis/cluster.py b/redis/cluster.py index 5e6e7da546..182ec6d733 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -8,8 +8,8 @@ from redis.backoff import default_backoff from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -29,6 +29,7 @@ TryAgainError, ) from redis.lock import Lock +from redis.parsers import CommandsParser, Encoder from redis.retry import Retry from redis.utils import ( dict_merge, @@ -138,6 +139,7 @@ def parse_cluster_shards(resp, **options): "queue_class", "retry", "retry_on_timeout", + "protocol", "socket_connect_timeout", "socket_keepalive", "socket_keepalive_options", diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f3f08286c8..a94d9764a6 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,7 +1,6 @@ from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args -from .parser import CommandsParser from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands from .sentinel import AsyncSentinelCommands, SentinelCommands @@ -10,7 +9,6 @@ "AsyncRedisClusterCommands", "AsyncRedisModuleCommands", "AsyncSentinelCommands", - "CommandsParser", "CoreCommands", "READ_COMMANDS", "RedisClusterCommands", diff --git a/redis/connection.py b/redis/connection.py index faea7683f7..85509f7ef7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,64 +1,39 @@ import copy -import errno -import io import os import socket +import ssl import sys import threading import weakref from abc import abstractmethod -from io import SEEK_END from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Union +from typing import Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from redis.backoff import NoBackoff -from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider -from redis.exceptions import ( +from .backoff import NoBackoff +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.retry import Retry -from redis.utils import ( +from .parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser +from .retry import Retry +from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, + SSL_AVAILABLE, str_if_bytes, ) -try: - import ssl - - ssl_available = True -except ImportError: - ssl_available = False - -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} - -if ssl_available: - if hasattr(ssl, "SSLWantReadError"): - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 - else: - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - if HIREDIS_AVAILABLE: import hiredis @@ -67,452 +42,13 @@ SYM_CRLF = b"\r\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - SENTINEL = object() -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class Encoder: - "Encode strings to bytes-like and decode bytes-like to strings" - - def __init__(self, encoding, encoding_errors, decode_responses): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value): - "Return a bytestring or bytes-like representation of the value" - if isinstance(value, (bytes, memoryview)): - return value - elif isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first." - ) - elif isinstance(value, (int, float)): - value = repr(value).encode() - elif not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = type(value).__name__ - raise DataError( - f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first." - ) - if isinstance(value, str): - value = value.encode(self.encoding, self.encoding_errors) - return value - - def decode(self, value, force=False): - "Return a unicode string from the bytes-like representation" - if self.decode_responses or force: - if isinstance(value, memoryview): - value = value.tobytes() - if isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - return value - - -class BaseParser: - EXCEPTION_CLASSES = { - "ERR": { - "max number of clients reached": ConnectionError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments " - "for 'auth' command": AuthenticationWrongNumberOfArgsError, - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments " - "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def parse_error(self, response): - "Parse an error response" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - -class SocketBuffer: - def __init__( - self, socket: socket.socket, socket_read_size: int, socket_timeout: float - ): - self._sock = socket - self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() - - def unread_bytes(self) -> int: - """ - Remaining unread length of buffer - """ - pos = self._buffer.tell() - end = self._buffer.seek(0, SEEK_END) - self._buffer.seek(pos) - return end - pos - - def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, object] = SENTINEL, - raise_on_timeout: Optional[bool] = True, - ) -> bool: - sock = self._sock - socket_read_size = self.socket_read_size - marker = 0 - custom_timeout = timeout is not SENTINEL - - buf = self._buffer - current_pos = buf.tell() - buf.seek(0, SEEK_END) - if custom_timeout: - sock.settimeout(timeout) - try: - while True: - data = self._sock.recv(socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - marker += data_length - - if length is not None and length > marker: - continue - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - buf.seek(current_pos) - if custom_timeout: - sock.settimeout(self.socket_timeout) - def can_read(self, timeout: float) -> bool: - return bool(self.unread_bytes()) or self._read_from_socket( - timeout=timeout, raise_on_timeout=False - ) - - def read(self, length: int) -> bytes: - length = length + 2 # make sure to read the \r\n terminator - # BufferIO will return less than requested if buffer is short - data = self._buffer.read(length) - missing = length - len(data) - if missing: - # fill up the buffer and read the remainder - self._read_from_socket(missing) - data += self._buffer.read(missing) - return data[:-2] - - def readline(self) -> bytes: - buf = self._buffer - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - self._read_from_socket() - data += buf.readline() - - return data[:-2] - - def get_pos(self) -> int: - """ - Get current read position - """ - return self._buffer.tell() - - def rewind(self, pos: int) -> None: - """ - Rewind the buffer to a specific position, to re-start reading - """ - self._buffer.seek(pos) - - def purge(self) -> None: - """ - After a successful read, purge the read part of buffer - """ - unread = self.unread_bytes() - - # Only if we have read all of the buffer do we truncate, to - # reduce the amount of memory thrashing. This heuristic - # can be changed or removed later. - if unread > 0: - return - - if unread > 0: - # move unread data to the front - view = self._buffer.getbuffer() - view[:unread] = view[-unread:] - self._buffer.truncate(unread) - self._buffer.seek(0) - - def close(self) -> None: - try: - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None - self._sock = None - - -class PythonParser(BaseParser): - "Plain Python parsing class" - - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - self.encoder = None - self._sock = None - self._buffer = None - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection): - "Called when the socket connects" - self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) - self.encoder = connection.encoder - - def on_disconnect(self): - "Called when the socket disconnects" - self._sock = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None - self.encoder = None - - def can_read(self, timeout): - return self._buffer and self._buffer.can_read(timeout) - - def read_response(self, disable_decoding=False): - pos = self._buffer.get_pos() if self._buffer else None - try: - result = self._read_response(disable_decoding=disable_decoding) - except BaseException: - if self._buffer: - self._buffer.rewind(pos) - raise - else: - self._buffer.purge() - return result - - def _read_response(self, disable_decoding=False): - raw = self._buffer.readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - byte, response = raw[:1], raw[1:] - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - return int(response) - # bulk response - elif byte == b"$" and response == b"-1": - return None - elif byte == b"$": - response = self._buffer.read(int(response)) - # multi-bulk response - elif byte == b"*" and response == b"-1": - return None - elif byte == b"*": - response = [ - self._read_response(disable_decoding=disable_decoding) - for i in range(int(response)) - ] - else: - raise InvalidResponse(f"Protocol Error: {raw!r}") - - if disable_decoding is False: - response = self.encoder.decode(response) - return response - - -class HiredisParser(BaseParser): - "Parser class for connections using Hiredis" - - def __init__(self, socket_read_size): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not installed") - self.socket_read_size = socket_read_size - self._buffer = bytearray(socket_read_size) - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection, **kwargs): - self._sock = connection._sock - self._socket_timeout = connection.socket_timeout - kwargs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - "errors": connection.encoder.encoding_errors, - } - - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - self._reader = hiredis.Reader(**kwargs) - self._next_response = False - - def on_disconnect(self): - self._sock = None - self._reader = None - self._next_response = False - - def can_read(self, timeout): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - if self._next_response is False: - self._next_response = self._reader.gets() - if self._next_response is False: - return self.read_from_socket(timeout=timeout, raise_on_timeout=False) - return True - - def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): - sock = self._sock - custom_timeout = timeout is not SENTINEL - try: - if custom_timeout: - sock.settimeout(timeout) - bufflen = self._sock.recv_into(self._buffer) - if bufflen == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - self._reader.feed(self._buffer, 0, bufflen) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - if custom_timeout: - sock.settimeout(self._socket_timeout) - - def read_response(self, disable_decoding=False): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - # _next_response might be cached from a can_read() call - if self._next_response is not False: - response = self._next_response - self._next_response = False - return response - - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - - while response is False: - self.read_from_socket() - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -DefaultParser: BaseParser +DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _HiredisParser else: - DefaultParser = PythonParser + DefaultParser = _RESP2Parser class HiredisRespSerializer: @@ -604,6 +140,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, command_packer=None, ): """ @@ -652,6 +189,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + self.protocol = protocol self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): @@ -763,6 +301,18 @@ def on_connect(self): if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + if self.protocol != 2: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol) + response = self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: self.send_command("CLIENT", "SETNAME", self.client_name) @@ -1054,7 +604,7 @@ def __init__( Raises: RedisError """ # noqa - if not ssl_available: + if not SSL_AVAILABLE: raise RedisError("Python wasn't built with SSL support") self.keyfile = ssl_keyfile diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py new file mode 100644 index 0000000000..0586016a61 --- /dev/null +++ b/redis/parsers/__init__.py @@ -0,0 +1,19 @@ +from .base import BaseParser +from .commands import AsyncCommandsParser, CommandsParser +from .encoders import Encoder +from .hiredis import _AsyncHiredisParser, _HiredisParser +from .resp2 import _AsyncRESP2Parser, _RESP2Parser +from .resp3 import _AsyncRESP3Parser, _RESP3Parser + +__all__ = [ + "AsyncCommandsParser", + "_AsyncHiredisParser", + "_AsyncRESP2Parser", + "_AsyncRESP3Parser", + "CommandsParser", + "Encoder", + "BaseParser", + "_HiredisParser", + "_RESP2Parser", + "_RESP3Parser", +] diff --git a/redis/parsers/base.py b/redis/parsers/base.py new file mode 100644 index 0000000000..b98a44ef2f --- /dev/null +++ b/redis/parsers/base.py @@ -0,0 +1,229 @@ +import sys +from abc import ABC +from asyncio import IncompleteReadError, StreamReader, TimeoutError +from typing import List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from ..exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ConnectionError, + ExecAbortError, + ModuleError, + NoPermissionError, + NoScriptError, + ReadOnlyError, + RedisError, + ResponseError, +) +from ..typing import EncodableT +from .encoders import Encoder +from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer + +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) +# user send an AUTH cmd to a server without authorization configured +NO_AUTH_SET_ERROR = { + # Redis >= 6.0 + "AUTH called without any password " + "configured for the default user. Are you sure " + "your configuration is correct?": AuthenticationError, + # Redis < 6.0 + "Client sent AUTH, but no password is set": AuthenticationError, +} + + +class BaseParser(ABC): + + EXCEPTION_CLASSES = { + "ERR": { + "max number of clients reached": ConnectionError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + **NO_AUTH_SET_ERROR, + }, + "WRONGPASS": AuthenticationError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def parse_error(self, response): + "Parse an error response" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection): + raise NotImplementedError() + + +class _RESPBase(BaseParser): + """Base class for sync-based resp parsing""" + + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + self.encoder = None + self._sock = None + self._buffer = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._sock = connection._sock + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + "Called when the socket disconnects" + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + def can_read(self, timeout): + return self._buffer and self._buffer.can_read(timeout) + + +class AsyncBaseParser(BaseParser): + """Base parsing class for the python-backed async parser""" + + __slots__ = "_stream", "_read_size" + + def __init__(self, socket_read_size: int): + self._stream: Optional[StreamReader] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + async def can_read_destructive(self) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class _AsyncRESPBase(AsyncBaseParser): + """Base class for async resp parsing""" + + __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() + + def on_connect(self, connection): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + self.encoder = connection.encoder + self._clear() + self._connected = True + + def on_disconnect(self): + """Called when the stream disconnects""" + self._connected = False + + async def can_read_destructive(self) -> bool: + if not self._connected: + raise RedisError("Buffer is closed.") + if self._buffer: + return True + try: + async with async_timeout(0): + return await self._stream.read(1) + except TimeoutError: + return False + + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result diff --git a/redis/commands/parser.py b/redis/parsers/commands.py similarity index 63% rename from redis/commands/parser.py rename to redis/parsers/commands.py index 115230a9d2..2ea29a75ae 100644 --- a/redis/commands/parser.py +++ b/redis/parsers/commands.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + from redis.exceptions import RedisError, ResponseError from redis.utils import str_if_bytes +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + class CommandsParser: """ @@ -16,7 +21,7 @@ def __init__(self, redis_connection): self.initialize(redis_connection) def initialize(self, r): - commands = r.execute_command("COMMAND") + commands = r.command() uppercase_commands = [] for cmd in commands: if any(x.isupper() for x in cmd): @@ -117,12 +122,9 @@ def _get_moveable_keys(self, redis_conn, *args): So, don't use this function with EVAL or EVALSHA. """ - pieces = [] - cmd_name = args[0] # The command name should be splitted into separate arguments, # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces = pieces + cmd_name.split() - pieces = pieces + list(args[1:]) + pieces = args[0].split() + list(args[1:]) try: keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: @@ -164,3 +166,91 @@ def _get_pubsub_keys(self, *args): # PUBLISH channel message keys = [args[1]] return keys + + +class AsyncCommandsParser: + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands", "node") + + def __init__(self) -> None: + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} + + async def initialize(self, node: Optional["ClusterNode"] = None) -> None: + if node: + self.node = node + + commands = await self.node.execute_command("COMMAND") + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 + self.commands = {cmd.upper(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + if len(args) < 2: + # The command has no keys in it + return None + + try: + command = self.commands[args[0]] + except KeyError: + # try to split the command name and to take only the main command + # e.g. 'memory' for 'memory usage' + args = args[0].split() + list(args[1:]) + cmd_name = args[0].upper() + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize() + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name} command doesn't exist in Redis commands" + ) + + command = self.commands[cmd_name] + + if command == 1: + return (args[1],) + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(*args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] + + async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + try: + keys = await self.node.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys diff --git a/redis/parsers/encoders.py b/redis/parsers/encoders.py new file mode 100644 index 0000000000..6fdf0ad882 --- /dev/null +++ b/redis/parsers/encoders.py @@ -0,0 +1,44 @@ +from ..exceptions import DataError + + +class Encoder: + "Encode strings to bytes-like and decode bytes-like to strings" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) + elif isinstance(value, (int, float)): + value = repr(value).encode() + elif not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value diff --git a/redis/parsers/hiredis.py b/redis/parsers/hiredis.py new file mode 100644 index 0000000000..b3247b71ec --- /dev/null +++ b/redis/parsers/hiredis.py @@ -0,0 +1,217 @@ +import asyncio +import socket +import sys +from typing import Callable, List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from redis.compat import TypedDict + +from ..exceptions import ConnectionError, InvalidResponse, RedisError +from ..typing import EncodableT +from ..utils import HIREDIS_AVAILABLE +from .base import AsyncBaseParser, BaseParser +from .socket import ( + NONBLOCKING_EXCEPTION_ERROR_NUMBERS, + NONBLOCKING_EXCEPTIONS, + SENTINEL, + SERVER_CLOSED_CONNECTION_ERROR, +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class _HiredisParser(BaseParser): + "Parser class for connections using Hiredis" + + def __init__(self, socket_read_size): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + self.socket_read_size = socket_read_size + self._buffer = bytearray(socket_read_size) + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection, **kwargs): + import hiredis + + self._sock = connection._sock + self._socket_timeout = connection.socket_timeout + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + "errors": connection.encoder.encoding_errors, + } + + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + + def on_disconnect(self): + self._sock = None + self._reader = None + self._next_response = False + + def can_read(self, timeout): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + sock = self._sock + custom_timeout = timeout is not SENTINEL + try: + if custom_timeout: + sock.settimeout(timeout) + bufflen = self._sock.recv_into(self._buffer) + if bufflen == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + self._reader.feed(self._buffer, 0, bufflen) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + if custom_timeout: + sock.settimeout(self._socket_timeout) + + def read_response(self, disable_decoding=False): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + + while response is False: + self.read_from_socket() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +class _AsyncHiredisParser(AsyncBaseParser): + """Async implementation of parser class for connections using Hiredis""" + + __slots__ = ("_reader",) + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader = None + + def on_connect(self, connection): + import hiredis + + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._connected = True + + def on_disconnect(self): + self._connected = False + + async def can_read_destructive(self): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._reader.gets(): + return True + try: + async with async_timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: + return False + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response diff --git a/redis/parsers/resp2.py b/redis/parsers/resp2.py new file mode 100644 index 0000000000..0acd21164f --- /dev/null +++ b/redis/parsers/resp2.py @@ -0,0 +1,131 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP2Parser(_RESPBase): + """RESP2 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = self._buffer.read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP2Parser(_AsyncRESPBase): + """Async class for the RESP2 protocol""" + + async def read_response(self, disable_decoding: bool = False): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = await self._read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py new file mode 100644 index 0000000000..2753d39f1a --- /dev/null +++ b/redis/parsers/resp3.py @@ -0,0 +1,174 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP3Parser(_RESPBase): + """RESP3 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = self._buffer.read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response and verbatim strings + elif byte in (b"$", b"="): + response = self._buffer.read(int(response)) + # array response + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + response = { + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + } + # map response + elif byte == b"%": + response = { + self._read_response( + disable_decoding=disable_decoding + ): self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + } + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP3Parser(_AsyncRESPBase): + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # if byte not in (b"-", b"+", b":", b"$", b"*"): + # raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = await self._read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response and verbatim strings + elif byte in (b"$", b"="): + response = await self._read(int(response)) + # array response + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + response = { + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + } + # map response + elif byte == b"%": + response = { + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) + ) + for _ in range(int(response)) + } + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py new file mode 100644 index 0000000000..8147243bba --- /dev/null +++ b/redis/parsers/socket.py @@ -0,0 +1,162 @@ +import errno +import io +import socket +from io import SEEK_END +from typing import Optional, Union + +from ..exceptions import ConnectionError, TimeoutError +from ..utils import SSL_AVAILABLE + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} + +if SSL_AVAILABLE: + import ssl + + if hasattr(ssl, "SSLWantReadError"): + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 + else: + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +SENTINEL = object() + +SYM_CRLF = b"\r\n" + + +class SocketBuffer: + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): + self._sock = socket + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos + + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: + sock = self._sock + socket_read_size = self.socket_read_size + marker = 0 + custom_timeout = timeout is not SENTINEL + + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) + try: + while True: + data = self._sock.recv(socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + marker += data_length + + if length is not None and length > marker: + continue + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + buf.seek(current_pos) + if custom_timeout: + sock.settimeout(self.socket_timeout) + + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # BufferIO will return less than requested if buffer is short + data = self._buffer.read(length) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) + return data[:-2] + + def readline(self) -> bytes: + buf = self._buffer + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + data += buf.readline() + + return data[:-2] + + def get_pos(self) -> int: + """ + Get current read position + """ + return self._buffer.tell() + + def rewind(self, pos: int) -> None: + """ + Rewind the buffer to a specific position, to re-start reading + """ + self._buffer.seek(pos) + + def purge(self) -> None: + """ + After a successful read, purge the read part of buffer + """ + unread = self.unread_bytes() + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self._buffer.seek(0) + + def close(self) -> None: + try: + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._sock = None diff --git a/redis/typing.py b/redis/typing.py index 8504c7de0c..7c5908ff0c 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,14 +1,23 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Iterable, + Mapping, + Type, + TypeVar, + Union, +) from redis.compat import Protocol if TYPE_CHECKING: from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.asyncio.connection import Encoder as AsyncEncoder - from redis.connection import ConnectionPool, Encoder + from redis.connection import ConnectionPool + from redis.parsers import Encoder Number = Union[int, float] @@ -39,6 +48,8 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] @@ -48,7 +59,7 @@ def execute_command(self, *args, **options): class ClusterCommandsProtocol(CommandsProtocol): - encoder: Union["AsyncEncoder", "Encoder"] + encoder: "Encoder" def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... diff --git a/redis/utils.py b/redis/utils.py index d95e62c042..a6e620088b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -12,6 +12,13 @@ HIREDIS_AVAILABLE = False HIREDIS_PACK_AVAILABLE = False +try: + import ssl # noqa + + SSL_AVAILABLE = True +except ImportError: + SSL_AVAILABLE = False + try: import cryptography # noqa diff --git a/setup.py b/setup.py index 3003c59420..f37e77df67 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="4.5.3", + version="5.0.0b1", packages=find_packages( include=[ "redis", @@ -19,6 +19,7 @@ "redis.commands.search", "redis.commands.timeseries", "redis.commands.graph", + "redis.parsers", ] ), url="https://github.com/redis/redis-py", diff --git a/tests/conftest.py b/tests/conftest.py index 27dcc741a7..035dbc85cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/9" +default_redis_url = "redis://localhost:6379/0" default_redismod_url = "redis://localhost:36379" default_redis_unstable_url = "redis://localhost:6378" @@ -472,3 +472,11 @@ def wait_for_command(client, monitor, command, key=None): return monitor_response if key in monitor_response["command"]: return None + + +def is_resp2_connection(r): + if isinstance(r, redis.Redis): + protocol = r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.RedisCluster): + protocol = r.nodes_manager.connection_kwargs.get("protocol") + return protocol == "2" or protocol is None diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6982cc840a..e8ab6b297f 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -9,14 +9,11 @@ import redis.asyncio as redis from redis.asyncio.client import Monitor -from redis.asyncio.connection import ( - HIREDIS_AVAILABLE, - HiredisParser, - PythonParser, - parse_url, -) +from redis.asyncio.connection import parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff +from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import REDIS_INFO from .compat import mock @@ -32,14 +29,14 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ pytest.param( - (True, PythonParser), + (True, _AsyncRESP2Parser), marks=pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', reason="cluster mode enabled" ), ), - (False, PythonParser), + (False, _AsyncRESP2Parser), pytest.param( - (True, HiredisParser), + (True, _AsyncHiredisParser), marks=[ pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', @@ -51,7 +48,7 @@ async def _get_info(redis_url): ], ), pytest.param( - (False, HiredisParser), + (False, _AsyncHiredisParser), marks=pytest.mark.skipif( not HIREDIS_AVAILABLE, reason="hiredis is not installed" ), @@ -239,6 +236,29 @@ async def wait_for_command( return None +def get_protocol_version(r): + if isinstance(r, redis.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.RedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response in resp2_expected + else: + assert response in resp3_expected + + # python 3.6 doesn't have the asynccontextmanager decorator. Provide it here. class AsyncContextManager: def __init__(self, async_generator): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 0857c056c2..a80fa30cb9 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -12,7 +12,6 @@ from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster from redis.asyncio.connection import Connection, SSLConnection -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name @@ -29,6 +28,7 @@ RedisError, ResponseError, ) +from redis.parsers import AsyncCommandsParser from redis.utils import str_if_bytes from tests.conftest import ( skip_if_redis_enterprise, @@ -99,7 +99,7 @@ async def execute_command(*_args, **_kwargs): execute_command_mock.side_effect = execute_command with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -566,7 +566,7 @@ def map_7007(self): mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -2358,7 +2358,7 @@ async def mocked_execute_command(self, *args, **kwargs): assert "Redis Cluster cannot be connected" in str(e.value) with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7c6fd45ab9..866929b2e4 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -18,6 +18,8 @@ skip_unless_arch_bits, ) +from .conftest import assert_resp_response, assert_resp_response_in + REDIS_6_VERSION = "5.9.0" @@ -264,7 +266,8 @@ async def test_acl_log(self, r_teardown, create_redis): assert len(await r.acl_log()) == 2 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) - assert "client-info" in (await r.acl_log(count=1))[0] + expected = (await r.acl_log(count=1))[0] + assert_resp_response_in(r, "client-info", expected, expected.keys()) assert await r.acl_log_reset() @skip_if_server_version_lt(REDIS_6_VERSION) @@ -915,6 +918,19 @@ async def test_pttl_no_key(self, r: redis.Redis): """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @skip_if_server_version_lt("6.2.0") + async def test_hrandfield(self, r): + assert await r.hrandfield("key") is None + await r.hset("key", mapping={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + assert await r.hrandfield("key") is not None + assert len(await r.hrandfield("key", 2)) == 2 + # with values + assert_resp_response(r, len(await r.hrandfield("key", 2, True)), 4, 2) + # without duplications + assert len(await r.hrandfield("key", 10)) == 5 + # with duplications + assert len(await r.hrandfield("key", -10)) == 10 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None @@ -1374,7 +1390,10 @@ async def test_spop_multi_value(self, r: redis.Redis): for value in values: assert value in s - assert await r.spop("a", 1) == list(set(s) - set(values)) + response = await r.spop("a", 1) + assert_resp_response( + r, response, list(set(s) - set(values)), set(s) - set(values) + ) async def test_srandmember(self, r: redis.Redis): s = [b"1", b"2", b"3"] @@ -1412,11 +1431,13 @@ async def test_sunionstore(self, r: redis.Redis): async def test_zadd(self, r: redis.Redis): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} await r.zadd("a", mapping) - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -1433,23 +1454,24 @@ async def test_zadd(self, r: redis.Redis): async def test_zadd_nx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) async def test_zadd_xx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert await r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 99.0)], [[b"a1", 99.0]]) async def test_zadd_ch(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 99.0)], [[b"a2", 2.0], [b"a1", 99.0]] + ) async def test_zadd_incr(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 @@ -1473,6 +1495,25 @@ async def test_zcount(self, r: redis.Redis): assert await r.zcount("a", 1, "(" + str(2)) == 1 assert await r.zcount("a", 10, 20) == 0 + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiff(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["a", "b"]) == [b"a3"] + response = await r.zdiff(["a", "b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiffstore(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("out", ["a", "b"]) + assert await r.zrange("out", 0, -1) == [b"a3"] + response = await r.zrange("out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) + async def test_zincrby(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zincrby("a", 1, "a2") == 3.0 @@ -1492,7 +1533,10 @@ async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"]) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): @@ -1500,7 +1544,10 @@ async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): @@ -1508,7 +1555,10 @@ async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): @@ -1516,23 +1566,34 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmax("a") == [(b"a3", 3)] + response = await r.zpopmax("a") + assert_resp_response(r, response, [(b"a3", 3)], [b"a3", 3.0]) # with count - assert await r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + response = await r.zpopmax("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmin("a") == [(b"a1", 1)] + response = await r.zpopmin("a") + assert_resp_response(r, response, [(b"a1", 1)], [b"a1", 1.0]) # with count - assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + response = await r.zpopmin("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster @@ -1566,20 +1627,20 @@ async def test_zrange(self, r: redis.Redis): assert await r.zrange("a", 1, 2) == [b"a2", b"a3"] # withscores - assert await r.zrange("a", 0, 1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] - assert await r.zrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) + response = await r.zrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]] + ) # custom score function - assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] @skip_if_server_version_lt("2.8.9") async def test_zrangebylex(self, r: redis.Redis): @@ -1613,16 +1674,24 @@ async def test_zrangebyscore(self, r: redis.Redis): assert await r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert await r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] + response = await r.zrangebyscore("a", 2, 4, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) # custom score function - assert await r.zrangebyscore( + response = await r.zrangebyscore( "a", 2, 4, withscores=True, score_cast_func=int - ) == [(b"a2", 2), (b"a3", 3), (b"a4", 4)] + ) + assert_resp_response( + r, + response, + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) async def test_zrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1670,20 +1739,20 @@ async def test_zrevrange(self, r: redis.Redis): assert await r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert await r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert await r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]] + ) + response = await r.zrevrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]] + ) # custom score function - assert await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) + assert_resp_response( + r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]] + ) async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1693,16 +1762,24 @@ async def test_zrevrangebyscore(self, r: redis.Redis): assert await r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] # withscores - assert await r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrangebyscore("a", 4, 2, withscores=True) + assert_resp_response( + r, + response, + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - assert await r.zrevrangebyscore( + response = await r.zrevrangebyscore( "a", 4, 2, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + ) + assert_resp_response( + r, + response, + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) async def test_zrevrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1722,12 +1799,13 @@ async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"]) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): @@ -1735,12 +1813,13 @@ async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + respponse = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + respponse, + [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): @@ -1748,12 +1827,13 @@ async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): @@ -1761,12 +1841,13 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") @@ -2761,28 +2842,30 @@ async def test_xread(self, r: redis.Redis): m1 = await r.xadd(stream, {"foo": "bar"}) m2 = await r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert await r.xread(streams={stream: 0}) == expected + res = await r.xread(streams={stream: 0}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert await r.xread(streams={stream: 0}, count=1) == expected + res = await r.xread(streams={stream: 0}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] + expected_entries = [await get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert await r.xread(streams={stream: m1}) == expected - - # xread starting at the last message returns an empty list - assert await r.xread(streams={stream: m2}) == [] + res = await r.xread(streams={stream: m1}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xreadgroup(self, r: redis.Redis): @@ -2793,26 +2876,27 @@ async def test_xreadgroup(self, r: redis.Redis): m2 = await r.xadd(stream, {"bing": "baz"}) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert ( - await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} ) await r.xgroup_destroy(stream, group) @@ -2821,35 +2905,34 @@ async def test_xreadgroup(self, r: redis.Redis): # will only find messages added after this await r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response(r, res, [], {}) # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - assert ( - len( - ( - await r.xreadgroup( - group, consumer, streams={stream: ">"}, noack=True - ) - )[0][1] - ) - == 2 - ) - # now there should be nothing pending - assert ( - len((await r.xreadgroup(group, consumer, streams={stream: "0"}))[0][1]) == 0 - ) + # res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + # empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + # if is_resp2_connection(r): + # assert len(res[0][1]) == 2 + # # now there should be nothing pending + # assert len(empty_res[0][1]) == 0 + # else: + # assert len(res[strem_name][0]) == 2 + # # now there should be nothing pending + # assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) - assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xrevrange(self, r: redis.Redis): diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index d3b6285cfb..3a8cf8d9c2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -7,16 +7,11 @@ import redis from redis.asyncio import Redis -from redis.asyncio.connection import ( - BaseParser, - Connection, - HiredisParser, - PythonParser, - UnixDomainSocketConnection, -) +from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser, _AsyncRESP3Parser from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt @@ -31,11 +26,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: BaseParser = r.connection._parser + parser: _AsyncRESP2Parser = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, PythonParser): + if isinstance(parser, _AsyncRESP2Parser): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( @@ -218,7 +213,9 @@ async def test_connection_parse_response_resume(r: redis.Redis): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser], + ids=["AsyncRESP2Parser", "AsyncRESP3Parser", "AsyncHiredisParser"], ) async def test_connection_disconect_race(parser_class): """ @@ -232,7 +229,7 @@ async def test_connection_disconect_race(parser_class): This test verifies that a read in progress can finish even if the `disconnect()` method is called. """ - if parser_class == HiredisParser and not HIREDIS_AVAILABLE: + if parser_class == _AsyncHiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available") args = {} diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 0df7847e66..0c0b7dbca6 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -995,9 +995,9 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.asyncio.connection.PythonParser.read_response") as mock1: + with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1: mock1.side_effect = BaseException("boom") - with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2: + with patch("redis.parsers._AsyncHiredisParser.read_response") as mock2: mock2.side_effect = BaseException("boom") with pytest.raises(BaseException): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 58f9b77d7d..4a43eaea21 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -18,7 +18,6 @@ RedisCluster, get_node_name, ) -from redis.commands import CommandsParser from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot from redis.exceptions import ( @@ -33,12 +32,14 @@ ResponseError, TimeoutError, ) +from redis.parsers import CommandsParser from redis.retry import Retry from redis.utils import str_if_bytes from tests.test_pubsub import wait_for_message from .conftest import ( _get_client, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1724,7 +1725,10 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + if is_resp2_connection(r): + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + else: + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [[b"a3", 3.0]] @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1732,7 +1736,10 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + if is_resp2_connection(r): + assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + else: + assert r.zrange("{foo}out", 0, -1, withscores=True) == [[b"a3", 3.0]] @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): @@ -1743,24 +1750,42 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + # aggregate with SUM + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a3", 5), (b"a1", 6)] + # aggregate with MIN + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a3", 1)] + # with weights + assert r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a3", 20), (b"a1", 23)] + else: + # aggregate with SUM + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + [b"a3", 8], + [b"a1", 9], + ] + # aggregate with MAX + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [[b"a3", 5], [b"a1", 6]] + # aggregate with MIN + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [[b"a1", 1], [b"a3", 1]] + # with weights + assert r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [[b"a3", 2], [b"a1", 2]] def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 6c3ede9cdf..b2a2268f85 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,6 +1,6 @@ import pytest -from redis.commands import CommandsParser +from redis.parsers import CommandsParser from .conftest import skip_if_redis_enterprise, skip_if_server_version_lt diff --git a/tests/test_commands.py b/tests/test_commands.py index 94249e9419..1af69c83c0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -13,6 +13,7 @@ from .conftest import ( _get_client, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, skip_if_server_version_lt, @@ -380,7 +381,10 @@ def teardown(): assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - assert "client-info" in r.acl_log(count=1)[0] + if is_resp2_connection(r): + assert "client-info" in r.acl_log(count=1)[0] + else: + assert "client-info" in r.acl_log(count=1)[0].keys() assert r.acl_log_reset() @skip_if_server_version_lt("6.0.0") @@ -1535,7 +1539,10 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - assert len(r.hrandfield("key", 2, True)) == 4 + if is_resp2_connection(r): + assert len(r.hrandfield("key", 2, True)) == 4 + else: + assert len(r.hrandfield("key", 2, True)) == 2 # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1688,17 +1695,30 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + if is_resp2_connection(r): + assert r.stralgo("LCS", value1, value2, idx=True) == { + "len": len(res), + "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], + } + assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { + "len": len(res), + "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], + } + assert r.stralgo( + "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True + ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + else: + assert r.stralgo("LCS", value1, value2, idx=True) == { + "len": len(res), + "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]], + } + assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { + "len": len(res), + "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]], + } + assert r.stralgo( + "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True + ) == {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]} @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -2147,8 +2167,10 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - - assert r.spop("a", 1) == list(set(s) - set(values)) + if is_resp2_connection(r): + assert r.spop("a", 1) == list(set(s) - set(values)) + else: + assert r.spop("a", 1) == set(s) - set(values) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2199,11 +2221,18 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + (b"a3", 3.0), + ] + else: + assert r.zrange("a", 0, -1, withscores=True) == [ + [b"a1", 1.0], + [b"a2", 2.0], + [b"a3", 3.0], + ] # error cases with pytest.raises(exceptions.DataError): @@ -2220,17 +2249,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + else: + assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + else: + assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 99.0]] def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a2", 2.0), (b"a1", 99.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 99.0), + ] + else: + assert r.zrange("a", 0, -1, withscores=True) == [ + [b"a2", 2.0], + [b"a1", 99.0], + ] def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2278,7 +2322,10 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + if is_resp2_connection(r): + assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + else: + assert r.zdiff(["a", "b"], withscores=True) == [[b"a3", 3.0]] @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2287,7 +2334,10 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + if is_resp2_connection(r): + assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + else: + assert r.zrange("out", 0, -1, withscores=True) == [[b"a3", 3.0]] def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2312,23 +2362,48 @@ def test_zinter(self, r): # invalid aggregation with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + # aggregate with SUM + assert r.zinter(["a", "b", "c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a3", 5), + (b"a1", 6), + ] + # aggregate with MIN + assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a3", 1), + ] + # with weights + assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] + else: + # aggregate with SUM + assert r.zinter(["a", "b", "c"], withscores=True) == [ + [b"a3", 8], + [b"a1", 9], + ] + # aggregate with MAX + assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + [b"a3", 5], + [b"a1", 6], + ] + # aggregate with MIN + assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + [b"a1", 1], + [b"a3", 1], + ] + # with weights + assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + [b"a3", 20], + [b"a1", 23], + ] @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2345,7 +2420,10 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 8], [b"a1", 9]] @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2353,7 +2431,10 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 5], [b"a1", 6]] @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2361,7 +2442,10 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a1", 1], [b"a3", 3]] @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2369,23 +2453,34 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 20], [b"a1", 23]] @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmax("a") == [(b"a3", 3)] - - # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + if is_resp2_connection(r): + assert r.zpopmax("a") == [(b"a3", 3)] + # with count + assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + else: + assert r.zpopmax("a") == [b"a3", 3.0] + # with count + assert r.zpopmax("a", count=2) == [[b"a2", 2], [b"a1", 1]] @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmin("a") == [(b"a1", 1)] - - # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zpopmin("a") == [(b"a1", 1)] + # with count + assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + else: + assert r.zpopmin("a") == [b"a1", 1.0] + # with count + assert r.zpopmin("a", count=2) == [[b"a2", 2], [b"a3", 3]] @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2393,7 +2488,10 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - assert len(r.zrandmember("a", 2, True)) == 4 + if is_resp2_connection(r): + assert len(r.zrandmember("a", 2, True)) == 4 + else: + assert len(r.zrandmember("a", 2, True)) == 2 # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2457,14 +2555,18 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] - - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + if is_resp2_connection(r): + assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + + # custom score function + assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a1", 1), + (b"a2", 2), + ] + else: + assert r.zrange("a", 0, 1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] + assert r.zrange("a", 1, 2, withscores=True) == [[b"a2", 2.0], [b"a3", 3.0]] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2496,14 +2598,25 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + if is_resp2_connection(r): + assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + + else: + assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ + [b"a2", 2.0], + [b"a3", 3.0], + [b"a4", 4.0], + ] + assert r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ) == [[b"a4", 4], [b"a3", 3], [b"a2", 2]] # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2516,7 +2629,10 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + else: + assert r.zrange("b", 0, -1, withscores=True) == [[b"a2", 2], [b"a3", 3]] # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2551,16 +2667,28 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + if is_resp2_connection(r): + assert r.zrangebyscore("a", 2, 4, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + else: + assert r.zrangebyscore("a", 2, 4, withscores=True) == [ + [b"a2", 2.0], + [b"a3", 3.0], + [b"a4", 4.0], + ] + assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ + [b"a2", 2], + [b"a3", 3], + [b"a4", 4], + ] def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2607,33 +2735,61 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 0, 1) == [b"a3", b"a2"] assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [(b"a3", 3.0), (b"a2", 2.0)] - assert r.zrevrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a1", 1.0)] + if is_resp2_connection(r): + # withscores + assert r.zrevrange("a", 0, 1, withscores=True) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + assert r.zrevrange("a", 1, 2, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 1.0), + ] - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + # custom score function + assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + else: + # withscores + assert r.zrevrange("a", 0, 1, withscores=True) == [ + [b"a3", 3.0], + [b"a2", 2.0], + ] + assert r.zrevrange("a", 1, 2, withscores=True) == [ + [b"a2", 2.0], + [b"a1", 1.0], + ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] - # custom score function - assert r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] + + if is_resp2_connection(r): + # withscores + assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + (b"a4", 4.0), + (b"a3", 3.0), + (b"a2", 2.0), + ] + # custom score function + assert r.zrevrangebyscore( + "a", 4, 2, withscores=True, score_cast_func=int + ) == [ + (b"a4", 4), + (b"a3", 3), + (b"a2", 2), + ] + else: + # withscores + assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + [b"a4", 4.0], + [b"a3", 3.0], + [b"a2", 2.0], + ] def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2655,33 +2811,63 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + + if is_resp2_connection(r): + assert r.zunion(["a", "b", "c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + # max + assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + # min + assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a2", 1), + (b"a3", 1), + (b"a4", 4), + ] + # with weight + assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + else: + assert r.zunion(["a", "b", "c"], withscores=True) == [ + [b"a2", 3], + [b"a4", 4], + [b"a3", 8], + [b"a1", 9], + ] + # max + assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + [b"a2", 2], + [b"a4", 4], + [b"a3", 5], + [b"a1", 6], + ] + # min + assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + [b"a1", 1], + [b"a2", 1], + [b"a3", 1], + [b"a4", 4], + ] + # with weight + assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + [b"a2", 5], + [b"a4", 12], + [b"a3", 20], + [b"a1", 23], + ] @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2689,12 +2875,21 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 3], + [b"a4", 4], + [b"a3", 8], + [b"a1", 9], + ] @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2702,12 +2897,20 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 2], + [b"a4", 4], + [b"a3", 5], + [b"a1", 6], + ] @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2715,12 +2918,20 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a1", 1], + [b"a2", 2], + [b"a3", 3], + [b"a4", 4], + ] @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2728,12 +2939,20 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 5], + [b"a4", 12], + [b"a3", 20], + [b"a1", 23], + ] @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -4108,7 +4327,10 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - assert m1 in info["entries"] + if is_resp2_connection(r): + assert m1 in info["entries"] + else: + assert m1 in info["entries"][0] assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4249,25 +4471,40 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + strem_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert r.xread(streams={stream: 0}) == expected + res = r.xread(streams={stream: 0}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert r.xread(streams={stream: 0}, count=1) == expected + res = r.xread(streams={stream: 0}, count=1) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} - expected = [[stream.encode(), [get_stream_message(r, stream, m2)]]] + expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert r.xread(streams={stream: m1}) == expected + res = r.xread(streams={stream: m1}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} # xread starting at the last message returns an empty list - assert r.xread(streams={stream: m2}) == [] + res = r.xread(streams={stream: m2}) + if is_resp2_connection(r): + assert res == [] + else: + assert res == {} @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4278,21 +4515,30 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + strem_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = r.xreadgroup(group, consumer, streams={stream: ">"}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) == expected + res = r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} r.xgroup_destroy(stream, group) @@ -4300,27 +4546,37 @@ def test_xreadgroup(self, r): # will only find messages added after this r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + if is_resp2_connection(r): + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == [] + else: + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == {} # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") - assert ( - len(r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True)[0][1]) - == 2 - ) - # now there should be nothing pending - assert len(r.xreadgroup(group, consumer, streams={stream: "0"})[0][1]) == 0 + res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - assert r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): diff --git a/tests/test_connection.py b/tests/test_connection.py index 25b4118b2c..facd425061 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,14 +7,9 @@ import redis from redis.backoff import NoBackoff -from redis.connection import ( - Connection, - HiredisParser, - PythonParser, - SSLConnection, - UnixDomainSocketConnection, -) +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -134,7 +129,9 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_RESP2Parser, _RESP3Parser, _HiredisParser], + ids=["RESP2Parser", "RESP3Parser", "HiredisParser"], ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ @@ -142,7 +139,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): be that PythonParser or HiredisParser, can be interrupted at IO time and then resume parsing. """ - if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + if parser_class is _HiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available)") args = dict(r.connection_pool.connection_kwargs) args["parser_class"] = parser_class @@ -154,7 +151,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): ) mock_socket = MockSocket(message, interrupt_every=2) - if isinstance(conn._parser, PythonParser): + if isinstance(conn._parser, _RESP2Parser) or isinstance(conn._parser, _RESP3Parser): conn._parser._buffer._sock = mock_socket else: conn._parser._sock = mock_socket diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index e8a42692a1..ba9fef3089 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,7 +7,8 @@ import pytest import redis -from redis.connection import ssl_available, to_bool +from redis.connection import to_bool +from redis.utils import SSL_AVAILABLE from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt from .test_pubsub import wait_for_message @@ -425,7 +426,7 @@ class MyConnection(redis.UnixDomainSocketConnection): assert pool.connection_class == MyConnection -@pytest.mark.skipif(not ssl_available, reason="SSL not installed") +@pytest.mark.skipif(not SSL_AVAILABLE, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 716cd0fbf6..7b98ece692 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -19,7 +19,6 @@ def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert pipe.execute() == [ True, @@ -27,7 +26,6 @@ def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] def test_pipeline_memoryview(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5d86934de6..48c0f3ac47 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -767,9 +767,9 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.connection.PythonParser.read_response") as mock1: + with patch("redis.parsers._RESP2Parser.read_response") as mock1: mock1.side_effect = BaseException("boom") - with patch("redis.connection.HiredisParser.read_response") as mock2: + with patch("redis.parsers._HiredisParser.read_response") as mock2: mock2.side_effect = BaseException("boom") with pytest.raises(BaseException): diff --git a/whitelist.py b/whitelist.py index 8c9cee3c29..29cd529e4d 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,6 +14,5 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) -AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46)