Skip to content

Add type signatures to lru_cache using descriptor protocol (not ready to merge) #13043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions stdlib/@tests/test_cases/check_lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

# pyright: reportUnnecessaryTypeIgnoreComment=true
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Self, TypeVar, assert_type, final, overload, override

P = ParamSpec("P")
R = TypeVar("R")


@cache
def cached_fn(arg: int, arg2: str) -> int:
return arg


@dataclass
class MemberVarCached(Generic[P, R]):
member_callable: Callable[P, R]


@cache
def cached_fn_takes_t(arg: MemberVarCached[..., Any], arg2: str) -> int:
return 1


vc = MemberVarCached(cached_fn)
vc.member_callable(1, "")
assert_type(vc.member_callable(1, ""), int)

vc_t = MemberVarCached(cached_fn_takes_t)
vc_t.member_callable(vc_t, "")

if TYPE_CHECKING:
# type errors - correct
vc_t.member_callable("") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue]

Check failure on line 38 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)
vc.member_callable(1, 1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

Check failure on line 39 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportArgumentType" (reportUnnecessaryTypeIgnoreComment)
vc.member_callable(1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue]

Check failure on line 40 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)
vc.member_callable("1") # type: ignore[call-arg,arg-type] # pyright: ignore[reportCallIssue]

Check failure on line 41 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)


class CFnCls:
@cache
def fn(self, arg: int) -> int:
print("method fn called")
return arg

@cache
def fn_bad_self_name(this_t, arg: int) -> int:
print("method fn called")
return arg

@classmethod
@cache
def cls_fn(cls, arg: int) -> int:
print("class fn called")
return arg

@classmethod
@cache
def cls_fn_bad_name(my_class_type, arg: int) -> int:
print("class fn called")
return arg

@classmethod
@cache
def cls_fn_positional_only(__cls, arg: int) -> int:
print("class fn called")
return arg

@classmethod
@cache
def cls_fn_explicit_positional_only(_cls, arg: int, /) -> int:
print("class fn called")
return arg

@staticmethod
@cache
def st_fn(arg: int) -> int:
print("static fn called")
return arg

@staticmethod
@cache
def st_fn_clst_arg(arg: type[CFnCls]) -> int:
print("static fn called")
return 1

@staticmethod
@cache
def st_fn_strt_arg(arg: type[str]) -> int:
print("static fn called")
return 1

@staticmethod
@cache
def st_fn_self_t_arg(arg: CFnCls) -> int:
print("static fn called")
return 1

@property
@cache
def prp_fn(self) -> int:
print("property fn called")
return 1


class CFnSubCls(CFnCls):
pass


cfn_inst = CFnCls()
cfn_inst.fn(1)
cfn_inst.fn.__wrapped__(cfn_inst, 1)
cfn_inst.prp_fn
CFnCls.fn.__wrapped__(CFnCls(), 1)
CFnCls.st_fn(1)
CFnCls.st_fn.__wrapped__(1)
CFnCls.cls_fn(1)
CFnCls.cls_fn.__wrapped__(CFnCls, 1)
cfn_inst.fn(1)
CFnCls.st_fn(1)
CFnCls.cls_fn(1)
CFnCls.cls_fn_bad_name(1)
CFnCls.cls_fn_positional_only(1)
CFnCls().cls_fn_positional_only(1)
CFnCls.cls_fn_explicit_positional_only(1)
CFnCls().cls_fn_explicit_positional_only(1)

# incorrect type error. If a static method
# takes the type of the enclosing class as
# the first argument there's a false positive.
CFnCls.st_fn_clst_arg(CFnCls)

Check failure on line 135 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Expected 0 positional arguments (reportCallIssue)
CFnSubCls.st_fn_clst_arg(CFnSubCls)

Check failure on line 136 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Expected 0 positional arguments (reportCallIssue)
# If a static method takes an instance of the
# enclosing class as its first argument,
# there's an error if accessed via an instance.
CFnSubCls.st_fn_self_t_arg(CFnSubCls())
CFnSubCls().st_fn_self_t_arg(CFnSubCls())

Check failure on line 141 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Expected 0 positional arguments (reportCallIssue)
Comment on lines +132 to +141
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably isn't fixable but may be acceptable. Error in both mypy & pyright.

# but a different type is fine
CFnCls.st_fn_strt_arg(str)

assert_type(cfn_inst.fn(1), int)
assert_type(CFnCls.st_fn(1), int)
assert_type(CFnCls.cls_fn(1), int)
assert_type(cfn_inst.prp_fn, int)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy is erroring on using property with cache. pyright is fine. Not sure why. There's also cached_property but this is a valid use case too.

CFnCls().fn.cache_clear()
CFnCls.fn.cache_clear()
CFnCls.st_fn.cache_clear()
CFnCls.cls_fn.cache_clear()
if TYPE_CHECKING:
# type errors - correct
CFnCls().fn(1, 1) # type: ignore[call-arg,misc] # pyright: ignore[reportCallIssue]

Check failure on line 155 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)
CFnCls().fn.__wrapped__(CFnCls(), 1, 1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue]

Check failure on line 156 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)
CFnCls.fn(arg=1) # type: ignore[call-arg] # pyright: ignore[reportCallIssue]

Check failure on line 157 in stdlib/@tests/test_cases/check_lru_cache.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.12)

Unnecessary "# pyright: ignore" rule: "reportCallIssue" (reportUnnecessaryTypeIgnoreComment)
CFnCls.st_fn(1, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue]
CFnCls.cls_fn(1, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue]
CFnCls.cls_fn_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue]
CFnCls().cls_fn_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue]
CFnCls.cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[arg-type,call-arg] # pyright: ignore[reportCallIssue]
CFnCls().cls_fn_explicit_positional_only(CFnCls, 1) # type: ignore[call-arg,arg-type,misc] # pyright: ignore[reportCallIssue]


@cache
def fn(arg: int) -> int:
return arg


@cache
def df_fn(arg: int, darg: str = "default"):
print("default fn called")
return darg


df_fn(1)

fn(1)
assert_type(fn(1), int)
if TYPE_CHECKING:
# type error - correct
fn(1, 2) # type: ignore[call-arg] # pyright: ignore[reportCallIssue]
fn.cache_clear()


@overload
@cache
def fn_overload(arg: int) -> int: ...
@overload
@cache
def fn_overload(arg: str) -> str:
return arg


@cache
def fn_overload(arg: int | str) -> int | str:
return arg


fn_overload(1)
fn_overload("1")
# behavior varies between type checkers.
assert_type(fn_overload(1), int)
assert_type(fn_overload("1"), str)
Comment on lines +204 to +205
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overloads don't work well with this change. In pyright types are resolved as int | str but cache_clear is correctly accessible.

Mypy correctly resolves return type but cache_clear isn't accessible.

fn_overload.cache_clear()
if TYPE_CHECKING:
# type error - correct
fn_overload(frozenset({1, 2})) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]


class Unhashable:
@override
def __eq__(self, value: object) -> bool:
return False


@cache
def no_cache(arg: Unhashable, arg2: int) -> None:
pass


if TYPE_CHECKING:
# This is not correctly rejected.
no_cache(Unhashable(), 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


class MemberVarBound(Generic[P, R]):
@cache
def equals(self, other: Self) -> bool:
return False

member_fn: Callable[P, R]


def set_member(lhs: MemberVarBound[..., Any], rhs: MemberVarBound[..., Any]) -> None:
lhs.member_fn = rhs.equals
lhs.member_fn(rhs)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't resolve the callable type signature but pyright does, so this raises an error in mypy. Not sure why.


if TYPE_CHECKING:
lhs.member_fn() # type: ignore[call-arg] # pyright: ignore[reportCallIssue]


from abc import ABCMeta, abstractmethod


class CustomABC(metaclass=ABCMeta):

@abstractmethod
def foo(self, arg: int) -> int: ...


@final
class ABCConcrete(CustomABC):
@override
def foo(self, arg: int) -> int:
return 1

@cache
def abc_fn(self, arg: str) -> str:
return arg

@classmethod
@cache
def abc_cm(cls, arg: str) -> str:
return arg


ABCConcrete().abc_fn("1")
ABCConcrete.abc_fn(ABCConcrete(), "1")
ABCConcrete.abc_cm("")
40 changes: 31 additions & 9 deletions stdlib/functools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import types
from _typeshed import SupportsAllComparisons, SupportsItems
from collections.abc import Callable, Hashable, Iterable, Sequence, Sized
from typing import Any, Generic, Literal, NamedTuple, TypedDict, TypeVar, final, overload
from typing_extensions import ParamSpec, Self, TypeAlias
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias

if sys.version_info >= (3, 9):
from types import GenericAlias
Expand All @@ -28,8 +28,10 @@ if sys.version_info >= (3, 9):
__all__ += ["cache"]

_T = TypeVar("_T")
_R = TypeVar("_R")
_T_co = TypeVar("_T_co", covariant=True)
_S = TypeVar("_S")
_P = ParamSpec("_P")
_PWrapped = ParamSpec("_PWrapped")
_RWrapped = TypeVar("_RWrapped")
_PWrapper = ParamSpec("_PWrapper")
Expand All @@ -51,22 +53,42 @@ if sys.version_info >= (3, 9):
maxsize: int
typed: bool

_C = TypeVar("_C", bound=Callable[..., Any])
_W = TypeVar("_W", bound=Callable[..., Any])
_CacheFunction: TypeAlias = _lru_cache_wrapper[Callable[_P, _R], _W]
_CacheMethod: TypeAlias = _lru_cache_wrapper[Callable[Concatenate[_T, _P], _R], _W]
_CacheClassmethod: TypeAlias = _lru_cache_wrapper[Callable[Concatenate[type[_T], _P], _R], _W]

@final
class _lru_cache_wrapper(Generic[_T]):
__wrapped__: Callable[..., _T]
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _T: ...
class _lru_cache_wrapper(Generic[_C, _W]):
__wrapped__: _W
def __call__(_self: _CacheFunction[_P, _R, _W], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
def cache_info(self) -> _CacheInfo: ...
def cache_clear(self) -> None: ...
if sys.version_info >= (3, 9):
def cache_parameters(self) -> _CacheParameters: ...

def __copy__(self) -> _lru_cache_wrapper[_T]: ...
def __deepcopy__(self, memo: Any, /) -> _lru_cache_wrapper[_T]: ...
def __copy__(self) -> Self: ...
def __deepcopy__(self, memo: Any, /) -> Self: ...
@overload
def __get__(self: _CacheMethod[_T, _P, _R, _W], instance: None, owner: type[_T]) -> _CacheMethod[_T, _P, _R, _W]: ...
@overload
def __get__(self: _CacheMethod[_T, _P, _R, _W], instance: _T, owner: type[_T] | None = ...) -> _CacheFunction[_P, _R, _W]: ...
@overload
def __get__(self: _CacheClassmethod[_T, _P, _R, _W], instance: _T | None, owner: type[_T]) -> _CacheFunction[_P, _R, _W]: ...
@overload
def __get__(
self: _CacheClassmethod[_T, _P, _R, _W], instance: _T, owner: type[_T] | None = None
) -> _CacheFunction[_P, _R, _W]: ...
@overload
def __get__(
self: _CacheFunction[_P, _R, _W], instance: Any | None, owner: type[Any] | None = None
) -> _CacheFunction[_P, _R, _W]: ...

@overload
def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ...
def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[_C], _lru_cache_wrapper[_C, _C]]: ...
@overload
def lru_cache(maxsize: Callable[..., _T], typed: bool = False) -> _lru_cache_wrapper[_T]: ...
def lru_cache(maxsize: _C, typed: bool = False) -> _lru_cache_wrapper[_C, _C]: ...

if sys.version_info >= (3, 12):
WRAPPER_ASSIGNMENTS: tuple[
Expand Down Expand Up @@ -199,7 +221,7 @@ class cached_property(Generic[_T_co]):
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...

if sys.version_info >= (3, 9):
def cache(user_function: Callable[..., _T], /) -> _lru_cache_wrapper[_T]: ...
def cache(user_function: _C, /) -> _lru_cache_wrapper[_C, _C]: ...

def _make_key(
args: tuple[Hashable, ...],
Expand Down
Loading