Skip to content

Commit dc5f891

Browse files
authored
Enable generic TypedDicts (#13389)
Fixes #3863 This builds on top of some infra I added for recursive types (Ref #13297). Implementation is quite straightforward. The only non-trivial thing is that when extending/merging TypedDicts, the item types need to me mapped to supertype during semantic analysis. This means we can't call `is_subtype()` etc., and can in theory get types like `Union[int, int]`. But OTOH this equally applies to type aliases, and doesn't seem to cause problems.
1 parent 8deeaf3 commit dc5f891

18 files changed

+630
-129
lines changed

misc/proper_plugin.py

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def is_special_target(right: ProperType) -> bool:
9797
"mypy.types.PartialType",
9898
"mypy.types.ErasedType",
9999
"mypy.types.DeletedType",
100+
"mypy.types.RequiredType",
100101
):
101102
# Special case: these are not valid targets for a type alias and thus safe.
102103
# TODO: introduce a SyntheticType base to simplify this?

mypy/checkexpr.py

+141-23
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ def __init__(self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: P
283283

284284
self.resolved_type = {}
285285

286+
# Callee in a call expression is in some sense both runtime context and
287+
# type context, because we support things like C[int](...). Store information
288+
# on whether current expression is a callee, to give better error messages
289+
# related to type context.
290+
self.is_callee = False
291+
286292
def reset(self) -> None:
287293
self.resolved_type = {}
288294

@@ -319,7 +325,11 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
319325
result = node.type
320326
elif isinstance(node, TypeInfo):
321327
# Reference to a type object.
322-
result = type_object_type(node, self.named_type)
328+
if node.typeddict_type:
329+
# We special-case TypedDict, because they don't define any constructor.
330+
result = self.typeddict_callable(node)
331+
else:
332+
result = type_object_type(node, self.named_type)
323333
if isinstance(result, CallableType) and isinstance( # type: ignore
324334
result.ret_type, Instance
325335
):
@@ -386,17 +396,29 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
386396
return self.accept(e.analyzed, self.type_context[-1])
387397
return self.visit_call_expr_inner(e, allow_none_return=allow_none_return)
388398

399+
def refers_to_typeddict(self, base: Expression) -> bool:
400+
if not isinstance(base, RefExpr):
401+
return False
402+
if isinstance(base.node, TypeInfo) and base.node.typeddict_type is not None:
403+
# Direct reference.
404+
return True
405+
return isinstance(base.node, TypeAlias) and isinstance(
406+
get_proper_type(base.node.target), TypedDictType
407+
)
408+
389409
def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> Type:
390410
if (
391-
isinstance(e.callee, RefExpr)
392-
and isinstance(e.callee.node, TypeInfo)
393-
and e.callee.node.typeddict_type is not None
411+
self.refers_to_typeddict(e.callee)
412+
or isinstance(e.callee, IndexExpr)
413+
and self.refers_to_typeddict(e.callee.base)
394414
):
395-
# Use named fallback for better error messages.
396-
typeddict_type = e.callee.node.typeddict_type.copy_modified(
397-
fallback=Instance(e.callee.node, [])
398-
)
399-
return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
415+
typeddict_callable = get_proper_type(self.accept(e.callee, is_callee=True))
416+
if isinstance(typeddict_callable, CallableType):
417+
typeddict_type = get_proper_type(typeddict_callable.ret_type)
418+
assert isinstance(typeddict_type, TypedDictType)
419+
return self.check_typeddict_call(
420+
typeddict_type, e.arg_kinds, e.arg_names, e.args, e, typeddict_callable
421+
)
400422
if (
401423
isinstance(e.callee, NameExpr)
402424
and e.callee.name in ("isinstance", "issubclass")
@@ -457,7 +479,9 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
457479
ret_type=self.object_type(),
458480
fallback=self.named_type("builtins.function"),
459481
)
460-
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
482+
callee_type = get_proper_type(
483+
self.accept(e.callee, type_context, always_allow_any=True, is_callee=True)
484+
)
461485
if (
462486
self.chk.options.disallow_untyped_calls
463487
and self.chk.in_checked_function()
@@ -628,28 +652,33 @@ def check_typeddict_call(
628652
arg_names: Sequence[Optional[str]],
629653
args: List[Expression],
630654
context: Context,
655+
orig_callee: Optional[Type],
631656
) -> Type:
632657
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]):
633658
# ex: Point(x=42, y=1337)
634659
assert all(arg_name is not None for arg_name in arg_names)
635660
item_names = cast(List[str], arg_names)
636661
item_args = args
637662
return self.check_typeddict_call_with_kwargs(
638-
callee, dict(zip(item_names, item_args)), context
663+
callee, dict(zip(item_names, item_args)), context, orig_callee
639664
)
640665

641666
if len(args) == 1 and arg_kinds[0] == ARG_POS:
642667
unique_arg = args[0]
643668
if isinstance(unique_arg, DictExpr):
644669
# ex: Point({'x': 42, 'y': 1337})
645-
return self.check_typeddict_call_with_dict(callee, unique_arg, context)
670+
return self.check_typeddict_call_with_dict(
671+
callee, unique_arg, context, orig_callee
672+
)
646673
if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr):
647674
# ex: Point(dict(x=42, y=1337))
648-
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context)
675+
return self.check_typeddict_call_with_dict(
676+
callee, unique_arg.analyzed, context, orig_callee
677+
)
649678

650679
if len(args) == 0:
651680
# ex: EmptyDict()
652-
return self.check_typeddict_call_with_kwargs(callee, {}, context)
681+
return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee)
653682

654683
self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context)
655684
return AnyType(TypeOfAny.from_error)
@@ -683,18 +712,59 @@ def match_typeddict_call_with_dict(
683712
return False
684713

685714
def check_typeddict_call_with_dict(
686-
self, callee: TypedDictType, kwargs: DictExpr, context: Context
715+
self,
716+
callee: TypedDictType,
717+
kwargs: DictExpr,
718+
context: Context,
719+
orig_callee: Optional[Type],
687720
) -> Type:
688721
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs)
689722
if validated_kwargs is not None:
690723
return self.check_typeddict_call_with_kwargs(
691-
callee, kwargs=validated_kwargs, context=context
724+
callee, kwargs=validated_kwargs, context=context, orig_callee=orig_callee
692725
)
693726
else:
694727
return AnyType(TypeOfAny.from_error)
695728

729+
def typeddict_callable(self, info: TypeInfo) -> CallableType:
730+
"""Construct a reasonable type for a TypedDict type in runtime context.
731+
732+
If it appears as a callee, it will be special-cased anyway, e.g. it is
733+
also allowed to accept a single positional argument if it is a dict literal.
734+
735+
Note it is not safe to move this to type_object_type() since it will crash
736+
on plugin-generated TypedDicts, that may not have the special_alias.
737+
"""
738+
assert info.special_alias is not None
739+
target = info.special_alias.target
740+
assert isinstance(target, ProperType) and isinstance(target, TypedDictType)
741+
expected_types = list(target.items.values())
742+
kinds = [ArgKind.ARG_NAMED] * len(expected_types)
743+
names = list(target.items.keys())
744+
return CallableType(
745+
expected_types,
746+
kinds,
747+
names,
748+
target,
749+
self.named_type("builtins.type"),
750+
variables=info.defn.type_vars,
751+
)
752+
753+
def typeddict_callable_from_context(self, callee: TypedDictType) -> CallableType:
754+
return CallableType(
755+
list(callee.items.values()),
756+
[ArgKind.ARG_NAMED] * len(callee.items),
757+
list(callee.items.keys()),
758+
callee,
759+
self.named_type("builtins.type"),
760+
)
761+
696762
def check_typeddict_call_with_kwargs(
697-
self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context
763+
self,
764+
callee: TypedDictType,
765+
kwargs: Dict[str, Expression],
766+
context: Context,
767+
orig_callee: Optional[Type],
698768
) -> Type:
699769
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
700770
expected_keys = [
@@ -708,7 +778,38 @@ def check_typeddict_call_with_kwargs(
708778
)
709779
return AnyType(TypeOfAny.from_error)
710780

711-
for (item_name, item_expected_type) in callee.items.items():
781+
orig_callee = get_proper_type(orig_callee)
782+
if isinstance(orig_callee, CallableType):
783+
infer_callee = orig_callee
784+
else:
785+
# Try reconstructing from type context.
786+
if callee.fallback.type.special_alias is not None:
787+
infer_callee = self.typeddict_callable(callee.fallback.type)
788+
else:
789+
# Likely a TypedDict type generated by a plugin.
790+
infer_callee = self.typeddict_callable_from_context(callee)
791+
792+
# We don't show any errors, just infer types in a generic TypedDict type,
793+
# a custom error message will be given below, if there are errors.
794+
with self.msg.filter_errors(), self.chk.local_type_map():
795+
orig_ret_type, _ = self.check_callable_call(
796+
infer_callee,
797+
list(kwargs.values()),
798+
[ArgKind.ARG_NAMED] * len(kwargs),
799+
context,
800+
list(kwargs.keys()),
801+
None,
802+
None,
803+
None,
804+
)
805+
806+
ret_type = get_proper_type(orig_ret_type)
807+
if not isinstance(ret_type, TypedDictType):
808+
# If something went really wrong, type-check call with original type,
809+
# this may give a better error message.
810+
ret_type = callee
811+
812+
for (item_name, item_expected_type) in ret_type.items.items():
712813
if item_name in kwargs:
713814
item_value = kwargs[item_name]
714815
self.chk.check_simple_assignment(
@@ -721,7 +822,7 @@ def check_typeddict_call_with_kwargs(
721822
code=codes.TYPEDDICT_ITEM,
722823
)
723824

724-
return callee
825+
return orig_ret_type
725826

726827
def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
727828
"""Get variable node for a partial self attribute.
@@ -2547,7 +2648,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
25472648
return self.analyze_ref_expr(e)
25482649
else:
25492650
# This is a reference to a non-module attribute.
2550-
original_type = self.accept(e.expr)
2651+
original_type = self.accept(e.expr, is_callee=self.is_callee)
25512652
base = e.expr
25522653
module_symbol_table = None
25532654

@@ -3670,6 +3771,8 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
36703771
elif isinstance(item, TupleType) and item.partial_fallback.type.is_named_tuple:
36713772
tp = type_object_type(item.partial_fallback.type, self.named_type)
36723773
return self.apply_type_arguments_to_callable(tp, item.partial_fallback.args, tapp)
3774+
elif isinstance(item, TypedDictType):
3775+
return self.typeddict_callable_from_context(item)
36733776
else:
36743777
self.chk.fail(message_registry.ONLY_CLASS_APPLICATION, tapp)
36753778
return AnyType(TypeOfAny.from_error)
@@ -3723,7 +3826,12 @@ class LongName(Generic[T]): ...
37233826
# For example:
37243827
# A = List[Tuple[T, T]]
37253828
# x = A() <- same as List[Tuple[Any, Any]], see PEP 484.
3726-
item = get_proper_type(set_any_tvars(alias, ctx.line, ctx.column))
3829+
disallow_any = self.chk.options.disallow_any_generics and self.is_callee
3830+
item = get_proper_type(
3831+
set_any_tvars(
3832+
alias, ctx.line, ctx.column, disallow_any=disallow_any, fail=self.msg.fail
3833+
)
3834+
)
37273835
if isinstance(item, Instance):
37283836
# Normally we get a callable type (or overloaded) with .is_type_obj() true
37293837
# representing the class's constructor
@@ -3738,6 +3846,8 @@ class LongName(Generic[T]): ...
37383846
tuple_fallback(item).type.fullname != "builtins.tuple"
37393847
):
37403848
return type_object_type(tuple_fallback(item).type, self.named_type)
3849+
elif isinstance(item, TypedDictType):
3850+
return self.typeddict_callable_from_context(item)
37413851
elif isinstance(item, AnyType):
37423852
return AnyType(TypeOfAny.from_another_any, source_any=item)
37433853
else:
@@ -3962,7 +4072,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
39624072
# to avoid the second error, we always return TypedDict type that was requested
39634073
typeddict_context = self.find_typeddict_context(self.type_context[-1], e)
39644074
if typeddict_context:
3965-
self.check_typeddict_call_with_dict(callee=typeddict_context, kwargs=e, context=e)
4075+
orig_ret_type = self.check_typeddict_call_with_dict(
4076+
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
4077+
)
4078+
ret_type = get_proper_type(orig_ret_type)
4079+
if isinstance(ret_type, TypedDictType):
4080+
return ret_type.copy_modified()
39664081
return typeddict_context.copy_modified()
39674082

39684083
# fast path attempt
@@ -4494,6 +4609,7 @@ def accept(
44944609
type_context: Optional[Type] = None,
44954610
allow_none_return: bool = False,
44964611
always_allow_any: bool = False,
4612+
is_callee: bool = False,
44974613
) -> Type:
44984614
"""Type check a node in the given type context. If allow_none_return
44994615
is True and this expression is a call, allow it to return None. This
@@ -4502,6 +4618,8 @@ def accept(
45024618
if node in self.type_overrides:
45034619
return self.type_overrides[node]
45044620
self.type_context.append(type_context)
4621+
old_is_callee = self.is_callee
4622+
self.is_callee = is_callee
45054623
try:
45064624
if allow_none_return and isinstance(node, CallExpr):
45074625
typ = self.visit_call_expr(node, allow_none_return=True)
@@ -4517,7 +4635,7 @@ def accept(
45174635
report_internal_error(
45184636
err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options
45194637
)
4520-
4638+
self.is_callee = old_is_callee
45214639
self.type_context.pop()
45224640
assert typ is not None
45234641
self.chk.store_type(node, typ)

mypy/checkmember.py

+2
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
331331
assert isinstance(ret_type, ProperType)
332332
if isinstance(ret_type, TupleType):
333333
ret_type = tuple_fallback(ret_type)
334+
if isinstance(ret_type, TypedDictType):
335+
ret_type = ret_type.fallback
334336
if isinstance(ret_type, Instance):
335337
if not mx.is_operator:
336338
# When Python sees an operator (eg `3 == 4`), it automatically translates that

mypy/expandtype.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ def visit_tuple_type(self, t: TupleType) -> Type:
307307
return items
308308

309309
def visit_typeddict_type(self, t: TypedDictType) -> Type:
310-
return t.copy_modified(item_types=self.expand_types(t.items.values()))
310+
fallback = t.fallback.accept(self)
311+
fallback = get_proper_type(fallback)
312+
if not isinstance(fallback, Instance):
313+
fallback = t.fallback
314+
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
311315

312316
def visit_literal_type(self, t: LiteralType) -> Type:
313317
# TODO: Verify this implementation is correct

mypy/fixup.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def visit_type_info(self, info: TypeInfo) -> None:
8080
info.update_tuple_type(info.tuple_type)
8181
if info.typeddict_type:
8282
info.typeddict_type.accept(self.type_fixer)
83+
info.update_typeddict_type(info.typeddict_type)
8384
if info.declared_metaclass:
8485
info.declared_metaclass.accept(self.type_fixer)
8586
if info.metaclass_type:

mypy/nodes.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3192,8 +3192,8 @@ class TypeAlias(SymbolNode):
31923192
following:
31933193
31943194
1. An alias targeting a generic class without explicit variables act as
3195-
the given class (this doesn't apply to Tuple and Callable, which are not proper
3196-
classes but special type constructors):
3195+
the given class (this doesn't apply to TypedDict, Tuple and Callable, which
3196+
are not proper classes but special type constructors):
31973197
31983198
A = List
31993199
AA = List[Any]
@@ -3305,7 +3305,9 @@ def from_typeddict_type(cls, info: TypeInfo) -> TypeAlias:
33053305
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
33063306
assert info.typeddict_type
33073307
return TypeAlias(
3308-
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
3308+
info.typeddict_type.copy_modified(
3309+
fallback=mypy.types.Instance(info, info.defn.type_vars)
3310+
),
33093311
info.fullname,
33103312
info.line,
33113313
info.column,

0 commit comments

Comments
 (0)