From 155f837d3b07486773d22b1008f00785f0b9a4eb Mon Sep 17 00:00:00 2001 From: Marc Edwards Date: Mon, 18 Nov 2024 23:06:20 +0000 Subject: [PATCH 1/2] Add types to `lru_cache` by implementing descriptor protocol WIP Tidy Fix ignores --- stdlib/@tests/test_cases/check_lru_cache.py | 262 ++++++++++++++++++++ stdlib/functools.pyi | 40 ++- 2 files changed, 293 insertions(+), 9 deletions(-) create mode 100644 stdlib/@tests/test_cases/check_lru_cache.py diff --git a/stdlib/@tests/test_cases/check_lru_cache.py b/stdlib/@tests/test_cases/check_lru_cache.py new file mode 100644 index 000000000000..91c8ea828d28 --- /dev/null +++ b/stdlib/@tests/test_cases/check_lru_cache.py @@ -0,0 +1,262 @@ +from __future__ import annotations +# pyright: reportUnnecessaryTypeIgnoreComment=true + +from functools import cache +from collections.abc import Callable +from typing import ( + final, + assert_type, + TYPE_CHECKING, + overload, + override, + Any, + Self, + TypeVar, + ParamSpec, + Generic, +) +from dataclasses import dataclass + +P = ParamSpec("P") +R = TypeVar("R") + +@cache +def cached_fn(arg: int, arg2: str) -> int: + return arg + +@dataclass +class MemberVarCached(Generic[P, R]): + member_callable: Callable[P, R] + +@cache +def cached_fn_takes_t(arg: MemberVarCached[..., Any], arg2: str) -> int: + return 1 + + +vc = MemberVarCached(cached_fn) +vc.member_callable(1, "") +assert_type(vc.member_callable(1, ""), int) + +vc_t = MemberVarCached(cached_fn_takes_t) +vc_t.member_callable(vc_t, "") + +if TYPE_CHECKING: + # type errors - correct + vc_t.member_callable("") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] + vc.member_callable(1, 1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + vc.member_callable(1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + vc.member_callable("1") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] + +class CFnCls: + @cache + def fn(self, arg: int) -> int: + print("method fn called") + return arg + + @cache + def fn_bad_self_name(this_t, arg: int) -> int: + print("method fn called") + return arg + + @classmethod + @cache + def cls_fn(cls, arg: int) -> int: + print("class fn called") + return arg + + @classmethod + @cache + def cls_fn_bad_name(my_class_type, arg: int) -> int: + print("class fn called") + return arg + + @classmethod + @cache + def cls_fn_positional_only(__cls, arg: int) -> int: + print("class fn called") + return arg + + @classmethod + @cache + def cls_fn_explicit_positional_only(_cls, arg: int, /) -> int: + print("class fn called") + return arg + + + @staticmethod + @cache + def st_fn(arg: int) -> int: + print("static fn called") + return arg + + @staticmethod + @cache + def st_fn_clst_arg(arg: type[CFnCls]) -> int: + print("static fn called") + return 1 + + @staticmethod + @cache + def st_fn_strt_arg(arg: type[str]) -> int: + print("static fn called") + return 1 + + @staticmethod + @cache + def st_fn_self_t_arg(arg: CFnCls) -> int: + print("static fn called") + return 1 + + @property + @cache + def prp_fn(self) -> int: + print("property fn called") + return 1 + +class CFnSubCls(CFnCls): + pass + +cfn_inst = CFnCls() +cfn_inst.fn(1) +cfn_inst.fn.__wrapped__(cfn_inst, 1) +cfn_inst.prp_fn +CFnCls.fn.__wrapped__(CFnCls(), 1) +CFnCls.st_fn(1) +CFnCls.st_fn.__wrapped__(1) +CFnCls.cls_fn(1) +CFnCls.cls_fn.__wrapped__(CFnCls, 1) +cfn_inst.fn(1) +CFnCls.st_fn(1) +CFnCls.cls_fn(1) +CFnCls.cls_fn_bad_name(1) +CFnCls.cls_fn_positional_only(1) +CFnCls().cls_fn_positional_only(1) +CFnCls.cls_fn_explicit_positional_only(1) +CFnCls().cls_fn_explicit_positional_only(1) + +# incorrect type error. If a static method +# takes the type of the enclosing class as +# the first argument there's a false positive. +CFnCls.st_fn_clst_arg(CFnCls) +CFnSubCls.st_fn_clst_arg(CFnSubCls) +# If a static method takes an instance of the +# enclosing class as its first argument, +# there's an error if accessed via an instance. +CFnSubCls.st_fn_self_t_arg(CFnSubCls()) +CFnSubCls().st_fn_self_t_arg(CFnSubCls()) +# but a different type is fine +CFnCls.st_fn_strt_arg(str) + +assert_type(cfn_inst.fn(1), int) +assert_type(CFnCls.st_fn(1), int) +assert_type(CFnCls.cls_fn(1), int) +assert_type(cfn_inst.prp_fn, int) +CFnCls().fn.cache_clear() +CFnCls.fn.cache_clear() +CFnCls.st_fn.cache_clear() +CFnCls.cls_fn.cache_clear() +if TYPE_CHECKING: + # type errors - correct + CFnCls().fn(1, 1) # type: ignore[call-arg,misc] # pyright: ignore[reportCallIssue] + CFnCls().fn.__wrapped__(CFnCls(), 1, 1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + CFnCls.fn(arg=1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + CFnCls.st_fn(1, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn(1, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls().cls_fn_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls().cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + +@cache +def fn(arg: int) -> int: + return arg + +@cache +def df_fn(arg: int, darg: str = "default"): + print("default fn called") + return darg +df_fn(1) + +fn(1) +assert_type(fn(1), int) +if TYPE_CHECKING: + # type error - correct + fn(1, 2) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] +fn.cache_clear() + +@overload +@cache +def fn_overload(arg: int) -> int: + ... +@overload +@cache +def fn_overload(arg: str) -> str: + return arg + +@cache +def fn_overload(arg: int | str) -> int | str: + return arg + +fn_overload(1) +fn_overload("1") +# behavior varies between type checkers. +assert_type(fn_overload(1), int) +assert_type(fn_overload("1"), str) +fn_overload.cache_clear() +if TYPE_CHECKING: + # type error - correct + fn_overload(frozenset({1,2})) # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + + +class Unhashable: + @override + def __eq__(self, value: object) -> bool: + return False + +@cache +def no_cache(arg: Unhashable, arg2: int) -> None: + pass + +if TYPE_CHECKING: + # This is not correctly rejected. + no_cache(Unhashable(), 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + +class MemberVarBound(Generic[P, R]): + @cache + def equals(self, other: Self) -> bool: + return False + member_fn: Callable[P, R] + +def set_member(lhs: MemberVarBound[..., Any], rhs: MemberVarBound[..., Any]) -> None: + lhs.member_fn = rhs.equals + lhs.member_fn(rhs) + + if TYPE_CHECKING: + lhs.member_fn() # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + +from abc import ABCMeta, abstractmethod + +class CustomABC(metaclass=ABCMeta): + + @abstractmethod + def foo(self, arg: int) -> int: + ... + +@final +class ABCConcrete(CustomABC): + @override + def foo(self, arg: int) -> int: + return 1 + + @cache + def abc_fn(self, arg: str) -> str: + return arg + + @classmethod + @cache + def abc_cm(cls, arg: str) -> str: + return arg + +ABCConcrete().abc_fn("1") +ABCConcrete.abc_fn(ABCConcrete(), "1") +ABCConcrete.abc_cm("") diff --git a/stdlib/functools.pyi b/stdlib/functools.pyi index 9957fa8f1634..78407f2689b3 100644 --- a/stdlib/functools.pyi +++ b/stdlib/functools.pyi @@ -3,7 +3,7 @@ import types from _typeshed import SupportsAllComparisons, SupportsItems from collections.abc import Callable, Hashable, Iterable, Sequence, Sized from typing import Any, Generic, Literal, NamedTuple, TypedDict, TypeVar, final, overload -from typing_extensions import ParamSpec, Self, TypeAlias +from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias if sys.version_info >= (3, 9): from types import GenericAlias @@ -28,8 +28,10 @@ if sys.version_info >= (3, 9): __all__ += ["cache"] _T = TypeVar("_T") +_R = TypeVar("_R") _T_co = TypeVar("_T_co", covariant=True) _S = TypeVar("_S") +_P = ParamSpec("_P") _PWrapped = ParamSpec("_PWrapped") _RWrapped = TypeVar("_RWrapped") _PWrapper = ParamSpec("_PWrapper") @@ -51,22 +53,42 @@ if sys.version_info >= (3, 9): maxsize: int typed: bool +_C = TypeVar("_C", bound=Callable[..., Any]) +_W = TypeVar("_W", bound=Callable[..., Any]) +_CacheFunction: TypeAlias = _lru_cache_wrapper[Callable[_P, _R], _W] +_CacheMethod: TypeAlias = _lru_cache_wrapper[Callable[Concatenate[_T, _P], _R], _W] +_CacheClassmethod: TypeAlias = _lru_cache_wrapper[Callable[Concatenate[type[_T], _P], _R], _W] + @final -class _lru_cache_wrapper(Generic[_T]): - __wrapped__: Callable[..., _T] - def __call__(self, *args: Hashable, **kwargs: Hashable) -> _T: ... +class _lru_cache_wrapper(Generic[_C, _W]): + __wrapped__: _W + def __call__(_self: _CacheFunction[_P, _R, _W], *args: _P.args, **kwargs: _P.kwargs) -> _R: ... def cache_info(self) -> _CacheInfo: ... def cache_clear(self) -> None: ... if sys.version_info >= (3, 9): def cache_parameters(self) -> _CacheParameters: ... - def __copy__(self) -> _lru_cache_wrapper[_T]: ... - def __deepcopy__(self, memo: Any, /) -> _lru_cache_wrapper[_T]: ... + def __copy__(self) -> Self: ... + def __deepcopy__(self, memo: Any, /) -> Self: ... + @overload + def __get__(self: _CacheMethod[_T, _P, _R, _W], instance: None, owner: type[_T]) -> _CacheMethod[_T, _P, _R, _W]: ... + @overload + def __get__(self: _CacheMethod[_T, _P, _R, _W], instance: _T, owner: type[_T] | None = ...) -> _CacheFunction[_P, _R, _W]: ... + @overload + def __get__(self: _CacheClassmethod[_T, _P, _R, _W], instance: _T | None, owner: type[_T]) -> _CacheFunction[_P, _R, _W]: ... + @overload + def __get__( + self: _CacheClassmethod[_T, _P, _R, _W], instance: _T, owner: type[_T] | None = None + ) -> _CacheFunction[_P, _R, _W]: ... + @overload + def __get__( + self: _CacheFunction[_P, _R, _W], instance: Any | None, owner: type[Any] | None = None + ) -> _CacheFunction[_P, _R, _W]: ... @overload -def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... +def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[_C], _lru_cache_wrapper[_C, _C]]: ... @overload -def lru_cache(maxsize: Callable[..., _T], typed: bool = False) -> _lru_cache_wrapper[_T]: ... +def lru_cache(maxsize: _C, typed: bool = False) -> _lru_cache_wrapper[_C, _C]: ... if sys.version_info >= (3, 12): WRAPPER_ASSIGNMENTS: tuple[ @@ -199,7 +221,7 @@ class cached_property(Generic[_T_co]): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... if sys.version_info >= (3, 9): - def cache(user_function: Callable[..., _T], /) -> _lru_cache_wrapper[_T]: ... + def cache(user_function: _C, /) -> _lru_cache_wrapper[_C, _C]: ... def _make_key( args: tuple[Hashable, ...], From d309de479ef93869edbfc8ed73a57b1481ed9d82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 23:15:23 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stdlib/@tests/test_cases/check_lru_cache.py | 79 ++++++++++++--------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/stdlib/@tests/test_cases/check_lru_cache.py b/stdlib/@tests/test_cases/check_lru_cache.py index 91c8ea828d28..4474e8d1a29c 100644 --- a/stdlib/@tests/test_cases/check_lru_cache.py +++ b/stdlib/@tests/test_cases/check_lru_cache.py @@ -1,33 +1,26 @@ from __future__ import annotations -# pyright: reportUnnecessaryTypeIgnoreComment=true -from functools import cache from collections.abc import Callable -from typing import ( - final, - assert_type, - TYPE_CHECKING, - overload, - override, - Any, - Self, - TypeVar, - ParamSpec, - Generic, -) from dataclasses import dataclass +# pyright: reportUnnecessaryTypeIgnoreComment=true +from functools import cache +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Self, TypeVar, assert_type, final, overload, override + P = ParamSpec("P") R = TypeVar("R") + @cache def cached_fn(arg: int, arg2: str) -> int: return arg + @dataclass class MemberVarCached(Generic[P, R]): member_callable: Callable[P, R] + @cache def cached_fn_takes_t(arg: MemberVarCached[..., Any], arg2: str) -> int: return 1 @@ -42,10 +35,11 @@ def cached_fn_takes_t(arg: MemberVarCached[..., Any], arg2: str) -> int: if TYPE_CHECKING: # type errors - correct - vc_t.member_callable("") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] - vc.member_callable(1, 1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + vc_t.member_callable("") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] + vc.member_callable(1, 1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] vc.member_callable(1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] - vc.member_callable("1") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] + vc.member_callable("1") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue] + class CFnCls: @cache @@ -82,7 +76,6 @@ def cls_fn_explicit_positional_only(_cls, arg: int, /) -> int: print("class fn called") return arg - @staticmethod @cache def st_fn(arg: int) -> int: @@ -113,9 +106,11 @@ def prp_fn(self) -> int: print("property fn called") return 1 + class CFnSubCls(CFnCls): pass + cfn_inst = CFnCls() cfn_inst.fn(1) cfn_inst.fn.__wrapped__(cfn_inst, 1) @@ -157,46 +152,52 @@ class CFnSubCls(CFnCls): CFnCls.cls_fn.cache_clear() if TYPE_CHECKING: # type errors - correct - CFnCls().fn(1, 1) # type: ignore[call-arg,misc] # pyright: ignore[reportCallIssue] - CFnCls().fn.__wrapped__(CFnCls(), 1, 1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] - CFnCls.fn(arg=1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] - CFnCls.st_fn(1, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] - CFnCls.cls_fn(1, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] - CFnCls.cls_fn_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] - CFnCls().cls_fn_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] - CFnCls.cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] - CFnCls().cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + CFnCls().fn(1, 1) # type: ignore[call-arg,misc] # pyright: ignore[reportCallIssue] + CFnCls().fn.__wrapped__(CFnCls(), 1, 1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + CFnCls.fn(arg=1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + CFnCls.st_fn(1, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn(1, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls().cls_fn_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + CFnCls.cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue] + CFnCls().cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue] + @cache def fn(arg: int) -> int: return arg + @cache def df_fn(arg: int, darg: str = "default"): print("default fn called") return darg + + df_fn(1) fn(1) assert_type(fn(1), int) if TYPE_CHECKING: # type error - correct - fn(1, 2) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + fn(1, 2) # type: ignore[call-arg] # pyright: ignore[reportCallIssue] fn.cache_clear() + @overload @cache -def fn_overload(arg: int) -> int: - ... +def fn_overload(arg: int) -> int: ... @overload @cache def fn_overload(arg: str) -> str: return arg + @cache def fn_overload(arg: int | str) -> int | str: return arg + fn_overload(1) fn_overload("1") # behavior varies between type checkers. @@ -205,7 +206,7 @@ def fn_overload(arg: int | str) -> int | str: fn_overload.cache_clear() if TYPE_CHECKING: # type error - correct - fn_overload(frozenset({1,2})) # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + fn_overload(frozenset({1, 2})) # type: ignore[call-overload] # pyright: ignore[reportArgumentType] class Unhashable: @@ -213,34 +214,41 @@ class Unhashable: def __eq__(self, value: object) -> bool: return False + @cache def no_cache(arg: Unhashable, arg2: int) -> None: pass + if TYPE_CHECKING: # This is not correctly rejected. - no_cache(Unhashable(), 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + no_cache(Unhashable(), 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + class MemberVarBound(Generic[P, R]): @cache def equals(self, other: Self) -> bool: return False + member_fn: Callable[P, R] + def set_member(lhs: MemberVarBound[..., Any], rhs: MemberVarBound[..., Any]) -> None: lhs.member_fn = rhs.equals lhs.member_fn(rhs) if TYPE_CHECKING: - lhs.member_fn() # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + lhs.member_fn() # type: ignore[call-arg] # pyright: ignore[reportCallIssue] + from abc import ABCMeta, abstractmethod + class CustomABC(metaclass=ABCMeta): @abstractmethod - def foo(self, arg: int) -> int: - ... + def foo(self, arg: int) -> int: ... + @final class ABCConcrete(CustomABC): @@ -257,6 +265,7 @@ def abc_fn(self, arg: str) -> str: def abc_cm(cls, arg: str) -> str: return arg + ABCConcrete().abc_fn("1") ABCConcrete.abc_fn(ABCConcrete(), "1") ABCConcrete.abc_cm("")