diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 5de04c0f94..8e59249bef 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -46,7 +46,6 @@ class BaseParser(ABC): - EXCEPTION_CLASSES = { "ERR": { "max number of clients reached": ConnectionError, diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 1275686710..ad766a8f95 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -243,10 +243,8 @@ async def _read_response( ] res = self.push_handler_func(response) if not push_request: - return await ( - self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request ) else: return res diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index f36b4bf79b..65fa58643b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1155,7 +1155,6 @@ def __init__( queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated **connection_kwargs, ): - super().__init__( connection_class=connection_class, max_connections=max_connections, diff --git a/redis/client.py b/redis/client.py index cf6dbf1eed..4923143543 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,8 +4,9 @@ import time import warnings from itertools import chain -from typing import Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union +from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -49,7 +50,7 @@ class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." - def __init__(self, data): + def __init__(self, data: Dict[str, str]) -> None: for k, v in data.items(): self[k.upper()] = v @@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls, url: str, **kwargs) -> None: """ Return a Redis client object configured from the given URL @@ -202,7 +203,7 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - ): + ) -> None: """ Initialize a new Redis client. To specify a retry policy for specific errors, first set @@ -309,14 +310,14 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" - def get_encoder(self): + def get_encoder(self) -> "Encoder": """Get the connection pool's encoder""" return self.connection_pool.get_encoder() - def get_connection_kwargs(self): + def get_connection_kwargs(self) -> Dict: """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs @@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) - def set_response_callback(self, command, callback): + def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback - def load_external_module(self, funcname, func): + def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -354,7 +355,7 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def pipeline(self, transaction=True, shard_hint=None): + def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": """ Return a new pipeline object that can queue multiple commands for later execution. ``transaction`` indicates whether all commands @@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None): self.connection_pool, self.response_callbacks, transaction, shard_hint ) - def transaction(self, func, *watches, **kwargs): + def transaction( + self, func: Callable[["Pipeline"], None], *watches, **kwargs + ) -> None: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable @@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs): def lock( self, - name, - timeout=None, - sleep=0.1, - blocking=True, - blocking_timeout=None, - lock_class=None, - thread_local=True, + name: str, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + lock_class: Union[None, Any] = None, + thread_local: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -648,9 +651,9 @@ def __init__( self, connection_pool, shard_hint=None, - ignore_subscribe_messages=False, - encoder=None, - push_handler_func=None, + ignore_subscribe_messages: bool = False, + encoder: Optional["Encoder"] = None, + push_handler_func: Union[None, Callable[[str], None]] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -672,13 +675,13 @@ def __init__( _set_info_logger() self.reset() - def __enter__(self): + def __enter__(self) -> "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.reset() - def __del__(self): + def __del__(self) -> None: try: # if this object went out of scope prior to shutting down # subscriptions, close the connection manually before @@ -687,7 +690,7 @@ def __del__(self): except Exception: pass - def reset(self): + def reset(self) -> None: if self.connection: self.connection.disconnect() self.connection._deregister_connect_callback(self.on_connect) @@ -702,10 +705,10 @@ def reset(self): self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() - def close(self): + def close(self) -> None: self.reset() - def on_connect(self, connection): + def on_connect(self, connection) -> None: "Re-subscribe to any channels and patterns previously subscribed to" # NOTE: for python3, we can't pass bytestrings as keyword arguments # so we need to decode channel/pattern names back to unicode strings @@ -731,7 +734,7 @@ def on_connect(self, connection): self.ssubscribe(**shard_channels) @property - def subscribed(self): + def subscribed(self) -> bool: """Indicates if there are subscriptions to any channels or patterns""" return self.subscribed_event.is_set() @@ -757,7 +760,7 @@ def execute_command(self, *args): self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) - def clean_health_check_responses(self): + def clean_health_check_responses(self) -> None: """ If any health check responses are present, clean them """ @@ -775,7 +778,7 @@ def clean_health_check_responses(self): ) ttl -= 1 - def _disconnect_raise_connect(self, conn, error): + def _disconnect_raise_connect(self, conn, error) -> None: """ Close the connection and raise an exception if retry_on_timeout is not set or the error @@ -826,7 +829,7 @@ def try_read(): return None return response - def is_health_check_response(self, response): + def is_health_check_response(self, response) -> bool: """ Check if the response is a health check response. If there are no subscriptions redis responds to PING command with a @@ -837,7 +840,7 @@ def is_health_check_response(self, response): self.health_check_response_b, # If there wasn't ] - def check_health(self): + def check_health(self) -> None: conn = self.connection if conn is None: raise RuntimeError( @@ -849,7 +852,7 @@ def check_health(self): conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) self.health_check_response_counter += 1 - def _normalize_keys(self, data): + def _normalize_keys(self, data) -> Dict: """ normalize channel/pattern names to be either bytes or strings based on whether responses are automatically decoded. this saves us @@ -983,7 +986,9 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0.0): + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): """ Get the next message if one is available, otherwise None. @@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): get_sharded_message = get_message - def ping(self, message=None): + def ping(self, message: Union[str, None] = None) -> bool: """ Ping the Redis server """ @@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message - def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): + def run_in_thread( + self, + sleep_time: int = 0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: raise PubSubError(f"Channel: '{channel}' has no handler registered") @@ -1114,7 +1124,15 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): + def __init__( + self, + pubsub, + sleep_time: float, + daemon: bool = False, + exception_handler: Union[ + Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None + ] = None, + ): super().__init__() self.daemon = daemon self.pubsub = pubsub @@ -1122,7 +1140,7 @@ def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): self.exception_handler = exception_handler self._running = threading.Event() - def run(self): + def run(self) -> None: if self._running.is_set(): return self._running.set() @@ -1137,7 +1155,7 @@ def run(self): self.exception_handler(e, pubsub, self) pubsub.close() - def stop(self): + def stop(self) -> None: # trip the flag so the run loop exits. the run loop will # close the pubsub connection, which disconnects the socket # and returns the connection to the pool. @@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint) self.watching = False self.reset() - def __enter__(self): + def __enter__(self) -> "Pipeline": return self def __exit__(self, exc_type, exc_value, traceback): @@ -1187,14 +1205,14 @@ def __del__(self): except Exception: pass - def __len__(self): + def __len__(self) -> int: return len(self.command_stack) - def __bool__(self): + def __bool__(self) -> bool: """Pipeline instances should always evaluate to True""" return True - def reset(self): + def reset(self) -> None: self.command_stack = [] self.scripts = set() # make sure to reset the connection state in the event that we were @@ -1217,11 +1235,11 @@ def reset(self): self.connection_pool.release(self.connection) self.connection = None - def close(self): + def close(self) -> None: """Close the pipeline""" self.reset() - def multi(self): + def multi(self) -> None: """ Start a transactional block of the pipeline after WATCH commands are issued. End the transactional block with `execute`. @@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs): return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - def _disconnect_reset_raise(self, conn, error): + def _disconnect_reset_raise(self, conn, error) -> None: """ Close the connection, reset watching state and raise an exception if we were watching, @@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options): lambda error: self._disconnect_reset_raise(conn, error), ) - def pipeline_execute_command(self, *args, **options): + def pipeline_execute_command(self, *args, **options) -> "Pipeline": """ Stage a command to be executed when execute() is next called @@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options): self.command_stack.append((args, options)) return self - def _execute_transaction(self, connection, commands, raise_on_error): + def _execute_transaction(self, connection, commands, raise_on_error) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] @@ -1415,7 +1433,7 @@ def load_scripts(self): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_reset(self, conn, error): + def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: """ Close the connection, raise an exception if we were watching, and raise an exception if TimeoutError is not part of retry_on_error, @@ -1477,6 +1495,6 @@ def watch(self, *names): raise RedisError("Cannot issue a WATCH after a MULTI") return self.execute_command("WATCH", *names) - def unwatch(self): + def unwatch(self) -> bool: """Unwatches all previously specified keys""" return self.watching and self.execute_command("UNWATCH") or True diff --git a/redis/cluster.py b/redis/cluster.py index ee3e1a865d..873d586c4a 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2457,7 +2457,6 @@ def read(self): """ """ connection = self.connection for c in self.commands: - # if there is a result on this command, # it means we ran into an exception # like a connection error. Trying to parse diff --git a/redis/commands/core.py b/redis/commands/core.py index 9d81e9772c..e73553e47e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -403,7 +403,7 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password, username=None, **kwargs): + def auth(self, password: str, username: Optional[str] = None, **kwargs): """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 3abe155796..0f92e0d6c9 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -80,7 +80,6 @@ def arrpop( path: Optional[str] = Path.root_path(), index: Optional[int] = -1, ) -> List[Union[str, None]]: - """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index e635f91e99..a2bb23b76d 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -27,7 +27,6 @@ class BatchIndexer: """ def __init__(self, client, chunk_size=1000): - self.client = client self.execute_command = client.execute_command self._pipeline = client.pipeline(transaction=False, shard_hint=None) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 93a3d9273b..50d18f476a 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,8 +1,10 @@ +from typing import List, Union + FIELDNAME = object() class Limit: - def __init__(self, offset=0, count=0): + def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset self.count = count @@ -22,12 +24,12 @@ class Reducer: NAME = None - def __init__(self, *args): + def __init__(self, *args: List[str]) -> None: self._args = args self._field = None self._alias = None - def alias(self, alias): + def alias(self, alias: str) -> "Reducer": """ Set the alias for this reducer. @@ -51,7 +53,7 @@ def alias(self, alias): return self @property - def args(self): + def args(self) -> List[str]: return self._args @@ -62,7 +64,7 @@ class SortDirection: DIRSTRING = None - def __init__(self, field): + def __init__(self, field: str) -> None: self.field = field @@ -87,7 +89,7 @@ class AggregateRequest: Aggregation request which can be passed to `Client.aggregate`. """ - def __init__(self, query="*"): + def __init__(self, query: str = "*") -> None: """ Create an aggregation request. This request may then be passed to `client.aggregate()`. @@ -110,7 +112,7 @@ def __init__(self, query="*"): self._cursor = [] self._dialect = None - def load(self, *fields): + def load(self, *fields: List[str]) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -126,7 +128,9 @@ def load(self, *fields): self._loadall = True return self - def group_by(self, fields, *reducers): + def group_by( + self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -151,7 +155,7 @@ def group_by(self, fields, *reducers): self._aggregateplan.extend(ret) return self - def apply(self, **kwexpr): + def apply(self, **kwexpr) -> "AggregateRequest": """ Specify one or more projection expressions to add to each result @@ -169,7 +173,7 @@ def apply(self, **kwexpr): return self - def limit(self, offset, num): + def limit(self, offset: int, num: int) -> "AggregateRequest": """ Sets the limit for the most recent group or query. @@ -215,7 +219,7 @@ def limit(self, offset, num): self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields, **kwargs): + def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -262,7 +266,7 @@ def sort_by(self, *fields, **kwargs): self._aggregateplan.extend(ret) return self - def filter(self, expressions): + def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest": """ Specify filter for post-query results using predicates relating to values in the result set. @@ -280,7 +284,7 @@ def filter(self, expressions): return self - def with_schema(self): + def with_schema(self) -> "AggregateRequest": """ If set, the `schema` property will contain a list of `[field, type]` entries in the result object. @@ -288,11 +292,11 @@ def with_schema(self): self._with_schema = True return self - def verbatim(self): + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self - def cursor(self, count=0, max_idle=0.0): + def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest": args = ["WITHCURSOR"] if count: args += ["COUNT", str(count)] @@ -301,7 +305,7 @@ def cursor(self, count=0, max_idle=0.0): self._cursor = args return self - def build_args(self): + def build_args(self) -> List[str]: # @foo:bar ... ret = [self._query] @@ -329,7 +333,7 @@ def build_args(self): return ret - def dialect(self, dialect): + def dialect(self, dialect: int) -> "AggregateRequest": """ Add a dialect field to the aggregate command. @@ -340,7 +344,7 @@ def dialect(self, dialect): class Cursor: - def __init__(self, cid): + def __init__(self, cid: int) -> None: self.cid = cid self.max_idle = 0 self.count = 0 @@ -355,12 +359,12 @@ def build_args(self): class AggregateResult: - def __init__(self, rows, cursor, schema): + def __init__(self, rows, cursor: Cursor, schema) -> None: self.rows = rows self.cursor = cursor self.schema = schema - def __repr__(self): + def __repr__(self) -> (str, str): cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 83dea106d2..2df2b5a754 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,6 +1,6 @@ import itertools import time -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from redis.client import Pipeline from redis.utils import deprecated_function @@ -220,7 +220,7 @@ def create_index( return self.execute_command(*args) - def alter_schema_add(self, fields): + def alter_schema_add(self, fields: List[str]): """ Alter the existing search index by adding new fields. The index must already exist. @@ -240,7 +240,7 @@ def alter_schema_add(self, fields): return self.execute_command(*args) - def dropindex(self, delete_documents=False): + def dropindex(self, delete_documents: bool = False): """ Drop the index if it exists. Replaced `drop_index` in RediSearch 2.0. @@ -322,15 +322,15 @@ def _add_document_hash( ) def add_document( self, - doc_id, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - language=None, - no_create=False, - **fields, + doc_id: str, + nosave: bool = False, + score: float = 1.0, + payload: bool = None, + replace: bool = False, + partial: bool = False, + language: Optional[str] = None, + no_create: str = False, + **fields: List[str], ): """ Add a single document to the index. @@ -554,7 +554,9 @@ def aggregate( AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor ) - def _get_aggregate_result(self, raw, query, has_cursor): + def _get_aggregate_result( + self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool + ): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] @@ -642,7 +644,7 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - def dict_add(self, name, *terms): + def dict_add(self, name: str, *terms: List[str]): """Adds terms to a dictionary. ### Parameters @@ -656,7 +658,7 @@ def dict_add(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_del(self, name, *terms): + def dict_del(self, name: str, *terms: List[str]): """Deletes terms from a dictionary. ### Parameters @@ -670,7 +672,7 @@ def dict_del(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_dump(self, name): + def dict_dump(self, name: str): """Dumps all terms in the given dictionary. ### Parameters @@ -682,7 +684,7 @@ def dict_dump(self, name): cmd = [DICT_DUMP_CMD, name] return self.execute_command(*cmd) - def config_set(self, option, value): + def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -696,7 +698,7 @@ def config_set(self, option, value): raw = self.execute_command(*cmd) return raw == "OK" - def config_get(self, option): + def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -709,7 +711,7 @@ def config_get(self, option): res = self.execute_command(*cmd) return self._parse_results(CONFIG_CMD, res) - def tagvals(self, tagfield): + def tagvals(self, tagfield: str): """ Return a list of all possible tag values @@ -722,7 +724,7 @@ def tagvals(self, tagfield): return self.execute_command(TAGVALS_CMD, self.index_name, tagfield) - def aliasadd(self, alias): + def aliasadd(self, alias: str): """ Alias a search index - will fail if alias already exists @@ -735,7 +737,7 @@ def aliasadd(self, alias): return self.execute_command(ALIAS_ADD_CMD, alias, self.index_name) - def aliasupdate(self, alias): + def aliasupdate(self, alias: str): """ Updates an alias - will fail if alias does not already exist @@ -748,7 +750,7 @@ def aliasupdate(self, alias): return self.execute_command(ALIAS_UPDATE_CMD, alias, self.index_name) - def aliasdel(self, alias): + def aliasdel(self, alias: str): """ Removes an alias to a search index @@ -783,7 +785,7 @@ def sugadd(self, key, *suggestions, **kwargs): return pipe.execute()[-1] - def suglen(self, key): + def suglen(self, key: str) -> int: """ Return the number of entries in the AutoCompleter index. @@ -791,7 +793,7 @@ def suglen(self, key): """ # noqa return self.execute_command(SUGLEN_COMMAND, key) - def sugdel(self, key, string): + def sugdel(self, key: str, string: str) -> int: """ Delete a string from the AutoCompleter index. Returns 1 if the string was found and deleted, 0 otherwise. @@ -801,8 +803,14 @@ def sugdel(self, key, string): return self.execute_command(SUGDEL_COMMAND, key, string) def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. @@ -850,7 +858,7 @@ def sugget( parser = SuggestionParser(with_scores, with_payloads, res) return [s for s in parser] - def synupdate(self, groupid, skipinitial=False, *terms): + def synupdate(self, groupid: str, skipinitial: bool = False, *terms: List[str]): """ Updates a synonym group. The command is used to create or update a synonym group with @@ -986,7 +994,7 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - async def config_set(self, option, value): + async def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -1000,7 +1008,7 @@ async def config_set(self, option, value): raw = await self.execute_command(*cmd) return raw == "OK" - async def config_get(self, option): + async def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -1053,8 +1061,14 @@ async def sugadd(self, key, *suggestions, **kwargs): return (await pipe.execute())[-1] async def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 6f31ce1fc2..76eb58c2d7 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -4,7 +4,6 @@ class Field: - NUMERIC = "NUMERIC" TEXT = "TEXT" WEIGHT = "WEIGHT" diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 362dd6c72a..113ddf9da8 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,3 +1,6 @@ +from typing import List, Optional, Union + + class Query: """ Query is used to build complex queries that have more parameters than just @@ -8,52 +11,52 @@ class Query: i.e. `Query("foo").verbatim().filter(...)` etc. """ - def __init__(self, query_string): + def __init__(self, query_string: str) -> None: """ Create a new query object. The query string is set in the constructor, and other options have setter functions. """ - self._query_string = query_string - self._offset = 0 - self._num = 10 - self._no_content = False - self._no_stopwords = False - self._fields = None - self._verbatim = False - self._with_payloads = False - self._with_scores = False - self._scorer = False - self._filters = list() - self._ids = None - self._slop = -1 - self._timeout = None - self._in_order = False - self._sortby = None - self._return_fields = [] - self._summarize_fields = [] - self._highlight_fields = [] - self._language = None - self._expander = None - self._dialect = None - - def query_string(self): + self._query_string: str = query_string + self._offset: int = 0 + self._num: int = 10 + self._no_content: bool = False + self._no_stopwords: bool = False + self._fields: Optional[List[str]] = None + self._verbatim: bool = False + self._with_payloads: bool = False + self._with_scores: bool = False + self._scorer: Optional[str] = None + self._filters: List = list() + self._ids: Optional[List[str]] = None + self._slop: int = -1 + self._timeout: Optional[float] = None + self._in_order: bool = False + self._sortby: Optional[SortbyField] = None + self._return_fields: List = [] + self._summarize_fields: List = [] + self._highlight_fields: List = [] + self._language: Optional[str] = None + self._expander: Optional[str] = None + self._dialect: Optional[int] = None + + def query_string(self) -> str: """Return the query string of this query only.""" return self._query_string - def limit_ids(self, *ids): + def limit_ids(self, *ids) -> "Query": """Limit the results to a specific set of pre-known document ids of any length.""" self._ids = ids return self - def return_fields(self, *fields): + def return_fields(self, *fields) -> "Query": """Add fields to return fields.""" self._return_fields += fields return self - def return_field(self, field, as_field=None): + def return_field(self, field: str, as_field: Optional[str] = None) -> "Query": """Add field to return fields (Optional: add 'AS' name to the field).""" self._return_fields.append(field) @@ -61,12 +64,18 @@ def return_field(self, field, as_field=None): self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields): + def _mk_field_list(self, fields: List[str]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) - def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): + def summarize( + self, + fields: Optional[List] = None, + context_len: Optional[int] = None, + num_frags: Optional[int] = None, + sep: Optional[str] = None, + ) -> "Query": """ Return an abridged format of the field, containing only the segments of the field which contain the matching term(s). @@ -98,7 +107,9 @@ def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): self._summarize_fields = args return self - def highlight(self, fields=None, tags=None): + def highlight( + self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None + ) -> None: """ Apply specified markup to matched term(s) within the returned field(s). @@ -116,7 +127,7 @@ def highlight(self, fields=None, tags=None): self._highlight_fields = args return self - def language(self, language): + def language(self, language: str) -> "Query": """ Analyze the query as being in the specified language. @@ -125,19 +136,19 @@ def language(self, language): self._language = language return self - def slop(self, slop): + def slop(self, slop: int) -> "Query": """Allow a maximum of N intervening non matched terms between phrase terms (0 means exact phrase). """ self._slop = slop return self - def timeout(self, timeout): + def timeout(self, timeout: float) -> "Query": """overrides the timeout parameter of the module""" self._timeout = timeout return self - def in_order(self): + def in_order(self) -> "Query": """ Match only documents where the query terms appear in the same order in the document. @@ -146,7 +157,7 @@ def in_order(self): self._in_order = True return self - def scorer(self, scorer): + def scorer(self, scorer: str) -> "Query": """ Use a different scoring function to evaluate document relevance. Default is `TFIDF`. @@ -157,7 +168,7 @@ def scorer(self, scorer): self._scorer = scorer return self - def get_args(self): + def get_args(self) -> List[str]: """Format the redis arguments for this query and return them.""" args = [self._query_string] args += self._get_args_tags() @@ -165,7 +176,7 @@ def get_args(self): args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self): + def _get_args_tags(self) -> List[str]: args = [] if self._no_content: args.append("NOCONTENT") @@ -216,7 +227,7 @@ def _get_args_tags(self): return args - def paging(self, offset, num): + def paging(self, offset: int, num: int) -> "Query": """ Set the paging for the query (defaults to 0..10). @@ -227,19 +238,19 @@ def paging(self, offset, num): self._num = num return self - def verbatim(self): + def verbatim(self) -> "Query": """Set the query to be verbatim, i.e. use no query expansion or stemming. """ self._verbatim = True return self - def no_content(self): + def no_content(self) -> "Query": """Set the query to only return ids and not the document content.""" self._no_content = True return self - def no_stopwords(self): + def no_stopwords(self) -> "Query": """ Prevent the query from being filtered for stopwords. Only useful in very big queries that you are certain contain @@ -248,17 +259,17 @@ def no_stopwords(self): self._no_stopwords = True return self - def with_payloads(self): + def with_payloads(self) -> "Query": """Ask the engine to return document payloads.""" self._with_payloads = True return self - def with_scores(self): + def with_scores(self) -> "Query": """Ask the engine to return document search scores.""" self._with_scores = True return self - def limit_fields(self, *fields): + def limit_fields(self, *fields: List[str]) -> "Query": """ Limit the search to specific TEXT fields only. @@ -268,7 +279,7 @@ def limit_fields(self, *fields): self._fields = fields return self - def add_filter(self, flt): + def add_filter(self, flt: "Filter") -> "Query": """ Add a numeric or geo filter to the query. **Currently only one of each filter is supported by the engine** @@ -280,7 +291,7 @@ def add_filter(self, flt): self._filters.append(flt) return self - def sort_by(self, field, asc=True): + def sort_by(self, field: str, asc: bool = True) -> "Query": """ Add a sortby field to the query. @@ -290,7 +301,7 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self - def expander(self, expander): + def expander(self, expander: str) -> "Query": """ Add a expander field to the query. @@ -310,7 +321,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword, field, *args): + def __init__(self, keyword: str, field: str, *args: List[str]) -> None: self.args = [keyword, field] + list(args) @@ -318,7 +329,14 @@ class NumericFilter(Filter): INF = "+inf" NEG_INF = "-inf" - def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): + def __init__( + self, + field: str, + minval: Union[int, str], + maxval: Union[int, str], + minExclusive: bool = False, + maxExclusive: bool = False, + ) -> None: args = [ minval if not minExclusive else f"({minval}", maxval if not maxExclusive else f"({maxval}", @@ -333,10 +351,12 @@ class GeoFilter(Filter): FEET = "ft" MILES = "mi" - def __init__(self, field, lon, lat, radius, unit=KILOMETERS): + def __init__( + self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS + ) -> None: Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit) class SortbyField: - def __init__(self, field, asc=True): + def __init__(self, field: str, asc=True) -> None: self.args = [field, "ASC" if asc else "DESC"] diff --git a/redis/commands/search/reducers.py b/redis/commands/search/reducers.py index 41ed11a238..8b60f23283 100644 --- a/redis/commands/search/reducers.py +++ b/redis/commands/search/reducers.py @@ -1,8 +1,12 @@ -from .aggregation import Reducer, SortDirection +from typing import Union + +from .aggregation import Asc, Desc, Reducer, SortDirection class FieldOnlyReducer(Reducer): - def __init__(self, field): + """See https://redis.io/docs/interact/search-and-query/search/aggregations/""" + + def __init__(self, field: str) -> None: super().__init__(field) self._field = field @@ -14,7 +18,7 @@ class count(Reducer): NAME = "COUNT" - def __init__(self): + def __init__(self) -> None: super().__init__() @@ -25,7 +29,7 @@ class sum(FieldOnlyReducer): NAME = "SUM" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -36,7 +40,7 @@ class min(FieldOnlyReducer): NAME = "MIN" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -47,7 +51,7 @@ class max(FieldOnlyReducer): NAME = "MAX" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -58,7 +62,7 @@ class avg(FieldOnlyReducer): NAME = "AVG" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -69,7 +73,7 @@ class tolist(FieldOnlyReducer): NAME = "TOLIST" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -81,7 +85,7 @@ class count_distinct(FieldOnlyReducer): NAME = "COUNT_DISTINCT" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -103,7 +107,7 @@ class quantile(Reducer): NAME = "QUANTILE" - def __init__(self, field, pct): + def __init__(self, field: str, pct: float) -> None: super().__init__(field, str(pct)) self._field = field @@ -115,7 +119,7 @@ class stddev(FieldOnlyReducer): NAME = "STDDEV" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -126,7 +130,7 @@ class first_value(Reducer): NAME = "FIRST_VALUE" - def __init__(self, field, *byfields): + def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None: """ Selects the first value of the given field within the group. @@ -166,7 +170,7 @@ class random_sample(Reducer): NAME = "RANDOM_SAMPLE" - def __init__(self, field, size): + def __init__(self, field: str, size: int) -> None: """ ### Parameter diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 451bf89bb7..5b19e6faa4 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -69,5 +69,5 @@ def __init__( ) self.docs.append(doc) - def __repr__(self): + def __repr__(self) -> str: return f"Result{{{self.total} total, docs: {self.docs}}}" diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index 5d1eba64b8..499c8d917e 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -1,3 +1,5 @@ +from typing import Optional + from ._util import to_string @@ -7,12 +9,14 @@ class Suggestion: autocomplete server """ - def __init__(self, string, score=1.0, payload=None): + def __init__( + self, string: str, score: float = 1.0, payload: Optional[str] = None + ) -> None: self.string = to_string(string) self.payload = to_string(payload) self.score = score - def __repr__(self): + def __repr__(self) -> str: return self.string @@ -23,7 +27,7 @@ class SuggestionParser: the return value depending on what objects were requested """ - def __init__(self, with_scores, with_payloads, ret): + def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None: self.with_scores = with_scores self.with_payloads = with_payloads diff --git a/redis/connection.py b/redis/connection.py index f5266d7dce..b39ba28f76 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Type, Union +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -55,7 +55,7 @@ class HiredisRespSerializer: - def pack(self, *args): + def pack(self, *args: List): """Pack a series of arguments into the Redis protocol""" output = [] @@ -128,27 +128,27 @@ class AbstractConnection: def __init__( self, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - retry_on_timeout=False, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, retry_on_error=SENTINEL, - encoding="utf-8", - encoding_errors="strict", - decode_responses=False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, parser_class=DefaultParser, - socket_read_size=65536, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - command_packer=None, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. @@ -977,7 +977,10 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, connection_class=Connection, max_connections=None, **connection_kwargs + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: @@ -998,13 +1001,13 @@ def __init__( self._fork_lock = threading.Lock() self.reset() - def __repr__(self): + def __repr__(self) -> (str, str): return ( f"{type(self).__name__}" f"<{repr(self.connection_class(**self.connection_kwargs))}>" ) - def reset(self): + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] @@ -1021,7 +1024,7 @@ def reset(self): # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() - def _checkpid(self): + def _checkpid(self) -> None: # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that # manipulate the pool's state such as get_connection() and release(). @@ -1068,7 +1071,7 @@ def _checkpid(self): finally: self._fork_lock.release() - def get_connection(self, command_name, *keys, **options): + def get_connection(self, command_name: str, *keys, **options) -> "Connection": "Get a connection from the pool" self._checkpid() with self._lock: @@ -1101,7 +1104,7 @@ def get_connection(self, command_name, *keys, **options): return connection - def get_encoder(self): + def get_encoder(self) -> Encoder: "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( @@ -1110,14 +1113,14 @@ def get_encoder(self): decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self): + def make_connection(self) -> "Connection": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - def release(self, connection): + def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" self._checkpid() with self._lock: @@ -1138,10 +1141,10 @@ def release(self, connection): connection.disconnect() return - def owns_connection(self, connection): + def owns_connection(self, connection: "Connection") -> int: return connection.pid == self.pid - def disconnect(self, inuse_connections=True): + def disconnect(self, inuse_connections: bool = True) -> None: """ Disconnects connections in the pool @@ -1215,7 +1218,6 @@ def __init__( queue_class=LifoQueue, **connection_kwargs, ): - self.queue_class = queue_class self.timeout = timeout super().__init__( diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 76ec2bbd26..17ed6822ac 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -99,17 +99,14 @@ async def pipe( @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) async def test_standalone(delay, master_host): - # create a tcp socket proxy that relays data to Redis and back, # inserting 0.1 seconds of delay async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: - for b in [True, False]: # note that we connect to proxy, rather than to Redis directly async with Redis( host="127.0.0.1", port=5380, single_connection_client=b ) as r: - await r.set("foo", "foo") await r.set("bar", "bar") @@ -189,7 +186,6 @@ async def op(pipe): @pytest.mark.onlycluster async def test_cluster(master_host): - delay = 0.1 cluster_port = 16379 remap_base = 7372 diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index ed651cd903..a35bd4795f 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -490,7 +490,6 @@ async def test_json_mget_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_numby_commands_dollar(decoded_r: redis.Redis): - # Test NUMINCRBY await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} @@ -546,7 +545,6 @@ async def test_numby_commands_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strappend_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -578,7 +576,6 @@ async def test_strappend_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strlen_dollar(decoded_r: redis.Redis): - # Test multi await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -713,7 +710,6 @@ async def test_arrinsert_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrlen_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -802,7 +798,6 @@ async def test_arrpop_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrtrim_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -960,7 +955,6 @@ async def test_type_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_clear_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 75484a2791..c052eae2a0 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -234,7 +234,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 0fa1204750..3d271bf1d0 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -396,7 +396,6 @@ async def test_pipeline_get(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") async def test_pipeline_discard(self, r): - # empty pipeline should raise an error async with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8fef34d83d..19d4b1c650 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -121,7 +121,6 @@ async def test_pattern_subscribe_unsubscribe(self, pubsub): async def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert await sub_func(key) is None @@ -163,7 +162,6 @@ async def test_resubscribe_to_patterns_on_reconnection(self, pubsub): async def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False await sub_func(keys[0]) # we're now subscribed even though we haven't processed the diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index e46de39c70..efc5bf549c 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -77,7 +77,6 @@ async def createIndex(decoded_r, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -163,10 +162,8 @@ async def test_client(decoded_r: redis.Redis): ) ).total both_total = ( - await ( - decoded_r.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") ) ).total assert 129 == txt_total @@ -370,18 +367,14 @@ async def test_stopwords(decoded_r: redis.Redis): @pytest.mark.redismod async def test_filters(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num"), GeoField("loc")) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num"), GeoField("loc")) ) - await ( - decoded_r.hset( - "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} - ) + await decoded_r.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) - await ( - decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + await decoded_r.hset( + "doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"} ) await waitForIndex(decoded_r, "idx") @@ -432,10 +425,8 @@ async def test_filters(decoded_r: redis.Redis): @pytest.mark.redismod async def test_sort_by(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num", sortable=True)) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num", sortable=True)) ) await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) @@ -488,8 +479,8 @@ async def test_drop_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_example(decoded_r: redis.Redis): # Creating the index definition and schema - await ( - decoded_r.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + await decoded_r.ft().create_index( + (TextField("title", weight=5.0), TextField("body")) ) # Indexing a document @@ -550,8 +541,8 @@ async def test_auto_complete(decoded_r: redis.Redis): await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - sugs = await ( - decoded_r.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + sugs = await decoded_r.ft().sugget( + "ac", "pay", with_payloads=True, with_scores=True ) assert 3 == len(sugs) for sug in sugs: @@ -639,8 +630,8 @@ async def test_no_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_explain(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + await decoded_r.ft().create_index( + (TextField("f1"), TextField("f2"), TextField("f3")) ) res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @@ -903,10 +894,8 @@ async def test_alter_schema_add(decoded_r: redis.Redis): async def test_spell_check(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - await ( - decoded_r.hset( - "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} - ) + await decoded_r.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) await waitForIndex(decoded_r, "idx") @@ -1042,8 +1031,8 @@ async def test_scorer(decoded_r: redis.Redis): assert 1.0 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) assert 1.0 == res.docs[0].score - res = await ( - decoded_r.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() ) assert 0.14285714285714285 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index a2d52f17b7..25bd7730da 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -72,7 +72,6 @@ def client(self, host, port, **kwargs): @pytest_asyncio.fixture() async def cluster(master_ip): - cluster = SentinelTestCluster(ip=master_ip) saved_Redis = redis.asyncio.sentinel.Redis redis.asyncio.sentinel.Redis = cluster.client diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index 48ffdfd889..91c15c3db2 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -108,7 +108,6 @@ async def test_add(decoded_r: redis.Redis): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_add_duplicate_policy(r: redis.Redis): - # Test for duplicate policy BLOCK assert 1 == await r.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): diff --git a/tests/test_commands.py b/tests/test_commands.py index b538dc3038..6660c2c6b0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -516,7 +516,6 @@ def test_client_trackinginfo(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() def test_client_tracking(self, r, r2): - # simple case assert r.client_tracking_on() assert r.client_tracking_off() @@ -5011,7 +5010,6 @@ def test_module_loadex(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") def test_restore(self, r): - # standard restore key = "foo" r.set(key, "bar") diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py index 581ebfab5d..d2a1e3f39e 100644 --- a/tests/test_graph_utils/test_edge.py +++ b/tests/test_graph_utils/test_edge.py @@ -4,7 +4,6 @@ @pytest.mark.redismod def test_init(): - with pytest.raises(AssertionError): edge.Edge(None, None, None) edge.Edge(node.Node(), None, None) diff --git a/tests/test_json.py b/tests/test_json.py index be347f6677..73d72b8cc9 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -464,7 +464,6 @@ def test_json_mget_dollar(client): def test_numby_commands_dollar(client): - # Test NUMINCRBY client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) # Test multi @@ -508,7 +507,6 @@ def test_numby_commands_dollar(client): def test_strappend_dollar(client): - client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -539,7 +537,6 @@ def test_strappend_dollar(client): def test_strlen_dollar(client): - # Test multi client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -672,7 +669,6 @@ def test_arrinsert_dollar(client): def test_arrlen_dollar(client): - client.json().set( "doc1", "$", @@ -762,7 +758,6 @@ def test_arrpop_dollar(client): def test_arrtrim_dollar(client): - client.json().set( "doc1", "$", @@ -1015,7 +1010,6 @@ def test_toggle_dollar(client): def test_resp_dollar(client): - data = { "L1": { "a": { @@ -1244,7 +1238,6 @@ def test_resp_dollar(client): def test_arrindex_dollar(client): - client.json().set( "store", "$", diff --git a/tests/test_lock.py b/tests/test_lock.py index b34f7f0159..72af87fa81 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -247,7 +247,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e64a763bae..7f10fcad4f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -390,7 +390,6 @@ def test_pipeline_with_bitfield(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") def test_pipeline_discard(self, r): - # empty pipeline should raise an error with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ba097e3194..fb46772af3 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -152,7 +152,6 @@ def test_shard_channel_subscribe_unsubscribe_cluster(self, r): def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert sub_func(key) is None @@ -201,7 +200,6 @@ def test_resubscribe_to_shard_channels_on_reconnection(self, r): def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False sub_func(keys[0]) # we're now subscribed even though we haven't processed the diff --git a/tests/test_search.py b/tests/test_search.py index 7612332470..9bbfc3c696 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -86,7 +86,6 @@ def createIndex(client, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -820,7 +819,6 @@ def test_spell_check(client): waitForIndex(client, getattr(client.ft(), "index_name", "idx")) if is_resp2_connection(client): - # test spellcheck res = client.ft().spellcheck("impornant") assert "important" == res["impornant"][0]["suggestion"] @@ -2100,7 +2098,6 @@ def test_numeric_params(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_geo_params(client): - client.ft().create_index((GeoField("g"))) client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 4ab86cd56e..6b59967f3c 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -104,7 +104,6 @@ def test_add(client): @skip_ifmodversion_lt("1.4.0", "timeseries") def test_add_duplicate_policy(client): - # Test for duplicate policy BLOCK assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception):