Skip to content

Implement basic subtyping & inferrence for variadic classes. #13105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike,
TypeVarLikeType, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType,
DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType, ParamSpecType,
ENUM_REMOVED_PROPS
TypeVarTupleType, ENUM_REMOVED_PROPS
)
from mypy.nodes import (
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, SymbolTable, Context,
Expand Down Expand Up @@ -693,6 +693,7 @@ def f(self: S) -> T: ...
new_items = []
if is_classmethod:
dispatched_arg_type = TypeType.make_normalized(dispatched_arg_type)

for item in items:
if not item.arg_types or item.arg_kinds[0] not in (ARG_POS, ARG_STAR):
# No positional first (self) argument (*args is okay).
Expand All @@ -701,12 +702,14 @@ def f(self: S) -> T: ...
# there is at least one such error.
return functype
else:
selfarg = item.arg_types[0]
selfarg = get_proper_type(item.arg_types[0])
if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))):
new_items.append(item)
elif isinstance(selfarg, ParamSpecType):
# TODO: This is not always right. What's the most reasonable thing to do here?
new_items.append(item)
elif isinstance(selfarg, TypeVarTupleType):
raise NotImplementedError
if not new_items:
# Choose first item for the message (it may be not very helpful for overloads).
msg.incompatible_self_argument(name, dispatched_arg_type, items[0],
Expand Down
93 changes: 65 additions & 28 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any,
UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES, TypeVarTupleType,
TypeList,
)
from mypy.maptype import map_instance_to_supertype
import mypy.subtypes
Expand All @@ -18,6 +19,12 @@
from mypy.nodes import COVARIANT, CONTRAVARIANT, ArgKind
from mypy.argmap import ArgTypeExpander
from mypy.typestate import TypeState
from mypy.typevartuples import (
split_with_instance,
split_with_prefix_and_suffix,
extract_unpack,
find_unpack_in_list,
)

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
Expand Down Expand Up @@ -486,15 +493,60 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
elif isinstance(tvar, TypeVarTupleType):
raise NotImplementedError

return res
elif (self.direction == SUPERTYPE_OF and
instance.type.has_base(template.type.fullname)):
mapped = map_instance_to_supertype(instance, template.type)
tvars = template.type.defn.type_vars
if template.type.has_type_var_tuple_type:
mapped_prefix, mapped_middle, mapped_suffix = (
split_with_instance(mapped)
)
template_prefix, template_middle, template_suffix = (
split_with_instance(template)
)

# Add a constraint for the type var tuple, and then
# remove it for the case below.
template_unpack = extract_unpack(template_middle)
if template_unpack is not None:
if isinstance(template_unpack, TypeVarTupleType):
res.append(Constraint(
template_unpack.id,
SUPERTYPE_OF,
TypeList(list(mapped_middle))
))
elif (
isinstance(template_unpack, Instance) and
template_unpack.type.fullname == "builtins.tuple"
):
# TODO: check homogenous tuple case
raise NotImplementedError
elif isinstance(template_unpack, TupleType):
# TODO: check tuple case
raise NotImplementedError

mapped_args = mapped_prefix + mapped_suffix
template_args = template_prefix + template_suffix

assert template.type.type_var_tuple_prefix is not None
assert template.type.type_var_tuple_suffix is not None
tvars_prefix, _, tvars_suffix = split_with_prefix_and_suffix(
tuple(tvars),
template.type.type_var_tuple_prefix,
template.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
else:
mapped_args = mapped.args
template_args = template.args
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args):
for tvar, mapped_arg, template_arg in zip(tvars, mapped_args, template_args):
assert not isinstance(tvar, TypeVarTupleType)
if isinstance(tvar, TypeVarType):
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
Expand Down Expand Up @@ -573,6 +625,8 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
return []
elif isinstance(actual, ParamSpecType):
return infer_constraints(template, actual.upper_bound, self.direction)
elif isinstance(actual, TypeVarTupleType):
raise NotImplementedError
else:
return []

Expand Down Expand Up @@ -696,13 +750,12 @@ def infer_against_overloaded(self, overloaded: Overloaded,

def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
actual = self.actual
# TODO: Support other items in the tuple besides Unpack
# TODO: Support subclasses of Tuple
is_varlength_tuple = (
isinstance(actual, Instance)
and actual.type.fullname == "builtins.tuple"
)
unpack_index = find_unpack_in_tuple(template)
unpack_index = find_unpack_in_list(template.items)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to migrate this to use the other helpers, but couldn't figure out how to cleanly do it.


if unpack_index is not None:
unpack_item = get_proper_type(template.items[unpack_index])
Expand All @@ -727,16 +780,15 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
modified_actual = actual
if isinstance(actual, TupleType):
# Exclude the items from before and after the unpack index.
head = unpack_index
tail = len(template.items) - unpack_index - 1
if tail:
modified_actual = actual.copy_modified(
items=actual.items[head:-tail],
)
else:
modified_actual = actual.copy_modified(
items=actual.items[head:],
)
# TODO: Support including constraints from the prefix/suffix.
_, actual_items, _ = split_with_prefix_and_suffix(
tuple(actual.items),
unpack_index,
len(template.items) - unpack_index - 1,
)
modified_actual = actual.copy_modified(
items=list(actual_items)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also infer constraints for the prefix and suffix? I guess they may contain regular type variable types?

return [Constraint(
type_var=unpacked_type.id,
op=self.direction,
Expand Down Expand Up @@ -854,18 +906,3 @@ def find_matching_overload_items(overloaded: Overloaded,
# it maintains backward compatibility.
res = items[:]
return res


def find_unpack_in_tuple(t: TupleType) -> Optional[int]:
unpack_index: Optional[int] = None
for i, item in enumerate(t.items):
proper_item = get_proper_type(item)
if isinstance(proper_item, UnpackType):
# We cannot fail here, so we must check this in an earlier
# semanal phase.
# Funky code here avoids mypyc narrowing the type of unpack_index.
old_index = unpack_index
assert old_index is None
# Don't return so that we can also sanity check there is only one.
unpack_index = i
return unpack_index
5 changes: 5 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def visit_type_var(self, t: TypeVarType) -> Type:
return self.replacement
return t

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t

def visit_param_spec(self, t: ParamSpecType) -> Type:
if self.erase_id(t.id):
return self.replacement
Expand Down
63 changes: 51 additions & 12 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Dict, Iterable, List, TypeVar, Mapping, cast, Union, Optional
from typing import Dict, Iterable, List, TypeVar, Mapping, cast, Union, Optional, Sequence

from mypy.types import (
Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType,
NoneType, Overloaded, TupleType, TypedDictType, UnionType,
ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId,
FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType,
TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor,
UnpackType, TypeVarTupleType
UnpackType, TypeVarTupleType, TypeList
)
from mypy.typevartuples import split_with_instance, split_with_prefix_and_suffix


def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
Expand All @@ -26,8 +27,26 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
return typ
else:
variables: Dict[TypeVarId, Type] = {}
for binder, arg in zip(instance.type.defn.type_vars, instance.args):
if instance.type.has_type_var_tuple_type:
assert instance.type.type_var_tuple_prefix is not None
assert instance.type.type_var_tuple_suffix is not None

args_prefix, args_middle, args_suffix = split_with_instance(instance)
tvars_prefix, tvars_middle, tvars_suffix = split_with_prefix_and_suffix(
tuple(instance.type.defn.type_vars),
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
variables = {tvars_middle[0].id: TypeList(list(args_middle))}
instance_args = args_prefix + args_suffix
tvars = tvars_prefix + tvars_suffix
else:
tvars = tuple(instance.type.defn.type_vars)
instance_args = instance.args

for binder, arg in zip(tvars, instance_args):
variables[binder.id] = arg

return expand_type(typ, variables)


Expand All @@ -46,6 +65,7 @@ def freshen_function_type_vars(callee: F) -> F:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
Expand Down Expand Up @@ -89,8 +109,11 @@ def visit_erased_type(self, t: ErasedType) -> Type:
raise RuntimeError()

def visit_instance(self, t: Instance) -> Type:
args = self.expand_types(t.args)
return Instance(t.type, args, t.line, t.column)
args = self.expand_types_with_unpack(list(t.args))
if isinstance(args, list):
return Instance(t.type, args, t.line, t.column)
else:
return args

def visit_type_var(self, t: TypeVarType) -> Type:
repl = get_proper_type(self.variables.get(t.id, t))
Expand Down Expand Up @@ -153,6 +176,8 @@ def expand_unpack(self, t: UnpackType) -> Optional[Union[List[Type], Instance, A
repl = get_proper_type(self.variables.get(proper_typ.id, t))
if isinstance(repl, TupleType):
return repl.items
if isinstance(repl, TypeList):
return repl.items
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
return repl
elif isinstance(repl, AnyType):
Expand All @@ -166,9 +191,9 @@ def expand_unpack(self, t: UnpackType) -> Optional[Union[List[Type], Instance, A
elif isinstance(repl, UninhabitedType):
return None
else:
raise NotImplementedError(f"Invalid type to expand: {repl}")
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
else:
raise NotImplementedError
raise NotImplementedError(f"Invalid type to expand: {proper_typ}")

def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
Expand Down Expand Up @@ -211,17 +236,25 @@ def visit_overloaded(self, t: Overloaded) -> Type:
items.append(new_item)
return Overloaded(items)

def visit_tuple_type(self, t: TupleType) -> Type:
items = []
for item in t.items:
def expand_types_with_unpack(
self, typs: Sequence[Type]
) -> Union[List[Type], AnyType, UninhabitedType, Instance]:
"""Expands a list of types that has an unpack.

In corner cases, this can return a type rather than a list, in which case this
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
items: List[Type] = []
for item in typs:
proper_item = get_proper_type(item)
if isinstance(proper_item, UnpackType):
unpacked_items = self.expand_unpack(proper_item)
if unpacked_items is None:
# TODO: better error, something like tuple of unknown?
return UninhabitedType()
elif isinstance(unpacked_items, Instance):
if len(t.items) == 1:
if len(typs) == 1:
return unpacked_items
else:
assert False, "Invalid unpack of variable length tuple"
Expand All @@ -231,8 +264,14 @@ def visit_tuple_type(self, t: TupleType) -> Type:
items.extend(unpacked_items)
else:
items.append(proper_item.accept(self))
return items

return t.copy_modified(items=items)
def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_types_with_unpack(t.items)
if isinstance(items, list):
return t.copy_modified(items=items)
else:
return items

def visit_typeddict_type(self, t: TypedDictType) -> Type:
return t.copy_modified(item_types=self.expand_types(t.items.values()))
Expand Down
14 changes: 13 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,7 @@ class is generic then it will be a type constructor of higher kind.
'inferring', 'is_enum', 'fallback_to_any', 'type_vars', 'has_param_spec_type',
'bases', '_promote', 'tuple_type', 'is_named_tuple', 'typeddict_type',
'is_newtype', 'is_intersection', 'metadata', 'alt_promote',
'has_type_var_tuple_type', 'type_var_tuple_prefix', 'type_var_tuple_suffix'
)

_fullname: Bogus[str] # Fully qualified name
Expand Down Expand Up @@ -2719,6 +2720,7 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No
self.module_name = module_name
self.type_vars = []
self.has_param_spec_type = False
self.has_type_var_tuple_type = False
self.bases = []
self.mro = []
self._mro_refs = None
Expand All @@ -2734,6 +2736,8 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No
self.inferring = []
self.is_protocol = False
self.runtime_protocol = False
self.type_var_tuple_prefix: Optional[int] = None
self.type_var_tuple_suffix: Optional[int] = None
self.add_type_vars()
self.is_final = False
self.is_enum = False
Expand All @@ -2749,10 +2753,18 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No

def add_type_vars(self) -> None:
if self.defn.type_vars:
for vd in self.defn.type_vars:
for i, vd in enumerate(self.defn.type_vars):
if isinstance(vd, mypy.types.ParamSpecType):
self.has_param_spec_type = True
if isinstance(vd, mypy.types.TypeVarTupleType):
assert not self.has_type_var_tuple_type
self.has_type_var_tuple_type = True
self.type_var_tuple_prefix = i
self.type_var_tuple_suffix = len(self.defn.type_vars) - i - 1
self.type_vars.append(vd.name)
assert not (
self.has_param_spec_type and self.has_type_var_tuple_type
), "Mixing type var tuples and param specs not supported yet"

@property
def name(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,11 @@ def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarLikeExpr]]
# It's bound by our type variable scope
return None
return unbound.name, sym.node
if sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return unbound.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr):
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
Expand Down
Loading