diff --git a/mypy/checker.py b/mypy/checker.py index 3734f3170790..dcbc7a845b15 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -701,50 +701,57 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None: """Get type as seen by an overload item caller.""" inner_type = get_proper_type(inner_type) - outer_type: CallableType | None = None - if inner_type is not None and not isinstance(inner_type, AnyType): - if isinstance(inner_type, TypeVarLikeType): - inner_type = get_proper_type(inner_type.upper_bound) - if isinstance(inner_type, TypeType): - inner_type = get_proper_type( - self.expr_checker.analyze_type_type_callee(inner_type.item, ctx) - ) + outer_type: FunctionLike | None = None + if inner_type is None or isinstance(inner_type, AnyType): + return None + if isinstance(inner_type, TypeVarLikeType): + inner_type = get_proper_type(inner_type.upper_bound) + if isinstance(inner_type, TypeType): + inner_type = get_proper_type( + self.expr_checker.analyze_type_type_callee(inner_type.item, ctx) + ) - if isinstance(inner_type, CallableType): - outer_type = inner_type - elif isinstance(inner_type, Instance): - inner_call = get_proper_type( - analyze_member_access( - name="__call__", - typ=inner_type, - context=ctx, - is_lvalue=False, - is_super=False, - is_operator=True, - msg=self.msg, - original_type=inner_type, - chk=self, - ) + if isinstance(inner_type, FunctionLike): + outer_type = inner_type + elif isinstance(inner_type, Instance): + inner_call = get_proper_type( + analyze_member_access( + name="__call__", + typ=inner_type, + context=ctx, + is_lvalue=False, + is_super=False, + is_operator=True, + msg=self.msg, + original_type=inner_type, + chk=self, ) - if isinstance(inner_call, CallableType): - outer_type = inner_call - elif isinstance(inner_type, UnionType): - union_type = make_simplified_union(inner_type.items) - if isinstance(union_type, UnionType): - items = [] - for item in union_type.items: - callable_item = self.extract_callable_type(item, ctx) - if callable_item is None: - break - items.append(callable_item) - else: - joined_type = get_proper_type(join.join_type_list(items)) - if isinstance(joined_type, CallableType): - outer_type = joined_type + ) + if isinstance(inner_call, FunctionLike): + outer_type = inner_call + elif isinstance(inner_type, UnionType): + union_type = make_simplified_union(inner_type.items) + if isinstance(union_type, UnionType): + items = [] + for item in union_type.items: + callable_item = self.extract_callable_type(item, ctx) + if callable_item is None: + break + items.append(callable_item) else: - return self.extract_callable_type(union_type, ctx) - if outer_type is None: - self.msg.not_callable(inner_type, ctx) + joined_type = get_proper_type(join.join_type_list(items)) + if isinstance(joined_type, FunctionLike): + outer_type = joined_type + else: + return self.extract_callable_type(union_type, ctx) + + if outer_type is None: + self.msg.not_callable(inner_type, ctx) + return None + if isinstance(outer_type, Overloaded): + return None + + assert isinstance(outer_type, CallableType) return outer_type def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index 22159580163d..5bdc3ce7a352 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -640,3 +640,20 @@ hp = partial(h, 1) reveal_type(hp(1)) # N: Revealed type is "builtins.int" hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int" [builtins fixtures/tuple.pyi] + +[case testFunctoolsPartialOverloadedCallableProtocol] +from functools import partial +from typing import Callable, Protocol, overload + +class P(Protocol): + @overload + def __call__(self, x: int) -> int: ... + @overload + def __call__(self, x: str) -> str: ... + +def f(x: P): + reveal_type(partial(x, 1)()) # N: Revealed type is "builtins.int" + + # TODO: but this is incorrect, predating the functools.partial plugin + reveal_type(partial(x, "a")()) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi]