@@ -283,6 +283,12 @@ def __init__(self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: P
283
283
284
284
self .resolved_type = {}
285
285
286
+ # Callee in a call expression is in some sense both runtime context and
287
+ # type context, because we support things like C[int](...). Store information
288
+ # on whether current expression is a callee, to give better error messages
289
+ # related to type context.
290
+ self .is_callee = False
291
+
286
292
def reset (self ) -> None :
287
293
self .resolved_type = {}
288
294
@@ -319,7 +325,11 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
319
325
result = node .type
320
326
elif isinstance (node , TypeInfo ):
321
327
# Reference to a type object.
322
- result = type_object_type (node , self .named_type )
328
+ if node .typeddict_type :
329
+ # We special-case TypedDict, because they don't define any constructor.
330
+ result = self .typeddict_callable (node )
331
+ else :
332
+ result = type_object_type (node , self .named_type )
323
333
if isinstance (result , CallableType ) and isinstance ( # type: ignore
324
334
result .ret_type , Instance
325
335
):
@@ -386,17 +396,29 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
386
396
return self .accept (e .analyzed , self .type_context [- 1 ])
387
397
return self .visit_call_expr_inner (e , allow_none_return = allow_none_return )
388
398
399
+ def refers_to_typeddict (self , base : Expression ) -> bool :
400
+ if not isinstance (base , RefExpr ):
401
+ return False
402
+ if isinstance (base .node , TypeInfo ) and base .node .typeddict_type is not None :
403
+ # Direct reference.
404
+ return True
405
+ return isinstance (base .node , TypeAlias ) and isinstance (
406
+ get_proper_type (base .node .target ), TypedDictType
407
+ )
408
+
389
409
def visit_call_expr_inner (self , e : CallExpr , allow_none_return : bool = False ) -> Type :
390
410
if (
391
- isinstance (e .callee , RefExpr )
392
- and isinstance (e .callee . node , TypeInfo )
393
- and e .callee .node . typeddict_type is not None
411
+ self . refers_to_typeddict (e .callee )
412
+ or isinstance (e .callee , IndexExpr )
413
+ and self . refers_to_typeddict ( e .callee .base )
394
414
):
395
- # Use named fallback for better error messages.
396
- typeddict_type = e .callee .node .typeddict_type .copy_modified (
397
- fallback = Instance (e .callee .node , [])
398
- )
399
- return self .check_typeddict_call (typeddict_type , e .arg_kinds , e .arg_names , e .args , e )
415
+ typeddict_callable = get_proper_type (self .accept (e .callee , is_callee = True ))
416
+ if isinstance (typeddict_callable , CallableType ):
417
+ typeddict_type = get_proper_type (typeddict_callable .ret_type )
418
+ assert isinstance (typeddict_type , TypedDictType )
419
+ return self .check_typeddict_call (
420
+ typeddict_type , e .arg_kinds , e .arg_names , e .args , e , typeddict_callable
421
+ )
400
422
if (
401
423
isinstance (e .callee , NameExpr )
402
424
and e .callee .name in ("isinstance" , "issubclass" )
@@ -457,7 +479,9 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
457
479
ret_type = self .object_type (),
458
480
fallback = self .named_type ("builtins.function" ),
459
481
)
460
- callee_type = get_proper_type (self .accept (e .callee , type_context , always_allow_any = True ))
482
+ callee_type = get_proper_type (
483
+ self .accept (e .callee , type_context , always_allow_any = True , is_callee = True )
484
+ )
461
485
if (
462
486
self .chk .options .disallow_untyped_calls
463
487
and self .chk .in_checked_function ()
@@ -628,28 +652,33 @@ def check_typeddict_call(
628
652
arg_names : Sequence [Optional [str ]],
629
653
args : List [Expression ],
630
654
context : Context ,
655
+ orig_callee : Optional [Type ],
631
656
) -> Type :
632
657
if len (args ) >= 1 and all ([ak == ARG_NAMED for ak in arg_kinds ]):
633
658
# ex: Point(x=42, y=1337)
634
659
assert all (arg_name is not None for arg_name in arg_names )
635
660
item_names = cast (List [str ], arg_names )
636
661
item_args = args
637
662
return self .check_typeddict_call_with_kwargs (
638
- callee , dict (zip (item_names , item_args )), context
663
+ callee , dict (zip (item_names , item_args )), context , orig_callee
639
664
)
640
665
641
666
if len (args ) == 1 and arg_kinds [0 ] == ARG_POS :
642
667
unique_arg = args [0 ]
643
668
if isinstance (unique_arg , DictExpr ):
644
669
# ex: Point({'x': 42, 'y': 1337})
645
- return self .check_typeddict_call_with_dict (callee , unique_arg , context )
670
+ return self .check_typeddict_call_with_dict (
671
+ callee , unique_arg , context , orig_callee
672
+ )
646
673
if isinstance (unique_arg , CallExpr ) and isinstance (unique_arg .analyzed , DictExpr ):
647
674
# ex: Point(dict(x=42, y=1337))
648
- return self .check_typeddict_call_with_dict (callee , unique_arg .analyzed , context )
675
+ return self .check_typeddict_call_with_dict (
676
+ callee , unique_arg .analyzed , context , orig_callee
677
+ )
649
678
650
679
if len (args ) == 0 :
651
680
# ex: EmptyDict()
652
- return self .check_typeddict_call_with_kwargs (callee , {}, context )
681
+ return self .check_typeddict_call_with_kwargs (callee , {}, context , orig_callee )
653
682
654
683
self .chk .fail (message_registry .INVALID_TYPEDDICT_ARGS , context )
655
684
return AnyType (TypeOfAny .from_error )
@@ -683,18 +712,59 @@ def match_typeddict_call_with_dict(
683
712
return False
684
713
685
714
def check_typeddict_call_with_dict (
686
- self , callee : TypedDictType , kwargs : DictExpr , context : Context
715
+ self ,
716
+ callee : TypedDictType ,
717
+ kwargs : DictExpr ,
718
+ context : Context ,
719
+ orig_callee : Optional [Type ],
687
720
) -> Type :
688
721
validated_kwargs = self .validate_typeddict_kwargs (kwargs = kwargs )
689
722
if validated_kwargs is not None :
690
723
return self .check_typeddict_call_with_kwargs (
691
- callee , kwargs = validated_kwargs , context = context
724
+ callee , kwargs = validated_kwargs , context = context , orig_callee = orig_callee
692
725
)
693
726
else :
694
727
return AnyType (TypeOfAny .from_error )
695
728
729
+ def typeddict_callable (self , info : TypeInfo ) -> CallableType :
730
+ """Construct a reasonable type for a TypedDict type in runtime context.
731
+
732
+ If it appears as a callee, it will be special-cased anyway, e.g. it is
733
+ also allowed to accept a single positional argument if it is a dict literal.
734
+
735
+ Note it is not safe to move this to type_object_type() since it will crash
736
+ on plugin-generated TypedDicts, that may not have the special_alias.
737
+ """
738
+ assert info .special_alias is not None
739
+ target = info .special_alias .target
740
+ assert isinstance (target , ProperType ) and isinstance (target , TypedDictType )
741
+ expected_types = list (target .items .values ())
742
+ kinds = [ArgKind .ARG_NAMED ] * len (expected_types )
743
+ names = list (target .items .keys ())
744
+ return CallableType (
745
+ expected_types ,
746
+ kinds ,
747
+ names ,
748
+ target ,
749
+ self .named_type ("builtins.type" ),
750
+ variables = info .defn .type_vars ,
751
+ )
752
+
753
+ def typeddict_callable_from_context (self , callee : TypedDictType ) -> CallableType :
754
+ return CallableType (
755
+ list (callee .items .values ()),
756
+ [ArgKind .ARG_NAMED ] * len (callee .items ),
757
+ list (callee .items .keys ()),
758
+ callee ,
759
+ self .named_type ("builtins.type" ),
760
+ )
761
+
696
762
def check_typeddict_call_with_kwargs (
697
- self , callee : TypedDictType , kwargs : Dict [str , Expression ], context : Context
763
+ self ,
764
+ callee : TypedDictType ,
765
+ kwargs : Dict [str , Expression ],
766
+ context : Context ,
767
+ orig_callee : Optional [Type ],
698
768
) -> Type :
699
769
if not (callee .required_keys <= set (kwargs .keys ()) <= set (callee .items .keys ())):
700
770
expected_keys = [
@@ -708,7 +778,38 @@ def check_typeddict_call_with_kwargs(
708
778
)
709
779
return AnyType (TypeOfAny .from_error )
710
780
711
- for (item_name , item_expected_type ) in callee .items .items ():
781
+ orig_callee = get_proper_type (orig_callee )
782
+ if isinstance (orig_callee , CallableType ):
783
+ infer_callee = orig_callee
784
+ else :
785
+ # Try reconstructing from type context.
786
+ if callee .fallback .type .special_alias is not None :
787
+ infer_callee = self .typeddict_callable (callee .fallback .type )
788
+ else :
789
+ # Likely a TypedDict type generated by a plugin.
790
+ infer_callee = self .typeddict_callable_from_context (callee )
791
+
792
+ # We don't show any errors, just infer types in a generic TypedDict type,
793
+ # a custom error message will be given below, if there are errors.
794
+ with self .msg .filter_errors (), self .chk .local_type_map ():
795
+ orig_ret_type , _ = self .check_callable_call (
796
+ infer_callee ,
797
+ list (kwargs .values ()),
798
+ [ArgKind .ARG_NAMED ] * len (kwargs ),
799
+ context ,
800
+ list (kwargs .keys ()),
801
+ None ,
802
+ None ,
803
+ None ,
804
+ )
805
+
806
+ ret_type = get_proper_type (orig_ret_type )
807
+ if not isinstance (ret_type , TypedDictType ):
808
+ # If something went really wrong, type-check call with original type,
809
+ # this may give a better error message.
810
+ ret_type = callee
811
+
812
+ for (item_name , item_expected_type ) in ret_type .items .items ():
712
813
if item_name in kwargs :
713
814
item_value = kwargs [item_name ]
714
815
self .chk .check_simple_assignment (
@@ -721,7 +822,7 @@ def check_typeddict_call_with_kwargs(
721
822
code = codes .TYPEDDICT_ITEM ,
722
823
)
723
824
724
- return callee
825
+ return orig_ret_type
725
826
726
827
def get_partial_self_var (self , expr : MemberExpr ) -> Optional [Var ]:
727
828
"""Get variable node for a partial self attribute.
@@ -2547,7 +2648,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
2547
2648
return self .analyze_ref_expr (e )
2548
2649
else :
2549
2650
# This is a reference to a non-module attribute.
2550
- original_type = self .accept (e .expr )
2651
+ original_type = self .accept (e .expr , is_callee = self . is_callee )
2551
2652
base = e .expr
2552
2653
module_symbol_table = None
2553
2654
@@ -3670,6 +3771,8 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
3670
3771
elif isinstance (item , TupleType ) and item .partial_fallback .type .is_named_tuple :
3671
3772
tp = type_object_type (item .partial_fallback .type , self .named_type )
3672
3773
return self .apply_type_arguments_to_callable (tp , item .partial_fallback .args , tapp )
3774
+ elif isinstance (item , TypedDictType ):
3775
+ return self .typeddict_callable_from_context (item )
3673
3776
else :
3674
3777
self .chk .fail (message_registry .ONLY_CLASS_APPLICATION , tapp )
3675
3778
return AnyType (TypeOfAny .from_error )
@@ -3723,7 +3826,12 @@ class LongName(Generic[T]): ...
3723
3826
# For example:
3724
3827
# A = List[Tuple[T, T]]
3725
3828
# x = A() <- same as List[Tuple[Any, Any]], see PEP 484.
3726
- item = get_proper_type (set_any_tvars (alias , ctx .line , ctx .column ))
3829
+ disallow_any = self .chk .options .disallow_any_generics and self .is_callee
3830
+ item = get_proper_type (
3831
+ set_any_tvars (
3832
+ alias , ctx .line , ctx .column , disallow_any = disallow_any , fail = self .msg .fail
3833
+ )
3834
+ )
3727
3835
if isinstance (item , Instance ):
3728
3836
# Normally we get a callable type (or overloaded) with .is_type_obj() true
3729
3837
# representing the class's constructor
@@ -3738,6 +3846,8 @@ class LongName(Generic[T]): ...
3738
3846
tuple_fallback (item ).type .fullname != "builtins.tuple"
3739
3847
):
3740
3848
return type_object_type (tuple_fallback (item ).type , self .named_type )
3849
+ elif isinstance (item , TypedDictType ):
3850
+ return self .typeddict_callable_from_context (item )
3741
3851
elif isinstance (item , AnyType ):
3742
3852
return AnyType (TypeOfAny .from_another_any , source_any = item )
3743
3853
else :
@@ -3962,7 +4072,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
3962
4072
# to avoid the second error, we always return TypedDict type that was requested
3963
4073
typeddict_context = self .find_typeddict_context (self .type_context [- 1 ], e )
3964
4074
if typeddict_context :
3965
- self .check_typeddict_call_with_dict (callee = typeddict_context , kwargs = e , context = e )
4075
+ orig_ret_type = self .check_typeddict_call_with_dict (
4076
+ callee = typeddict_context , kwargs = e , context = e , orig_callee = None
4077
+ )
4078
+ ret_type = get_proper_type (orig_ret_type )
4079
+ if isinstance (ret_type , TypedDictType ):
4080
+ return ret_type .copy_modified ()
3966
4081
return typeddict_context .copy_modified ()
3967
4082
3968
4083
# fast path attempt
@@ -4494,6 +4609,7 @@ def accept(
4494
4609
type_context : Optional [Type ] = None ,
4495
4610
allow_none_return : bool = False ,
4496
4611
always_allow_any : bool = False ,
4612
+ is_callee : bool = False ,
4497
4613
) -> Type :
4498
4614
"""Type check a node in the given type context. If allow_none_return
4499
4615
is True and this expression is a call, allow it to return None. This
@@ -4502,6 +4618,8 @@ def accept(
4502
4618
if node in self .type_overrides :
4503
4619
return self .type_overrides [node ]
4504
4620
self .type_context .append (type_context )
4621
+ old_is_callee = self .is_callee
4622
+ self .is_callee = is_callee
4505
4623
try :
4506
4624
if allow_none_return and isinstance (node , CallExpr ):
4507
4625
typ = self .visit_call_expr (node , allow_none_return = True )
@@ -4517,7 +4635,7 @@ def accept(
4517
4635
report_internal_error (
4518
4636
err , self .chk .errors .file , node .line , self .chk .errors , self .chk .options
4519
4637
)
4520
-
4638
+ self . is_callee = old_is_callee
4521
4639
self .type_context .pop ()
4522
4640
assert typ is not None
4523
4641
self .chk .store_type (node , typ )
0 commit comments