Skip to content

Refine how overload selection handles *args, **kwargs, and Any #5166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,18 +1195,27 @@ def plausible_overload_call_targets(self,
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded) -> List[CallableType]:
"""Returns all overload call targets that having matching argument counts."""
"""Returns all overload call targets that having matching argument counts.

If the given args contains a star-arg (*arg or **kwarg argument), this method
will ensure all star-arg overloads appear at the start of the list, instead
of their usual location."""
matches = [] # type: List[CallableType]
star_matches = [] # type: List[CallableType]
args_have_star = ARG_STAR in arg_kinds or ARG_STAR2 in arg_kinds
for typ in overload.items():
formal_to_actual = map_actuals_to_formals(arg_kinds, arg_names,
typ.arg_kinds, typ.arg_names,
lambda i: arg_types[i])

if self.check_argument_count(typ, arg_types, arg_kinds, arg_names,
formal_to_actual, None, None):
matches.append(typ)
if args_have_star and (typ.is_var_arg or typ.is_kw_arg):
star_matches.append(typ)
else:
matches.append(typ)

return matches
return star_matches + matches

def infer_overload_return_type(self,
plausible_targets: List[CallableType],
Expand Down Expand Up @@ -1270,15 +1279,20 @@ def infer_overload_return_type(self,
return None
elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
# An argument of type or containing the type 'Any' caused ambiguity.
# We infer a type of 'Any'
return self.check_call(callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
if all(is_subtype(ret_type, return_types[-1]) for ret_type in return_types[:-1]):
# The last match is a supertype of all the previous ones, so it's safe
# to return that inferred type.
return return_types[-1], inferred_types[-1]
else:
# We give up and return 'Any'.
return self.check_call(callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
else:
# Success! No ambiguity; return the first match.
return return_types[0], inferred_types[0]
Expand Down
137 changes: 128 additions & 9 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def f(x: object) -> object: ...
def f(x): pass

a: Any
reveal_type(f(a)) # E: Revealed type is 'Any'
reveal_type(f(a)) # E: Revealed type is 'builtins.object'

[case testOverloadWithOverlappingItemsAndAnyArgument2]
from typing import overload, Any
Expand All @@ -1288,7 +1288,7 @@ def f(x: float) -> float: ...
def f(x): pass

a: Any
reveal_type(f(a)) # E: Revealed type is 'Any'
reveal_type(f(a)) # E: Revealed type is 'builtins.float'

[case testOverloadWithOverlappingItemsAndAnyArgument3]
from typing import overload, Any
Expand All @@ -1313,15 +1313,15 @@ def f(x): pass

a: Any
# Any causes ambiguity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is now less clear, I would rather remove it or clarify.

reveal_type(f(a, 1, '')) # E: Revealed type is 'Any'
reveal_type(f(a, 1, '')) # E: Revealed type is 'builtins.object'
# Any causes no ambiguity
reveal_type(f(1, a, a)) # E: Revealed type is 'builtins.int'
reveal_type(f('', a, a)) # E: Revealed type is 'builtins.object'
# Like above, but use keyword arguments.
reveal_type(f(y=1, z='', x=a)) # E: Revealed type is 'Any'
reveal_type(f(y=1, z='', x=a)) # E: Revealed type is 'builtins.object'
reveal_type(f(y=a, z='', x=1)) # E: Revealed type is 'builtins.int'
reveal_type(f(z='', x=1, y=a)) # E: Revealed type is 'builtins.int'
reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'Any'
reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'builtins.object'

[case testOverloadWithOverlappingItemsAndAnyArgument5]
from typing import overload, Any, Union
Expand All @@ -1333,7 +1333,7 @@ def f(x: Union[int, float]) -> float: ...
def f(x): pass

a: Any
reveal_type(f(a)) # E: Revealed type is 'Any'
reveal_type(f(a)) # E: Revealed type is 'builtins.float'

[case testOverloadWithOverlappingItemsAndAnyArgument6]
from typing import overload, Any
Expand All @@ -1343,7 +1343,7 @@ def f(x: int, y: int) -> int: ...
@overload
def f(x: float, y: int, z: str) -> float: ...
@overload
def f(x: object, y: int, z: str, a: None) -> object: ...
def f(x: object, y: int, z: str, a: None) -> str: ...
def f(x): pass

a: Any
Expand All @@ -1352,7 +1352,7 @@ reveal_type(f(*a)) # E: Revealed type is 'Any'
reveal_type(f(a, *a)) # E: Revealed type is 'Any'
reveal_type(f(1, *a)) # E: Revealed type is 'Any'
reveal_type(f(1.1, *a)) # E: Revealed type is 'Any'
reveal_type(f('', *a)) # E: Revealed type is 'builtins.object'
reveal_type(f('', *a)) # E: Revealed type is 'builtins.str'

[case testOverloadWithOverlappingItemsAndAnyArgument7]
from typing import overload, Any
Expand All @@ -1365,7 +1365,7 @@ def f(x): pass

a: Any
# TODO: We could infer 'int' here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this TODO looks interesting. Would it be hard to fix?

reveal_type(f(1, *a)) # E: Revealed type is 'Any'
reveal_type(f(1, *a)) # E: Revealed type is 'builtins.object'

[case testOverloadWithOverlappingItemsAndAnyArgument8]
from typing import overload, Any
Expand All @@ -1381,6 +1381,26 @@ a: Any
reveal_type(f(a, 1, 1)) # E: Revealed type is 'builtins.str'
reveal_type(f(1, *a)) # E: Revealed type is 'builtins.str'

[case testOverloadWithOverlappingItemsAndAnyArgument9]
from typing import overload, Any, List

@overload
def f(x: List[int]) -> List[int]: ...
@overload
def f(x: List[Any]) -> List[Any]: ...
def f(x): pass

a: Any
b: List[Any]
c: List[str]
d: List[int]
reveal_type(f(a)) # E: Revealed type is 'builtins.list[Any]'
reveal_type(f(b)) # E: Revealed type is 'builtins.list[Any]'
reveal_type(f(c)) # E: Revealed type is 'builtins.list[Any]'
reveal_type(f(d)) # E: Revealed type is 'builtins.list[builtins.int]'

[builtins fixtures/list.pyi]

[case testOverloadOnOverloadWithType]
from typing import Any, Type, TypeVar, overload
from mod import MyInt
Expand Down Expand Up @@ -1723,6 +1743,105 @@ def foo2(**kwargs: int) -> str: ...
def foo2(*args: int) -> int: ... # E: Overloaded function signature 2 will never be matched: function 1's parameter type(s) are the same or broader
[builtins fixtures/dict.pyi]

[case testOverloadVarargInputAndVarargDefinition]
from typing import overload, List

class A: ...
class B: ...
class C: ...

@overload
def foo(x: int) -> A: ...
@overload
def foo(x: int, y: int) -> B: ...
@overload
def foo(x: int, y: int, z: int, *args: int) -> C: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a test with overload like this:

@overload
def f(x: int) -> Tuple[int]: ...
@overload
def f(x: int, y: int) -> Tuple[int, int]: ...
@overload
def f(*xs: int) -> Tuple[int, ...]: ...

And check how it various calls with and without star.

def foo(*args): pass

reveal_type(foo(1)) # E: Revealed type is '__main__.A'
reveal_type(foo(1, 2)) # E: Revealed type is '__main__.B'
reveal_type(foo(1, 2, 3)) # E: Revealed type is '__main__.C'

reveal_type(foo(*[1])) # E: Revealed type is '__main__.C'
reveal_type(foo(*[1, 2])) # E: Revealed type is '__main__.C'
reveal_type(foo(*[1, 2, 3])) # E: Revealed type is '__main__.C'

x: List[int]
reveal_type(foo(*x)) # E: Revealed type is '__main__.C'

y: List[str]
foo(*y) # E: No overload variant of "foo" matches argument type "List[str]"
[builtins fixtures/list.pyi]

[case testOverloadMultipleVarargDefinition]
from typing import overload, List, Any

class A: ...
class B: ...
class C: ...
class D: ...

@overload
def foo(x: int) -> A: ...
@overload
def foo(x: int, y: int) -> B: ...
@overload
def foo(x: int, y: int, z: int, *args: int) -> C: ...
@overload
def foo(*x: str) -> D: ...
def foo(*args): pass

reveal_type(foo(*[1, 2])) # E: Revealed type is '__main__.C'
reveal_type(foo(*["a", "b"])) # E: Revealed type is '__main__.D'

x: List[Any]
reveal_type(foo(*x)) # E: Revealed type is 'Any'
[builtins fixtures/list.pyi]

[case testOverloadMultipleVarargDefinitionComplex]
from typing import TypeVar, overload, Any, Callable

T1 = TypeVar('T1')
T2 = TypeVar('T2')
T3 = TypeVar('T3')

@overload
def chain_call(input_value: T1,
f1: Callable[[T1], T2]) -> T2: ...
@overload
def chain_call(input_value: T1,
f1: Callable[[T1], T2],
f2: Callable[[T2], T3]) -> T3: ...
@overload
def chain_call(input_value: T1,
*f_rest: Callable[[T1], T1]) -> T1: ...
@overload
def chain_call(input_value: T1,
f1: Callable[[T1], T2],
f2: Callable[[T2], T3],
f3: Callable[[T3], Any],
*f_rest: Callable[[Any], Any]) -> Any: ...
def chain_call(input_value, *f_rest):
for function in f_rest:
input_value = function(input_value)
return input_value


class A: ...
class B: ...
class C: ...
class D: ...

def f(x: A) -> A: ...
def f1(x: A) -> B: ...
def f2(x: B) -> C: ...
def f3(x: C) -> D: ...

reveal_type(chain_call(A(), f1, f2)) # E: Revealed type is '__main__.C*'
reveal_type(chain_call(A(), f1, f2, f3)) # E: Revealed type is 'Any'
reveal_type(chain_call(A(), f, f, f, f)) # E: Revealed type is '__main__.A'
[builtins fixtures/list.pyi]

[case testOverloadWithPartiallyOverlappingUnions]
from typing import overload, Union

Expand Down