From c98c9bf7e868036e1f30dcef1b1dab2f80e55fd7 Mon Sep 17 00:00:00 2001 From: Avasam Date: Sun, 30 Apr 2023 01:44:05 -0400 Subject: [PATCH 1/2] Fix stubtest metaclass issues with Redis --- stubs/redis/@tests/stubtest_allowlist.txt | 55 +------------- stubs/redis/redis/asyncio/cluster.pyi | 10 +-- stubs/redis/redis/asyncio/connection.pyi | 2 +- stubs/redis/redis/commands/__init__.pyi | 19 +++-- stubs/redis/redis/commands/cluster.pyi | 48 +++++++++++- stubs/redis/redis/commands/core.pyi | 91 +++++++++++++++-------- stubs/redis/redis/typing.pyi | 13 +++- 7 files changed, 134 insertions(+), 104 deletions(-) diff --git a/stubs/redis/@tests/stubtest_allowlist.txt b/stubs/redis/@tests/stubtest_allowlist.txt index 4256e9c480b5..ba8809bb7cf6 100644 --- a/stubs/redis/@tests/stubtest_allowlist.txt +++ b/stubs/redis/@tests/stubtest_allowlist.txt @@ -16,59 +16,6 @@ redis.asyncio.sentinel.Sentinel.slave_for redis.sentinel.Sentinel.master_for redis.sentinel.Sentinel.slave_for -# Metaclass differs: -redis.RedisCluster -redis.asyncio.Redis +# Metaclass differs (runtime uses redis.compat.TypedDict) redis.asyncio.client.MonitorCommandInfo -redis.asyncio.client.Pipeline -redis.asyncio.client.Redis redis.asyncio.connection.ConnectKwargs -redis.client.Pipeline -redis.client.Redis -redis.cluster.ClusterPipeline -redis.cluster.RedisCluster -redis.commands.AsyncCoreCommands -redis.commands.CoreCommands -redis.commands.RedisClusterCommands -redis.commands.cluster.ClusterDataAccessCommands -redis.commands.cluster.ClusterManagementCommands -redis.commands.cluster.ClusterMultiKeyCommands -redis.commands.cluster.RedisClusterCommands -redis.commands.core.ACLCommands -redis.commands.core.AsyncACLCommands -redis.commands.core.AsyncBasicKeyCommands -redis.commands.core.AsyncClusterCommands -redis.commands.core.AsyncCoreCommands -redis.commands.core.AsyncDataAccessCommands -redis.commands.core.AsyncGeoCommands -redis.commands.core.AsyncHashCommands -redis.commands.core.AsyncHyperlogCommands -redis.commands.core.AsyncListCommands -redis.commands.core.AsyncManagementCommands -redis.commands.core.AsyncModuleCommands -redis.commands.core.AsyncPubSubCommands -redis.commands.core.AsyncScanCommands -redis.commands.core.AsyncScriptCommands -redis.commands.core.AsyncSetCommands -redis.commands.core.AsyncSortedSetCommands -redis.commands.core.AsyncStreamCommands -redis.commands.core.BasicKeyCommands -redis.commands.core.ClusterCommands -redis.commands.core.CoreCommands -redis.commands.core.DataAccessCommands -redis.commands.core.GeoCommands -redis.commands.core.HashCommands -redis.commands.core.HyperlogCommands -redis.commands.core.ListCommands -redis.commands.core.ManagementCommands -redis.commands.core.ModuleCommands -redis.commands.core.PubSubCommands -redis.commands.core.ScanCommands -redis.commands.core.ScriptCommands -redis.commands.core.SetCommands -redis.commands.core.SortedSetCommands -redis.commands.core.StreamCommands -redis.commands.json.Pipeline -redis.commands.timeseries.Pipeline -redis.asyncio.cluster.ClusterPipeline -redis.asyncio.cluster.RedisCluster diff --git a/stubs/redis/redis/asyncio/cluster.pyi b/stubs/redis/redis/asyncio/cluster.pyi index e61313fdd0f9..00dd3e908aa2 100644 --- a/stubs/redis/redis/asyncio/cluster.pyi +++ b/stubs/redis/redis/asyncio/cluster.pyi @@ -9,9 +9,7 @@ from redis.asyncio.connection import BaseParser, Connection, Encoder from redis.asyncio.parser import CommandsParser from redis.client import AbstractRedis from redis.cluster import AbstractRedisCluster, LoadBalancer - -# TODO: add AsyncRedisClusterCommands stubs -# from redis.commands import AsyncRedisClusterCommands +from redis.commands import AsyncRedisClusterCommands from redis.commands.core import _StrType from redis.credentials import CredentialProvider from redis.retry import Retry @@ -20,7 +18,7 @@ from redis.typing import AnyKeyT, EncodableT, KeyT # It uses `DefaultParser` in real life, but it is a dynamic base class. class ClusterParser(BaseParser): ... -class RedisCluster(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): # TODO: AsyncRedisClusterCommands +class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands, Generic[_StrType]): # type:ignore[misc] retry: Retry | None connection_kwargs: dict[str, Any] nodes_manager: NodesManager @@ -145,7 +143,7 @@ class NodesManager: async def initialize(self) -> None: ... async def close(self, attr: str = "nodes_cache") -> None: ... -class ClusterPipeline(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): # TODO: AsyncRedisClusterCommands +class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands, Generic[_StrType]): # type:ignore[misc] def __init__(self, client: RedisCluster[_StrType]) -> None: ... async def initialize(self) -> Self: ... async def __aenter__(self) -> Self: ... @@ -161,7 +159,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, Generic[_StrType]): def __len__(self) -> int: ... def execute_command(self, *args: KeyT | EncodableT, **kwargs: Any) -> Self: ... async def execute(self, raise_on_error: bool = True, allow_redirections: bool = True) -> list[Any]: ... - def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> Self: ... + def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> Self: ... # type:ignore[override] class PipelineCommand: args: Any diff --git a/stubs/redis/redis/asyncio/connection.pyi b/stubs/redis/redis/asyncio/connection.pyi index ae7e0fd6aa6f..f862cce19589 100644 --- a/stubs/redis/redis/asyncio/connection.pyi +++ b/stubs/redis/redis/asyncio/connection.pyi @@ -232,7 +232,7 @@ def to_bool(value) -> bool | None: ... URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] -class ConnectKwargs(TypedDict): +class ConnectKwargs(TypedDict, total=False): username: str password: str connection_class: type[Connection] diff --git a/stubs/redis/redis/commands/__init__.pyi b/stubs/redis/redis/commands/__init__.pyi index 4959ea0fdf15..d7392f40cf6f 100644 --- a/stubs/redis/redis/commands/__init__.pyi +++ b/stubs/redis/redis/commands/__init__.pyi @@ -1,17 +1,24 @@ -from .cluster import RedisClusterCommands as RedisClusterCommands +from .cluster import ( + READ_COMMANDS as READ_COMMANDS, + AsyncRedisClusterCommands as AsyncRedisClusterCommands, + RedisClusterCommands as RedisClusterCommands, +) from .core import AsyncCoreCommands as AsyncCoreCommands, CoreCommands as CoreCommands from .helpers import list_or_args as list_or_args from .parser import CommandsParser as CommandsParser -from .redismodules import RedisModuleCommands as RedisModuleCommands +from .redismodules import RedisModuleCommands as RedisModuleCommands # , AsyncRedisModuleCommands as AsyncRedisModuleCommands from .sentinel import AsyncSentinelCommands as AsyncSentinelCommands, SentinelCommands as SentinelCommands __all__ = [ - "RedisClusterCommands", - "CommandsParser", "AsyncCoreCommands", + "AsyncRedisClusterCommands", + # "AsyncRedisModuleCommands", # incomplete + "AsyncSentinelCommands", + "CommandsParser", "CoreCommands", - "list_or_args", + "READ_COMMANDS", + "RedisClusterCommands", "RedisModuleCommands", - "AsyncSentinelCommands", "SentinelCommands", + "list_or_args", ] diff --git a/stubs/redis/redis/commands/cluster.pyi b/stubs/redis/redis/commands/cluster.pyi index 5304382400a9..718d6e0f5b05 100644 --- a/stubs/redis/redis/commands/cluster.pyi +++ b/stubs/redis/redis/commands/cluster.pyi @@ -1,9 +1,28 @@ from _typeshed import Incomplete -from typing import Generic +from collections.abc import AsyncIterator, Mapping +from typing import Any, Generic -from .core import ACLCommands, DataAccessCommands, ManagementCommands, PubSubCommands, _StrType +from ..asyncio.connection import ConnectionPool as AsyncConnectionPool, Encoder as AsyncEncoder +from ..connection import ConnectionPool, Encoder +from ..typing import AnyKeyT, ClusterCommandsProtocol, EncodableT, KeysT, KeyT, PatternT +from .core import ( + ACLCommands, + AsyncACLCommands, + AsyncDataAccessCommands, + AsyncFunctionCommands, + AsyncManagementCommands, + AsyncScriptCommands, + DataAccessCommands, + ManagementCommands, + PubSubCommands, + _StrType, +) + +READ_COMMANDS: frozenset[str] -class ClusterMultiKeyCommands: +class ClusterMultiKeyCommands(ClusterCommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool + encoder: AsyncEncoder | Encoder def mget_nonatomic(self, keys, *args): ... def mset_nonatomic(self, mapping): ... def exists(self, *keys): ... @@ -11,11 +30,18 @@ class ClusterMultiKeyCommands: def touch(self, *keys): ... def unlink(self, *keys): ... +class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): + async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> list[Any | None]: ... + async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> list[bool]: ... + class ClusterManagementCommands(ManagementCommands): def slaveof(self, *args, **kwargs) -> None: ... def replicaof(self, *args, **kwargs) -> None: ... def swapdb(self, *args, **kwargs) -> None: ... +class AsyncClusterManagementCommands(ClusterManagementCommands, AsyncManagementCommands): # type:ignore[misc] + async def cluster_delslots(self, *slots: EncodableT) -> list[bool]: ... + class ClusterDataAccessCommands(DataAccessCommands[_StrType], Generic[_StrType]): def stralgo( self, @@ -30,6 +56,13 @@ class ClusterDataAccessCommands(DataAccessCommands[_StrType], Generic[_StrType]) **kwargs, ): ... +class AsyncClusterDataAccessCommands( # type:ignore[misc] + ClusterDataAccessCommands[Incomplete], AsyncDataAccessCommands[Incomplete] +): + async def scan_iter( # type:ignore[override] + self, match: PatternT | None = None, count: int | None = None, _type: str | None = None, **kwargs + ) -> AsyncIterator[Incomplete]: ... + class RedisClusterCommands( ClusterMultiKeyCommands, ClusterManagementCommands, @@ -59,3 +92,12 @@ class RedisClusterCommands( read_from_replicas: bool def readonly(self, target_nodes: Incomplete | None = None): ... def readwrite(self, target_nodes: Incomplete | None = None): ... + +class AsyncRedisClusterCommands( # type:ignore[misc] + AsyncClusterMultiKeyCommands, + AsyncClusterManagementCommands, + AsyncACLCommands[Incomplete], + AsyncClusterDataAccessCommands, + AsyncScriptCommands[Incomplete], + AsyncFunctionCommands, +): ... diff --git a/stubs/redis/redis/commands/core.pyi b/stubs/redis/redis/commands/core.pyi index 1651634dc891..212d9ea2da34 100644 --- a/stubs/redis/redis/commands/core.pyi +++ b/stubs/redis/redis/commands/core.pyi @@ -6,13 +6,16 @@ from typing import Any, Generic, TypeVar, overload from typing_extensions import Literal from ..asyncio.client import Redis as AsyncRedis +from ..asyncio.connection import ConnectionPool as AsyncConnectionPool from ..client import _CommandOptions, _Key, _Value -from ..typing import ChannelT, EncodableT, KeyT, PatternT, ScriptTextT, StreamIdT +from ..connection import ConnectionPool +from ..typing import ChannelT, CommandsProtocol, EncodableT, KeyT, PatternT, ScriptTextT, StreamIdT _ScoreCastFuncReturn = TypeVar("_ScoreCastFuncReturn") _StrType = TypeVar("_StrType", bound=str | bytes) -class ACLCommands(Generic[_StrType]): +class ACLCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def acl_cat(self, category: str | None = None, **kwargs: _CommandOptions) -> list[str]: ... def acl_deluser(self, *username: str, **kwargs: _CommandOptions) -> int: ... def acl_genpass(self, bits: int | None = None, **kwargs: _CommandOptions) -> str: ... @@ -44,7 +47,8 @@ class ACLCommands(Generic[_StrType]): def acl_users(self, **kwargs: _CommandOptions) -> list[str]: ... def acl_whoami(self, **kwargs: _CommandOptions) -> str: ... -class AsyncACLCommands(Generic[_StrType]): +class AsyncACLCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def acl_cat(self, category: str | None = None, **kwargs: _CommandOptions) -> list[str]: ... async def acl_deluser(self, *username: str, **kwargs: _CommandOptions) -> int: ... async def acl_genpass(self, bits: int | None = None, **kwargs: _CommandOptions) -> str: ... @@ -76,7 +80,8 @@ class AsyncACLCommands(Generic[_StrType]): async def acl_users(self, **kwargs: _CommandOptions) -> list[str]: ... async def acl_whoami(self, **kwargs: _CommandOptions) -> str: ... -class ManagementCommands: +class ManagementCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def bgrewriteaof(self, **kwargs: _CommandOptions): ... def bgsave(self, schedule: bool = True, **kwargs: _CommandOptions): ... def role(self): ... @@ -193,7 +198,8 @@ class ManagementCommands: def time(self, **kwargs: _CommandOptions): ... def wait(self, num_replicas, timeout, **kwargs: _CommandOptions): ... -class AsyncManagementCommands: +class AsyncManagementCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def bgrewriteaof(self, **kwargs: _CommandOptions): ... async def bgsave(self, schedule: bool = True, **kwargs: _CommandOptions): ... async def role(self): ... @@ -310,7 +316,8 @@ class AsyncManagementCommands: async def time(self, **kwargs: _CommandOptions): ... async def wait(self, num_replicas, timeout, **kwargs: _CommandOptions): ... -class BasicKeyCommands(Generic[_StrType]): +class BasicKeyCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def append(self, key, value): ... def bitcount(self, key: _Key, start: int | None = None, end: int | None = None, mode: str | None = None) -> int: ... def bitfield(self, key, default_overflow: Incomplete | None = None): ... @@ -428,7 +435,8 @@ class BasicKeyCommands(Generic[_StrType]): def unwatch(self): ... def unlink(self, *names: _Key) -> int: ... -class AsyncBasicKeyCommands(Generic[_StrType]): +class AsyncBasicKeyCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def append(self, key, value): ... async def bitcount(self, key: _Key, start: int | None = None, end: int | None = None, mode: str | None = None) -> int: ... async def bitfield(self, key, default_overflow: Incomplete | None = None): ... @@ -546,7 +554,8 @@ class AsyncBasicKeyCommands(Generic[_StrType]): def __delitem__(self, name: _Key) -> None: ... def __contains__(self, name: _Key) -> None: ... -class ListCommands(Generic[_StrType]): +class ListCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool @overload def blpop(self, keys: _Value | Iterable[_Value], timeout: Literal[0] | None = 0) -> tuple[_StrType, _StrType]: ... @overload @@ -616,7 +625,8 @@ class ListCommands(Generic[_StrType]): groups: bool = False, ) -> int: ... -class AsyncListCommands(Generic[_StrType]): +class AsyncListCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool @overload async def blpop(self, keys: _Value | Iterable[_Value], timeout: Literal[0] | None = 0) -> tuple[_StrType, _StrType]: ... @overload @@ -686,7 +696,8 @@ class AsyncListCommands(Generic[_StrType]): groups: bool = False, ) -> int: ... -class ScanCommands(Generic[_StrType]): +class ScanCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def scan( self, cursor: int = 0, @@ -747,7 +758,8 @@ class ScanCommands(Generic[_StrType]): self, name: _Key, match: _Key | None, count: int | None, score_cast_func: Callable[[_StrType], _ScoreCastFuncReturn] ) -> Iterator[tuple[_StrType, _ScoreCastFuncReturn]]: ... -class AsyncScanCommands(Generic[_StrType]): +class AsyncScanCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def scan( self, cursor: int = 0, @@ -810,7 +822,8 @@ class AsyncScanCommands(Generic[_StrType]): self, name: _Key, match: _Key | None, count: int | None, score_cast_func: Callable[[_StrType], _ScoreCastFuncReturn] ) -> AsyncIterator[tuple[_StrType, _ScoreCastFuncReturn]]: ... -class SetCommands(Generic[_StrType]): +class SetCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def sadd(self, name: _Key, *values: _Value) -> int: ... def scard(self, name: _Key) -> int: ... def sdiff(self, keys: _Key | Iterable[_Key], *args: _Key) -> builtins.set[_Value]: ... @@ -833,7 +846,8 @@ class SetCommands(Generic[_StrType]): def sunion(self, keys: _Key | Iterable[_Key], *args: _Key) -> builtins.set[_Value]: ... def sunionstore(self, dest: _Key, keys: _Key | Iterable[_Key], *args: _Key) -> int: ... -class AsyncSetCommands(Generic[_StrType]): +class AsyncSetCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def sadd(self, name: _Key, *values: _Value) -> int: ... async def scard(self, name: _Key) -> int: ... async def sdiff(self, keys: _Key | Iterable[_Key], *args: _Key) -> builtins.set[_Value]: ... @@ -856,7 +870,8 @@ class AsyncSetCommands(Generic[_StrType]): async def sunion(self, keys: _Key | Iterable[_Key], *args: _Key) -> builtins.set[_Value]: ... async def sunionstore(self, dest: _Key, keys: _Key | Iterable[_Key], *args: _Key) -> int: ... -class StreamCommands: +class StreamCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def xack(self, name, groupname, *ids): ... def xadd( self, @@ -922,7 +937,8 @@ class StreamCommands: self, name, maxlen: int | None = None, approximate: bool = True, minid: Incomplete | None = None, limit: int | None = None ): ... -class AsyncStreamCommands: +class AsyncStreamCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def xack(self, name, groupname, *ids): ... async def xadd( self, @@ -988,7 +1004,8 @@ class AsyncStreamCommands: self, name, maxlen: int | None = None, approximate: bool = True, minid: Incomplete | None = None, limit: int | None = None ): ... -class SortedSetCommands(Generic[_StrType]): +class SortedSetCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def zadd( self, name: _Key, @@ -1187,7 +1204,8 @@ class SortedSetCommands(Generic[_StrType]): def zunionstore(self, dest: _Key, keys: Iterable[_Key], aggregate: Literal["SUM", "MIN", "MAX"] | None = None) -> int: ... def zmscore(self, key, members): ... -class AsyncSortedSetCommands(Generic[_StrType]): +class AsyncSortedSetCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def zadd( self, name: _Key, @@ -1392,17 +1410,20 @@ class AsyncSortedSetCommands(Generic[_StrType]): ) -> int: ... async def zmscore(self, key, members): ... -class HyperlogCommands: +class HyperlogCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def pfadd(self, name: _Key, *values: _Value) -> int: ... def pfcount(self, name: _Key) -> int: ... def pfmerge(self, dest: _Key, *sources: _Key) -> bool: ... -class AsyncHyperlogCommands: +class AsyncHyperlogCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def pfadd(self, name: _Key, *values: _Value) -> int: ... async def pfcount(self, name: _Key) -> int: ... async def pfmerge(self, dest: _Key, *sources: _Key) -> bool: ... -class HashCommands(Generic[_StrType]): +class HashCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def hdel(self, name: _Key, *keys: _Key) -> int: ... def hexists(self, name: _Key, key: _Key) -> bool: ... def hget(self, name: _Key, key: _Key) -> _StrType | None: ... @@ -1427,7 +1448,8 @@ class HashCommands(Generic[_StrType]): def hvals(self, name: _Key) -> list[_StrType]: ... def hstrlen(self, name, key): ... -class AsyncHashCommands(Generic[_StrType]): +class AsyncHashCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool async def hdel(self, name: _Key, *keys: _Key) -> int: ... async def hexists(self, name: _Key, key: _Key) -> bool: ... async def hget(self, name: _Key, key: _Key) -> _StrType | None: ... @@ -1458,19 +1480,22 @@ class AsyncScript: self, keys: Sequence[KeyT] | None = None, args: Iterable[EncodableT] | None = None, client: AsyncRedis[Any] | None = None ): ... -class PubSubCommands: +class PubSubCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def publish(self, channel: _Key, message: _Key, **kwargs: _CommandOptions) -> int: ... def pubsub_channels(self, pattern: _Key = "*", **kwargs: _CommandOptions) -> list[str]: ... def pubsub_numpat(self, **kwargs: _CommandOptions) -> int: ... def pubsub_numsub(self, *args: _Key, **kwargs: _CommandOptions) -> list[tuple[str, int]]: ... -class AsyncPubSubCommands: +class AsyncPubSubCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def publish(self, channel: _Key, message: _Key, **kwargs: _CommandOptions) -> int: ... async def pubsub_channels(self, pattern: _Key = "*", **kwargs: _CommandOptions) -> list[str]: ... async def pubsub_numpat(self, **kwargs: _CommandOptions) -> int: ... async def pubsub_numsub(self, *args: _Key, **kwargs: _CommandOptions) -> list[tuple[str, int]]: ... -class ScriptCommands(Generic[_StrType]): +class ScriptCommands(CommandsProtocol, Generic[_StrType]): + connection_pool: AsyncConnectionPool | ConnectionPool def eval(self, script, numkeys, *keys_and_args): ... def evalsha(self, sha, numkeys, *keys_and_args): ... def script_exists(self, *args): ... @@ -1480,7 +1505,7 @@ class ScriptCommands(Generic[_StrType]): def script_load(self, script): ... def register_script(self, script: str | _StrType) -> Script: ... -class AsyncScriptCommands(Generic[_StrType]): +class AsyncScriptCommands(ScriptCommands[_StrType], Generic[_StrType]): async def eval(self, script, numkeys, *keys_and_args): ... async def evalsha(self, sha, numkeys, *keys_and_args): ... async def script_exists(self, *args): ... @@ -1490,7 +1515,8 @@ class AsyncScriptCommands(Generic[_StrType]): async def script_load(self, script): ... def register_script(self, script: ScriptTextT) -> AsyncScript: ... # type: ignore[override] -class GeoCommands: +class GeoCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def geoadd(self, name, values, nx: bool = False, xx: bool = False, ch: bool = False): ... def geodist(self, name, place1, place2, unit: Incomplete | None = None): ... def geohash(self, name, *values): ... @@ -1560,7 +1586,8 @@ class GeoCommands: storedist: bool = False, ): ... -class AsyncGeoCommands: +class AsyncGeoCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def geoadd(self, name, values, nx: bool = False, xx: bool = False, ch: bool = False): ... async def geodist(self, name, place1, place2, unit: Incomplete | None = None): ... async def geohash(self, name, *values): ... @@ -1630,7 +1657,8 @@ class AsyncGeoCommands: storedist: bool = False, ): ... -class ModuleCommands: +class ModuleCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def module_load(self, path, *args): ... def module_unload(self, name): ... def module_list(self): ... @@ -1657,12 +1685,14 @@ class BitFieldOperation: class AsyncModuleCommands(ModuleCommands): async def command_info(self) -> None: ... -class ClusterCommands: +class ClusterCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool def cluster(self, cluster_arg: str, *args, **kwargs: _CommandOptions): ... def readwrite(self, **kwargs: _CommandOptions) -> bool: ... def readonly(self, **kwargs: _CommandOptions) -> bool: ... -class AsyncClusterCommands: +class AsyncClusterCommands(CommandsProtocol): + connection_pool: AsyncConnectionPool | ConnectionPool async def cluster(self, cluster_arg: str, *args, **kwargs: _CommandOptions): ... async def readwrite(self, **kwargs: _CommandOptions) -> bool: ... async def readonly(self, **kwargs: _CommandOptions) -> bool: ... @@ -1725,6 +1755,7 @@ class CoreCommands( ModuleCommands, PubSubCommands, ScriptCommands[_StrType], + FunctionCommands, Generic[_StrType], ): ... class AsyncCoreCommands( diff --git a/stubs/redis/redis/typing.pyi b/stubs/redis/redis/typing.pyi index f351ed45ac76..e88fbd39c263 100644 --- a/stubs/redis/redis/typing.pyi +++ b/stubs/redis/redis/typing.pyi @@ -1,10 +1,11 @@ -from collections.abc import Iterable +from _typeshed import Incomplete +from collections.abc import Awaitable, Iterable from datetime import datetime, timedelta -from typing import Protocol, TypeVar +from typing import Any, Protocol, TypeVar from typing_extensions import TypeAlias -from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool -from redis.connection import ConnectionPool +from .asyncio.connection import ConnectionPool as AsyncConnectionPool, Encoder as AsyncEncoder +from .connection import ConnectionPool, Encoder # The following type aliases exist at runtime. EncodedT: TypeAlias = bytes | memoryview @@ -32,3 +33,7 @@ AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) # noqa: Y001 class CommandsProtocol(Protocol): connection_pool: AsyncConnectionPool | ConnectionPool def execute_command(self, *args, **options): ... + +class ClusterCommandsProtocol(CommandsProtocol, Protocol): + encoder: AsyncEncoder | Encoder + def execute_command(self, *args, **options) -> Any | Awaitable[Incomplete]: ... From 730eaf065c9cf2ee265fffdd0cca9e96b4ee1db9 Mon Sep 17 00:00:00 2001 From: Avasam Date: Mon, 8 May 2023 13:14:35 -0400 Subject: [PATCH 2/2] Update stubs/redis/@tests/stubtest_allowlist.txt --- stubs/redis/@tests/stubtest_allowlist.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/stubs/redis/@tests/stubtest_allowlist.txt b/stubs/redis/@tests/stubtest_allowlist.txt index ca2f6e62bad1..bde45aad6345 100644 --- a/stubs/redis/@tests/stubtest_allowlist.txt +++ b/stubs/redis/@tests/stubtest_allowlist.txt @@ -17,9 +17,6 @@ redis.sentinel.Sentinel.master_for redis.sentinel.Sentinel.slave_for # Metaclass differs: -redis.RedisCluster -redis.asyncio.Redis -redis.asyncio.RedisCluster # (runtime uses redis.compat.TypedDict) redis.asyncio.client.MonitorCommandInfo redis.asyncio.connection.ConnectKwargs