Skip to content

Enable generic TypedDicts #13389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.PartialType",
"mypy.types.ErasedType",
"mypy.types.DeletedType",
"mypy.types.RequiredType",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
Expand Down
164 changes: 141 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def __init__(self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: P

self.resolved_type = {}

# Callee in a call expression is in some sense both runtime context and
# type context, because we support things like C[int](...). Store information
# on whether current expression is a callee, to give better error messages
# related to type context.
self.is_callee = False

def reset(self) -> None:
self.resolved_type = {}

Expand Down Expand Up @@ -319,7 +325,11 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = node.type
elif isinstance(node, TypeInfo):
# Reference to a type object.
result = type_object_type(node, self.named_type)
if node.typeddict_type:
# We special-case TypedDict, because they don't define any constructor.
result = self.typeddict_callable(node)
else:
result = type_object_type(node, self.named_type)
if isinstance(result, CallableType) and isinstance( # type: ignore
result.ret_type, Instance
):
Expand Down Expand Up @@ -386,17 +396,29 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
return self.accept(e.analyzed, self.type_context[-1])
return self.visit_call_expr_inner(e, allow_none_return=allow_none_return)

def refers_to_typeddict(self, base: Expression) -> bool:
if not isinstance(base, RefExpr):
return False
if isinstance(base.node, TypeInfo) and base.node.typeddict_type is not None:
# Direct reference.
return True
return isinstance(base.node, TypeAlias) and isinstance(
get_proper_type(base.node.target), TypedDictType
)

def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> Type:
if (
isinstance(e.callee, RefExpr)
and isinstance(e.callee.node, TypeInfo)
and e.callee.node.typeddict_type is not None
self.refers_to_typeddict(e.callee)
or isinstance(e.callee, IndexExpr)
and self.refers_to_typeddict(e.callee.base)
):
# Use named fallback for better error messages.
typeddict_type = e.callee.node.typeddict_type.copy_modified(
fallback=Instance(e.callee.node, [])
)
return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
typeddict_callable = get_proper_type(self.accept(e.callee, is_callee=True))
if isinstance(typeddict_callable, CallableType):
typeddict_type = get_proper_type(typeddict_callable.ret_type)
assert isinstance(typeddict_type, TypedDictType)
return self.check_typeddict_call(
typeddict_type, e.arg_kinds, e.arg_names, e.args, e, typeddict_callable
)
if (
isinstance(e.callee, NameExpr)
and e.callee.name in ("isinstance", "issubclass")
Expand Down Expand Up @@ -457,7 +479,9 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
ret_type=self.object_type(),
fallback=self.named_type("builtins.function"),
)
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
callee_type = get_proper_type(
self.accept(e.callee, type_context, always_allow_any=True, is_callee=True)
)
if (
self.chk.options.disallow_untyped_calls
and self.chk.in_checked_function()
Expand Down Expand Up @@ -628,28 +652,33 @@ def check_typeddict_call(
arg_names: Sequence[Optional[str]],
args: List[Expression],
context: Context,
orig_callee: Optional[Type],
) -> Type:
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]):
# ex: Point(x=42, y=1337)
assert all(arg_name is not None for arg_name in arg_names)
item_names = cast(List[str], arg_names)
item_args = args
return self.check_typeddict_call_with_kwargs(
callee, dict(zip(item_names, item_args)), context
callee, dict(zip(item_names, item_args)), context, orig_callee
)

if len(args) == 1 and arg_kinds[0] == ARG_POS:
unique_arg = args[0]
if isinstance(unique_arg, DictExpr):
# ex: Point({'x': 42, 'y': 1337})
return self.check_typeddict_call_with_dict(callee, unique_arg, context)
return self.check_typeddict_call_with_dict(
callee, unique_arg, context, orig_callee
)
if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr):
# ex: Point(dict(x=42, y=1337))
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context)
return self.check_typeddict_call_with_dict(
callee, unique_arg.analyzed, context, orig_callee
)

if len(args) == 0:
# ex: EmptyDict()
return self.check_typeddict_call_with_kwargs(callee, {}, context)
return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee)

self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -683,18 +712,59 @@ def match_typeddict_call_with_dict(
return False

def check_typeddict_call_with_dict(
self, callee: TypedDictType, kwargs: DictExpr, context: Context
self,
callee: TypedDictType,
kwargs: DictExpr,
context: Context,
orig_callee: Optional[Type],
) -> Type:
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs)
if validated_kwargs is not None:
return self.check_typeddict_call_with_kwargs(
callee, kwargs=validated_kwargs, context=context
callee, kwargs=validated_kwargs, context=context, orig_callee=orig_callee
)
else:
return AnyType(TypeOfAny.from_error)

def typeddict_callable(self, info: TypeInfo) -> CallableType:
"""Construct a reasonable type for a TypedDict type in runtime context.

If it appears as a callee, it will be special-cased anyway, e.g. it is
also allowed to accept a single positional argument if it is a dict literal.

Note it is not safe to move this to type_object_type() since it will crash
on plugin-generated TypedDicts, that may not have the special_alias.
"""
assert info.special_alias is not None
target = info.special_alias.target
assert isinstance(target, ProperType) and isinstance(target, TypedDictType)
expected_types = list(target.items.values())
kinds = [ArgKind.ARG_NAMED] * len(expected_types)
names = list(target.items.keys())
return CallableType(
expected_types,
kinds,
names,
target,
self.named_type("builtins.type"),
variables=info.defn.type_vars,
)

def typeddict_callable_from_context(self, callee: TypedDictType) -> CallableType:
return CallableType(
list(callee.items.values()),
[ArgKind.ARG_NAMED] * len(callee.items),
list(callee.items.keys()),
callee,
self.named_type("builtins.type"),
)

def check_typeddict_call_with_kwargs(
self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context
self,
callee: TypedDictType,
kwargs: Dict[str, Expression],
context: Context,
orig_callee: Optional[Type],
) -> Type:
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
expected_keys = [
Expand All @@ -708,7 +778,38 @@ def check_typeddict_call_with_kwargs(
)
return AnyType(TypeOfAny.from_error)

for (item_name, item_expected_type) in callee.items.items():
orig_callee = get_proper_type(orig_callee)
if isinstance(orig_callee, CallableType):
infer_callee = orig_callee
else:
# Try reconstructing from type context.
if callee.fallback.type.special_alias is not None:
infer_callee = self.typeddict_callable(callee.fallback.type)
else:
# Likely a TypedDict type generated by a plugin.
infer_callee = self.typeddict_callable_from_context(callee)

# We don't show any errors, just infer types in a generic TypedDict type,
# a custom error message will be given below, if there are errors.
with self.msg.filter_errors(), self.chk.local_type_map():
orig_ret_type, _ = self.check_callable_call(
infer_callee,
list(kwargs.values()),
[ArgKind.ARG_NAMED] * len(kwargs),
context,
list(kwargs.keys()),
None,
None,
None,
)

ret_type = get_proper_type(orig_ret_type)
if not isinstance(ret_type, TypedDictType):
# If something went really wrong, type-check call with original type,
# this may give a better error message.
ret_type = callee

for (item_name, item_expected_type) in ret_type.items.items():
if item_name in kwargs:
item_value = kwargs[item_name]
self.chk.check_simple_assignment(
Expand All @@ -721,7 +822,7 @@ def check_typeddict_call_with_kwargs(
code=codes.TYPEDDICT_ITEM,
)

return callee
return orig_ret_type

def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
"""Get variable node for a partial self attribute.
Expand Down Expand Up @@ -2547,7 +2648,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
return self.analyze_ref_expr(e)
else:
# This is a reference to a non-module attribute.
original_type = self.accept(e.expr)
original_type = self.accept(e.expr, is_callee=self.is_callee)
base = e.expr
module_symbol_table = None

Expand Down Expand Up @@ -3670,6 +3771,8 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
elif isinstance(item, TupleType) and item.partial_fallback.type.is_named_tuple:
tp = type_object_type(item.partial_fallback.type, self.named_type)
return self.apply_type_arguments_to_callable(tp, item.partial_fallback.args, tapp)
elif isinstance(item, TypedDictType):
return self.typeddict_callable_from_context(item)
else:
self.chk.fail(message_registry.ONLY_CLASS_APPLICATION, tapp)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -3723,7 +3826,12 @@ class LongName(Generic[T]): ...
# For example:
# A = List[Tuple[T, T]]
# x = A() <- same as List[Tuple[Any, Any]], see PEP 484.
item = get_proper_type(set_any_tvars(alias, ctx.line, ctx.column))
disallow_any = self.chk.options.disallow_any_generics and self.is_callee
item = get_proper_type(
set_any_tvars(
alias, ctx.line, ctx.column, disallow_any=disallow_any, fail=self.msg.fail
)
)
if isinstance(item, Instance):
# Normally we get a callable type (or overloaded) with .is_type_obj() true
# representing the class's constructor
Expand All @@ -3738,6 +3846,8 @@ class LongName(Generic[T]): ...
tuple_fallback(item).type.fullname != "builtins.tuple"
):
return type_object_type(tuple_fallback(item).type, self.named_type)
elif isinstance(item, TypedDictType):
return self.typeddict_callable_from_context(item)
elif isinstance(item, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=item)
else:
Expand Down Expand Up @@ -3962,7 +4072,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
# to avoid the second error, we always return TypedDict type that was requested
typeddict_context = self.find_typeddict_context(self.type_context[-1], e)
if typeddict_context:
self.check_typeddict_call_with_dict(callee=typeddict_context, kwargs=e, context=e)
orig_ret_type = self.check_typeddict_call_with_dict(
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
)
ret_type = get_proper_type(orig_ret_type)
if isinstance(ret_type, TypedDictType):
return ret_type.copy_modified()
return typeddict_context.copy_modified()

# fast path attempt
Expand Down Expand Up @@ -4494,6 +4609,7 @@ def accept(
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
is_callee: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
Expand All @@ -4502,6 +4618,8 @@ def accept(
if node in self.type_overrides:
return self.type_overrides[node]
self.type_context.append(type_context)
old_is_callee = self.is_callee
self.is_callee = is_callee
try:
if allow_none_return and isinstance(node, CallExpr):
typ = self.visit_call_expr(node, allow_none_return=True)
Expand All @@ -4517,7 +4635,7 @@ def accept(
report_internal_error(
err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options
)

self.is_callee = old_is_callee
self.type_context.pop()
assert typ is not None
self.chk.store_type(node, typ)
Expand Down
2 changes: 2 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
assert isinstance(ret_type, ProperType)
if isinstance(ret_type, TupleType):
ret_type = tuple_fallback(ret_type)
if isinstance(ret_type, TypedDictType):
ret_type = ret_type.fallback
if isinstance(ret_type, Instance):
if not mx.is_operator:
# When Python sees an operator (eg `3 == 4`), it automatically translates that
Expand Down
6 changes: 5 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return items

def visit_typeddict_type(self, t: TypedDictType) -> Type:
return t.copy_modified(item_types=self.expand_types(t.items.values()))
fallback = t.fallback.accept(self)
fallback = get_proper_type(fallback)
if not isinstance(fallback, Instance):
fallback = t.fallback
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)

def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
Expand Down
1 change: 1 addition & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def visit_type_info(self, info: TypeInfo) -> None:
info.update_tuple_type(info.tuple_type)
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
Expand Down
8 changes: 5 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3192,8 +3192,8 @@ class TypeAlias(SymbolNode):
following:

1. An alias targeting a generic class without explicit variables act as
the given class (this doesn't apply to Tuple and Callable, which are not proper
classes but special type constructors):
the given class (this doesn't apply to TypedDict, Tuple and Callable, which
are not proper classes but special type constructors):

A = List
AA = List[Any]
Expand Down Expand Up @@ -3305,7 +3305,9 @@ def from_typeddict_type(cls, info: TypeInfo) -> TypeAlias:
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
assert info.typeddict_type
return TypeAlias(
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.typeddict_type.copy_modified(
fallback=mypy.types.Instance(info, info.defn.type_vars)
),
info.fullname,
info.line,
info.column,
Expand Down
Loading