Skip to content

Commit f0e8288

Browse files
JukkaLilevkivskyi
authored andcommitted
Support TypedDicts with missing keys (total=False) (#3558)
* Basic support for TypedDicts with missing keys (total=False) Only the functional syntax is supported. * Support get(key, {}) and fix construction of partial typed dict * Fix subtyping of non-total typed dicts * Fix join with non-total typed dict * Fix meet with non-total typed dicts * Add serialization test case * Support TypedDict total keyword argument with class syntax * Attempt to fix Python 3.3 * Add minimal runtime `total` support to mypy_extensions There is no support for introspection of `total` yet. * Fix tests on pre-3.6 Python and improve introspection Make TypedDict `total` introspectable. * Fix lint * Fix problems caused by merge * Allow td['key'] even if td is not total * Fix lint * Add test case * Address review feedback * Update comment
1 parent 04a2aec commit f0e8288

16 files changed

+428
-122
lines changed

extensions/mypy_extensions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,18 @@ def _dict_new(cls, *args, **kwargs):
3030

3131

3232
def _typeddict_new(cls, _typename, _fields=None, **kwargs):
33+
total = kwargs.pop('total', True)
3334
if _fields is None:
3435
_fields = kwargs
3536
elif kwargs:
3637
raise TypeError("TypedDict takes either a dict or keyword arguments,"
3738
" but not both")
38-
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)})
39+
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields),
40+
'__total__': total})
3941

4042

4143
class _TypedDictMeta(type):
42-
def __new__(cls, name, bases, ns):
44+
def __new__(cls, name, bases, ns, total=True):
4345
# Create new typed dict class object.
4446
# This method is called directly when TypedDict is subclassed,
4547
# or via _typeddict_new when TypedDict is instantiated. This way
@@ -59,6 +61,8 @@ def __new__(cls, name, bases, ns):
5961
for base in bases:
6062
anns.update(base.__dict__.get('__annotations__', {}))
6163
tp_dict.__annotations__ = anns
64+
if not hasattr(tp_dict, '__total__'):
65+
tp_dict.__total__ = total
6266
return tp_dict
6367

6468
__instancecheck__ = __subclasscheck__ = _check_fails

mypy/checkexpr.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,31 +292,31 @@ def check_typeddict_call_with_dict(self, callee: TypedDictType,
292292
def check_typeddict_call_with_kwargs(self, callee: TypedDictType,
293293
kwargs: 'OrderedDict[str, Expression]',
294294
context: Context) -> Type:
295-
if callee.items.keys() != kwargs.keys():
296-
callee_item_names = callee.items.keys()
297-
kwargs_item_names = kwargs.keys()
298-
295+
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
296+
expected_item_names = [key for key in callee.items.keys()
297+
if key in callee.required_keys or key in kwargs.keys()]
298+
actual_item_names = kwargs.keys()
299299
self.msg.typeddict_instantiated_with_unexpected_items(
300-
expected_item_names=list(callee_item_names),
301-
actual_item_names=list(kwargs_item_names),
300+
expected_item_names=list(expected_item_names),
301+
actual_item_names=list(actual_item_names),
302302
context=context)
303303
return AnyType()
304304

305305
items = OrderedDict() # type: OrderedDict[str, Type]
306306
for (item_name, item_expected_type) in callee.items.items():
307-
item_value = kwargs[item_name]
308-
309-
self.chk.check_simple_assignment(
310-
lvalue_type=item_expected_type, rvalue=item_value, context=item_value,
311-
msg=messages.INCOMPATIBLE_TYPES,
312-
lvalue_name='TypedDict item "{}"'.format(item_name),
313-
rvalue_name='expression')
307+
if item_name in kwargs:
308+
item_value = kwargs[item_name]
309+
self.chk.check_simple_assignment(
310+
lvalue_type=item_expected_type, rvalue=item_value, context=item_value,
311+
msg=messages.INCOMPATIBLE_TYPES,
312+
lvalue_name='TypedDict item "{}"'.format(item_name),
313+
rvalue_name='expression')
314314
items[item_name] = item_expected_type
315315

316316
mapping_value_type = join.join_type_list(list(items.values()))
317317
fallback = self.chk.named_generic_type('typing.Mapping',
318318
[self.chk.str_type(), mapping_value_type])
319-
return TypedDictType(items, fallback)
319+
return TypedDictType(items, set(callee.required_keys), fallback)
320320

321321
# Types and methods that can be used to infer partial types.
322322
item_args = {'builtins.list': ['append'],

mypy/fastparse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,15 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
467467
metaclass = stringify_name(metaclass_arg.value)
468468
if metaclass is None:
469469
metaclass = '<error>' # To be reported later
470+
keywords = [(kw.arg, self.visit(kw.value))
471+
for kw in n.keywords]
470472

471473
cdef = ClassDef(n.name,
472474
self.as_block(n.body, n.lineno),
473475
None,
474476
self.translate_expr_list(n.bases),
475-
metaclass=metaclass)
477+
metaclass=metaclass,
478+
keywords=keywords)
476479
cdef.decorators = self.translate_expr_list(n.decorator_list)
477480
self.class_nesting -= 1
478481
return cdef

mypy/join.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,15 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
228228
items = OrderedDict([
229229
(item_name, s_item_type)
230230
for (item_name, s_item_type, t_item_type) in self.s.zip(t)
231-
if is_equivalent(s_item_type, t_item_type)
231+
if (is_equivalent(s_item_type, t_item_type) and
232+
(item_name in t.required_keys) == (item_name in self.s.required_keys))
232233
])
233234
mapping_value_type = join_type_list(list(items.values()))
234235
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
235-
return TypedDictType(items, fallback)
236+
# We need to filter by items.keys() since some required keys present in both t and
237+
# self.s might be missing from the join if the types are incompatible.
238+
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
239+
return TypedDictType(items, required_keys, fallback)
236240
elif isinstance(self.s, Instance):
237241
return join_instances(self.s, t.fallback)
238242
else:

mypy/meet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
252252

253253
def visit_typeddict_type(self, t: TypedDictType) -> Type:
254254
if isinstance(self.s, TypedDictType):
255-
for (_, l, r) in self.s.zip(t):
256-
if not is_equivalent(l, r):
255+
for (name, l, r) in self.s.zip(t):
256+
if (not is_equivalent(l, r) or
257+
(name in t.required_keys) != (name in self.s.required_keys)):
257258
return self.default(self.s)
258259
item_list = [] # type: List[Tuple[str, Type]]
259260
for (item_name, s_item_type, t_item_type) in self.s.zipall(t):
@@ -266,7 +267,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
266267
items = OrderedDict(item_list)
267268
mapping_value_type = join_type_list(list(items.values()))
268269
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
269-
return TypedDictType(items, fallback)
270+
required_keys = t.required_keys | self.s.required_keys
271+
return TypedDictType(items, required_keys, fallback)
270272
else:
271273
return self.default(self.s)
272274

mypy/nodes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
from abc import abstractmethod
5+
from collections import OrderedDict
56

67
from typing import (
78
Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional, Callable,
@@ -730,6 +731,7 @@ class ClassDef(Statement):
730731
info = None # type: TypeInfo # Related TypeInfo
731732
metaclass = '' # type: Optional[str]
732733
decorators = None # type: List[Expression]
734+
keywords = None # type: OrderedDict[str, Expression]
733735
analyzed = None # type: Optional[Expression]
734736
has_incompatible_baseclass = False
735737

@@ -738,13 +740,15 @@ def __init__(self,
738740
defs: 'Block',
739741
type_vars: List['mypy.types.TypeVarDef'] = None,
740742
base_type_exprs: List[Expression] = None,
741-
metaclass: str = None) -> None:
743+
metaclass: str = None,
744+
keywords: List[Tuple[str, Expression]] = None) -> None:
742745
self.name = name
743746
self.defs = defs
744747
self.type_vars = type_vars or []
745748
self.base_type_exprs = base_type_exprs or []
746749
self.metaclass = metaclass
747750
self.decorators = []
751+
self.keywords = OrderedDict(keywords or [])
748752

749753
def accept(self, visitor: StatementVisitor[T]) -> T:
750754
return visitor.visit_class_def(self)

mypy/plugin.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Plugin system for extending mypy."""
22

3+
from collections import OrderedDict
34
from abc import abstractmethod
45
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar
56

6-
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
7+
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr
78
from mypy.types import (
89
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType,
910
AnyType, TypeList, UnboundType
@@ -263,17 +264,26 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
263264
and len(ctx.args[0]) == 1
264265
and isinstance(ctx.args[0][0], StrExpr)
265266
and len(signature.arg_types) == 2
266-
and len(signature.variables) == 1):
267+
and len(signature.variables) == 1
268+
and len(ctx.args[1]) == 1):
267269
key = ctx.args[0][0].value
268270
value_type = ctx.type.items.get(key)
271+
ret_type = signature.ret_type
269272
if value_type:
273+
default_arg = ctx.args[1][0]
274+
if (isinstance(value_type, TypedDictType)
275+
and isinstance(default_arg, DictExpr)
276+
and len(default_arg.items) == 0):
277+
# Caller has empty dict {} as default for typed dict.
278+
value_type = value_type.copy_modified(required_keys=set())
270279
# Tweak the signature to include the value type as context. It's
271280
# only needed for type inference since there's a union with a type
272281
# variable that accepts everything.
273282
tv = TypeVarType(signature.variables[0])
274283
return signature.copy_modified(
275284
arg_types=[signature.arg_types[0],
276-
UnionType.make_simplified_union([value_type, tv])])
285+
UnionType.make_simplified_union([value_type, tv])],
286+
ret_type=ret_type)
277287
return signature
278288

279289

@@ -288,8 +298,15 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
288298
if value_type:
289299
if len(ctx.arg_types) == 1:
290300
return UnionType.make_simplified_union([value_type, NoneTyp()])
291-
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1:
292-
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
301+
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
302+
and len(ctx.args[1]) == 1):
303+
default_arg = ctx.args[1][0]
304+
if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
305+
and isinstance(value_type, TypedDictType)):
306+
# Special case '{}' as the default for a typed dict type.
307+
return value_type.copy_modified(required_keys=set())
308+
else:
309+
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
293310
else:
294311
ctx.api.msg.typeddict_item_name_not_found(ctx.type, key, ctx.context)
295312
return AnyType()

0 commit comments

Comments
 (0)