|
12 | 12 | import mypy.errorcodes as codes
|
13 | 13 | from mypy import applytype, erasetype, join, message_registry, nodes, operators, types
|
14 | 14 | from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
|
15 |
| -from mypy.checkmember import analyze_member_access, type_object_type |
| 15 | +from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type |
16 | 16 | from mypy.checkstrformat import StringFormatterChecker
|
17 | 17 | from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
|
18 | 18 | from mypy.errors import ErrorWatcher, report_internal_error
|
|
98 | 98 | )
|
99 | 99 | from mypy.semanal_enum import ENUM_BASES
|
100 | 100 | from mypy.state import state
|
101 |
| -from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members |
| 101 | +from mypy.subtypes import ( |
| 102 | + find_member, |
| 103 | + is_equivalent, |
| 104 | + is_same_type, |
| 105 | + is_subtype, |
| 106 | + non_method_protocol_members, |
| 107 | +) |
102 | 108 | from mypy.traverser import has_await_expression
|
| 109 | +from mypy.type_visitor import TypeTranslator |
103 | 110 | from mypy.typeanal import (
|
104 | 111 | check_for_explicit_any,
|
105 | 112 | has_any_from_unimported_type,
|
|
114 | 121 | false_only,
|
115 | 122 | fixup_partial_type,
|
116 | 123 | function_type,
|
| 124 | + get_type_vars, |
117 | 125 | is_literal_type_like,
|
118 | 126 | make_simplified_union,
|
119 | 127 | simple_literal_type,
|
|
146 | 154 | TypedDictType,
|
147 | 155 | TypeOfAny,
|
148 | 156 | TypeType,
|
| 157 | + TypeVarLikeType, |
149 | 158 | TypeVarTupleType,
|
150 | 159 | TypeVarType,
|
151 | 160 | UninhabitedType,
|
@@ -300,6 +309,7 @@ def __init__(
|
300 | 309 | # on whether current expression is a callee, to give better error messages
|
301 | 310 | # related to type context.
|
302 | 311 | self.is_callee = False
|
| 312 | + type_state.infer_polymorphic = self.chk.options.new_type_inference |
303 | 313 |
|
304 | 314 | def reset(self) -> None:
|
305 | 315 | self.resolved_type = {}
|
@@ -1791,6 +1801,51 @@ def infer_function_type_arguments(
|
1791 | 1801 | inferred_args[0] = self.named_type("builtins.str")
|
1792 | 1802 | elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
|
1793 | 1803 | self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)
|
| 1804 | + |
| 1805 | + if self.chk.options.new_type_inference and any( |
| 1806 | + a is None |
| 1807 | + or isinstance(get_proper_type(a), UninhabitedType) |
| 1808 | + or set(get_type_vars(a)) & set(callee_type.variables) |
| 1809 | + for a in inferred_args |
| 1810 | + ): |
| 1811 | + # If the regular two-phase inference didn't work, try inferring type |
| 1812 | + # variables while allowing for polymorphic solutions, i.e. for solutions |
| 1813 | + # potentially involving free variables. |
| 1814 | + # TODO: support the similar inference for return type context. |
| 1815 | + poly_inferred_args = infer_function_type_arguments( |
| 1816 | + callee_type, |
| 1817 | + arg_types, |
| 1818 | + arg_kinds, |
| 1819 | + formal_to_actual, |
| 1820 | + context=self.argument_infer_context(), |
| 1821 | + strict=self.chk.in_checked_function(), |
| 1822 | + allow_polymorphic=True, |
| 1823 | + ) |
| 1824 | + for i, pa in enumerate(get_proper_types(poly_inferred_args)): |
| 1825 | + if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): |
| 1826 | + # Indicate that free variables should not be applied in the call below. |
| 1827 | + poly_inferred_args[i] = None |
| 1828 | + poly_callee_type = self.apply_generic_arguments( |
| 1829 | + callee_type, poly_inferred_args, context |
| 1830 | + ) |
| 1831 | + yes_vars = poly_callee_type.variables |
| 1832 | + no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} |
| 1833 | + if not set(get_type_vars(poly_callee_type)) & no_vars: |
| 1834 | + # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can |
| 1835 | + # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. |
| 1836 | + applied = apply_poly(poly_callee_type, yes_vars) |
| 1837 | + if applied is not None and poly_inferred_args != [UninhabitedType()] * len( |
| 1838 | + poly_inferred_args |
| 1839 | + ): |
| 1840 | + freeze_all_type_vars(applied) |
| 1841 | + return applied |
| 1842 | + # If it didn't work, erase free variables as <nothing>, to avoid confusing errors. |
| 1843 | + inferred_args = [ |
| 1844 | + expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) |
| 1845 | + if a is not None |
| 1846 | + else None |
| 1847 | + for a in inferred_args |
| 1848 | + ] |
1794 | 1849 | else:
|
1795 | 1850 | # In dynamically typed functions use implicit 'Any' types for
|
1796 | 1851 | # type variables.
|
@@ -5393,6 +5448,92 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
|
5393 | 5448 | return c.copy_modified(ret_type=new_ret_type)
|
5394 | 5449 |
|
5395 | 5450 |
|
| 5451 | +def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]: |
| 5452 | + """Make free type variables generic in the type if possible. |
| 5453 | +
|
| 5454 | + This will translate the type `tp` while trying to create valid bindings for |
| 5455 | + type variables `poly_tvars` while traversing the type. This follows the same rules |
| 5456 | + as we do during semantic analysis phase, examples: |
| 5457 | + * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T |
| 5458 | + * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T) |
| 5459 | + * List[T] -> None (not possible) |
| 5460 | + """ |
| 5461 | + try: |
| 5462 | + return tp.copy_modified( |
| 5463 | + arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], |
| 5464 | + ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)), |
| 5465 | + variables=[], |
| 5466 | + ) |
| 5467 | + except PolyTranslationError: |
| 5468 | + return None |
| 5469 | + |
| 5470 | + |
| 5471 | +class PolyTranslationError(Exception): |
| 5472 | + pass |
| 5473 | + |
| 5474 | + |
| 5475 | +class PolyTranslator(TypeTranslator): |
| 5476 | + """Make free type variables generic in the type if possible. |
| 5477 | +
|
| 5478 | + See docstring for apply_poly() for details. |
| 5479 | + """ |
| 5480 | + |
| 5481 | + def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: |
| 5482 | + self.poly_tvars = set(poly_tvars) |
| 5483 | + # This is a simplified version of TypeVarScope used during semantic analysis. |
| 5484 | + self.bound_tvars: set[TypeVarLikeType] = set() |
| 5485 | + self.seen_aliases: set[TypeInfo] = set() |
| 5486 | + |
| 5487 | + def visit_callable_type(self, t: CallableType) -> Type: |
| 5488 | + found_vars = set() |
| 5489 | + for arg in t.arg_types: |
| 5490 | + found_vars |= set(get_type_vars(arg)) & self.poly_tvars |
| 5491 | + |
| 5492 | + found_vars -= self.bound_tvars |
| 5493 | + self.bound_tvars |= found_vars |
| 5494 | + result = super().visit_callable_type(t) |
| 5495 | + self.bound_tvars -= found_vars |
| 5496 | + |
| 5497 | + assert isinstance(result, ProperType) and isinstance(result, CallableType) |
| 5498 | + result.variables = list(result.variables) + list(found_vars) |
| 5499 | + return result |
| 5500 | + |
| 5501 | + def visit_type_var(self, t: TypeVarType) -> Type: |
| 5502 | + if t in self.poly_tvars and t not in self.bound_tvars: |
| 5503 | + raise PolyTranslationError() |
| 5504 | + return super().visit_type_var(t) |
| 5505 | + |
| 5506 | + def visit_param_spec(self, t: ParamSpecType) -> Type: |
| 5507 | + # TODO: Support polymorphic apply for ParamSpec. |
| 5508 | + raise PolyTranslationError() |
| 5509 | + |
| 5510 | + def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: |
| 5511 | + # TODO: Support polymorphic apply for TypeVarTuple. |
| 5512 | + raise PolyTranslationError() |
| 5513 | + |
| 5514 | + def visit_type_alias_type(self, t: TypeAliasType) -> Type: |
| 5515 | + if not t.args: |
| 5516 | + return t.copy_modified() |
| 5517 | + if not t.is_recursive: |
| 5518 | + return get_proper_type(t).accept(self) |
| 5519 | + # We can't handle polymorphic application for recursive generic aliases |
| 5520 | + # without risking an infinite recursion, just give up for now. |
| 5521 | + raise PolyTranslationError() |
| 5522 | + |
| 5523 | + def visit_instance(self, t: Instance) -> Type: |
| 5524 | + # There is the same problem with callback protocols as with aliases |
| 5525 | + # (callback protocols are essentially more flexible aliases to callables). |
| 5526 | + # Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T]. |
| 5527 | + if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: |
| 5528 | + if t.type in self.seen_aliases: |
| 5529 | + raise PolyTranslationError() |
| 5530 | + self.seen_aliases.add(t.type) |
| 5531 | + call = find_member("__call__", t, t, is_operator=True) |
| 5532 | + assert call is not None |
| 5533 | + return call.accept(self) |
| 5534 | + return super().visit_instance(t) |
| 5535 | + |
| 5536 | + |
5396 | 5537 | class ArgInferSecondPassQuery(types.BoolTypeQuery):
|
5397 | 5538 | """Query whether an argument type should be inferred in the second pass.
|
5398 | 5539 |
|
|
0 commit comments