|
5 | 5 | since these may assume that MROs are ready. |
6 | 6 | """ |
7 | 7 |
|
8 | | -from typing import cast, Optional, List, Sequence, Set |
| 8 | +from typing import cast, Optional, List, Sequence, Set, Iterable |
9 | 9 | import sys |
10 | 10 |
|
11 | 11 | from mypy.types import ( |
12 | 12 | TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, |
13 | 13 | TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, |
14 | 14 | AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, |
15 | | - copy_type, TypeAliasType |
| 15 | + copy_type, TypeAliasType, TypeQuery |
16 | 16 | ) |
17 | 17 | from mypy.nodes import ( |
18 | 18 | FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS, |
@@ -215,23 +215,29 @@ class B(A): pass |
215 | 215 | original_type = erase_to_bound(self_param_type) |
216 | 216 | original_type = get_proper_type(original_type) |
217 | 217 |
|
218 | | - ids = [x.id for x in func.variables] |
219 | | - typearg = get_proper_type(infer_type_arguments(ids, self_param_type, |
220 | | - original_type, is_supertype=True)[0]) |
221 | | - if (is_classmethod and isinstance(typearg, UninhabitedType) |
| 218 | + all_ids = [x.id for x in func.variables] |
| 219 | + typeargs = infer_type_arguments(all_ids, self_param_type, original_type, |
| 220 | + is_supertype=True) |
| 221 | + if (is_classmethod |
| 222 | + # TODO: why do we need the extra guards here? |
| 223 | + and any(isinstance(get_proper_type(t), UninhabitedType) for t in typeargs) |
222 | 224 | and isinstance(original_type, (Instance, TypeVarType, TupleType))): |
223 | 225 | # In case we call a classmethod through an instance x, fallback to type(x) |
224 | | - typearg = get_proper_type(infer_type_arguments(ids, self_param_type, |
225 | | - TypeType(original_type), |
226 | | - is_supertype=True)[0]) |
| 226 | + typeargs = infer_type_arguments(all_ids, self_param_type, TypeType(original_type), |
| 227 | + is_supertype=True) |
| 228 | + |
| 229 | + ids = [tid for tid in all_ids |
| 230 | + if any(tid == t.id for t in get_type_vars(self_param_type))] |
| 231 | + |
| 232 | + # Technically, some constrains might be unsolvable, make them <nothing>. |
| 233 | + to_apply = [t if t is not None else UninhabitedType() for t in typeargs] |
227 | 234 |
|
228 | 235 | def expand(target: Type) -> Type: |
229 | | - assert typearg is not None |
230 | | - return expand_type(target, {func.variables[0].id: typearg}) |
| 236 | + return expand_type(target, {id: to_apply[all_ids.index(id)] for id in ids}) |
231 | 237 |
|
232 | 238 | arg_types = [expand(x) for x in func.arg_types[1:]] |
233 | 239 | ret_type = expand(func.ret_type) |
234 | | - variables = func.variables[1:] |
| 240 | + variables = [v for v in func.variables if v.id not in ids] |
235 | 241 | else: |
236 | 242 | arg_types = func.arg_types[1:] |
237 | 243 | ret_type = func.ret_type |
@@ -587,3 +593,21 @@ def coerce_to_literal(typ: Type) -> ProperType: |
587 | 593 | if len(enum_values) == 1: |
588 | 594 | return LiteralType(value=enum_values[0], fallback=typ) |
589 | 595 | return typ |
| 596 | + |
| 597 | + |
| 598 | +def get_type_vars(tp: Type) -> List[TypeVarType]: |
| 599 | + return tp.accept(TypeVarExtractor()) |
| 600 | + |
| 601 | + |
| 602 | +class TypeVarExtractor(TypeQuery[List[TypeVarType]]): |
| 603 | + def __init__(self) -> None: |
| 604 | + super().__init__(self._merge) |
| 605 | + |
| 606 | + def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]: |
| 607 | + out = [] |
| 608 | + for item in iter: |
| 609 | + out.extend(item) |
| 610 | + return out |
| 611 | + |
| 612 | + def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]: |
| 613 | + return [t] |
0 commit comments