Skip to content

Improve support for functools.partial of overloaded callable protocol #18639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 48 additions & 41 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,50 +701,57 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
"""Get type as seen by an overload item caller."""
inner_type = get_proper_type(inner_type)
outer_type: CallableType | None = None
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, TypeVarLikeType):
inner_type = get_proper_type(inner_type.upper_bound)
if isinstance(inner_type, TypeType):
inner_type = get_proper_type(
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
)
outer_type: FunctionLike | None = None
if inner_type is None or isinstance(inner_type, AnyType):
return None
if isinstance(inner_type, TypeVarLikeType):
inner_type = get_proper_type(inner_type.upper_bound)
if isinstance(inner_type, TypeType):
inner_type = get_proper_type(
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
)

if isinstance(inner_type, CallableType):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
if isinstance(inner_type, FunctionLike):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
elif isinstance(inner_type, UnionType):
union_type = make_simplified_union(inner_type.items)
if isinstance(union_type, UnionType):
items = []
for item in union_type.items:
callable_item = self.extract_callable_type(item, ctx)
if callable_item is None:
break
items.append(callable_item)
else:
joined_type = get_proper_type(join.join_type_list(items))
if isinstance(joined_type, CallableType):
outer_type = joined_type
)
if isinstance(inner_call, FunctionLike):
outer_type = inner_call
elif isinstance(inner_type, UnionType):
union_type = make_simplified_union(inner_type.items)
if isinstance(union_type, UnionType):
items = []
for item in union_type.items:
callable_item = self.extract_callable_type(item, ctx)
if callable_item is None:
break
items.append(callable_item)
else:
return self.extract_callable_type(union_type, ctx)
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
joined_type = get_proper_type(join.join_type_list(items))
if isinstance(joined_type, FunctionLike):
outer_type = joined_type
else:
return self.extract_callable_type(union_type, ctx)

if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return None
if isinstance(outer_type, Overloaded):
return None

assert isinstance(outer_type, CallableType)
return outer_type

def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,20 @@ hp = partial(h, 1)
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialOverloadedCallableProtocol]
from functools import partial
from typing import Callable, Protocol, overload

class P(Protocol):
@overload
def __call__(self, x: int) -> int: ...
@overload
def __call__(self, x: str) -> str: ...

def f(x: P):
reveal_type(partial(x, 1)()) # N: Revealed type is "builtins.int"

# TODO: but this is incorrect, predating the functools.partial plugin
reveal_type(partial(x, "a")()) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]