190
190
NoneType ,
191
191
Overloaded ,
192
192
PartialType ,
193
+ PlaceholderType ,
193
194
ProperType ,
194
195
StarType ,
195
196
TupleType ,
@@ -338,8 +339,6 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
338
339
# Used for collecting inferred attribute types so that they can be checked
339
340
# for consistency.
340
341
inferred_attribute_types : dict [Var , Type ] | None = None
341
- # Don't infer partial None types if we are processing assignment from Union
342
- no_partial_types : bool = False
343
342
344
343
# The set of all dependencies (suppressed or not) that this module accesses, either
345
344
# directly or indirectly.
@@ -3375,7 +3374,6 @@ def check_multi_assignment(
3375
3374
context : Context ,
3376
3375
infer_lvalue_type : bool = True ,
3377
3376
rv_type : Type | None = None ,
3378
- undefined_rvalue : bool = False ,
3379
3377
) -> None :
3380
3378
"""Check the assignment of one rvalue to a number of lvalues."""
3381
3379
@@ -3386,12 +3384,6 @@ def check_multi_assignment(
3386
3384
if isinstance (rvalue_type , TypeVarLikeType ):
3387
3385
rvalue_type = get_proper_type (rvalue_type .upper_bound )
3388
3386
3389
- if isinstance (rvalue_type , UnionType ):
3390
- # If this is an Optional type in non-strict Optional code, unwrap it.
3391
- relevant_items = rvalue_type .relevant_items ()
3392
- if len (relevant_items ) == 1 :
3393
- rvalue_type = get_proper_type (relevant_items [0 ])
3394
-
3395
3387
if isinstance (rvalue_type , AnyType ):
3396
3388
for lv in lvalues :
3397
3389
if isinstance (lv , StarExpr ):
@@ -3402,7 +3394,7 @@ def check_multi_assignment(
3402
3394
self .check_assignment (lv , temp_node , infer_lvalue_type )
3403
3395
elif isinstance (rvalue_type , TupleType ):
3404
3396
self .check_multi_assignment_from_tuple (
3405
- lvalues , rvalue , rvalue_type , context , undefined_rvalue , infer_lvalue_type
3397
+ lvalues , rvalue , rvalue_type , context , infer_lvalue_type
3406
3398
)
3407
3399
elif isinstance (rvalue_type , UnionType ):
3408
3400
self .check_multi_assignment_from_union (
@@ -3430,58 +3422,86 @@ def check_multi_assignment_from_union(
3430
3422
x, y = t
3431
3423
reveal_type(x) # Union[int, str]
3432
3424
3433
- The idea in this case is to process the assignment for every item of the union.
3434
- Important note: the types are collected in two places, 'union_types' contains
3435
- inferred types for first assignments, 'assignments' contains the narrowed types
3436
- for binder.
3425
+ The idea is to convert unions of tuples or other iterables to tuples of (simplified)
3426
+ unions and then simply apply `check_multi_assignment_from_tuple`.
3437
3427
"""
3438
- self .no_partial_types = True
3439
- transposed : tuple [list [Type ], ...] = tuple ([] for _ in self .flatten_lvalues (lvalues ))
3440
- # Notify binder that we want to defer bindings and instead collect types.
3441
- with self .binder .accumulate_type_assignments () as assignments :
3442
- for item in rvalue_type .items :
3443
- # Type check the assignment separately for each union item and collect
3444
- # the inferred lvalue types for each union item.
3445
- self .check_multi_assignment (
3446
- lvalues ,
3447
- rvalue ,
3448
- context ,
3449
- infer_lvalue_type = infer_lvalue_type ,
3450
- rv_type = item ,
3451
- undefined_rvalue = True ,
3452
- )
3453
- for t , lv in zip (transposed , self .flatten_lvalues (lvalues )):
3454
- # We can access _type_maps directly since temporary type maps are
3455
- # only created within expressions.
3456
- t .append (self ._type_maps [0 ].pop (lv , AnyType (TypeOfAny .special_form )))
3457
- union_types = tuple (make_simplified_union (col ) for col in transposed )
3458
- for expr , items in assignments .items ():
3459
- # Bind a union of types collected in 'assignments' to every expression.
3460
- if isinstance (expr , StarExpr ):
3461
- expr = expr .expr
3462
-
3463
- # TODO: See todo in binder.py, ConditionalTypeBinder.assign_type
3464
- # It's unclear why the 'declared_type' param is sometimes 'None'
3465
- clean_items : list [tuple [Type , Type ]] = []
3466
- for type , declared_type in items :
3467
- assert declared_type is not None
3468
- clean_items .append ((type , declared_type ))
3469
-
3470
- types , declared_types = zip (* clean_items )
3471
- self .binder .assign_type (
3472
- expr ,
3473
- make_simplified_union (list (types )),
3474
- make_simplified_union (list (declared_types )),
3475
- False ,
3428
+ # if `rvalue_type` is Optional type in non-strict Optional code, unwap it:
3429
+ relevant_items = rvalue_type .relevant_items ()
3430
+ if len (relevant_items ) == 1 :
3431
+ self .check_multi_assignment (
3432
+ lvalues , rvalue , context , infer_lvalue_type , relevant_items [0 ]
3476
3433
)
3477
- for union , lv in zip (union_types , self .flatten_lvalues (lvalues )):
3478
- # Properly store the inferred types.
3479
- _1 , _2 , inferred = self .check_lvalue (lv )
3480
- if inferred :
3481
- self .set_inferred_type (inferred , lv , union )
3434
+ return
3435
+
3436
+ # union to tuple conversion:
3437
+ star_idx = next ((i for i , lv in enumerate (lvalues ) if isinstance (lv , StarExpr )), None )
3438
+
3439
+ def handle_star_index (orig_types : list [Type ]) -> list [Type ]:
3440
+ if star_idx is not None :
3441
+ orig_types [star_idx ] = self .named_generic_type (
3442
+ "builtins.list" , [orig_types [star_idx ]]
3443
+ )
3444
+ return orig_types
3445
+
3446
+ nmb_subitems = len (lvalues )
3447
+ items : list [list [Type ]] = []
3448
+ for idx , item in enumerate (rvalue_type .items ):
3449
+ item = get_proper_type (item )
3450
+ if isinstance (item , TupleType ):
3451
+ delta = len (item .items ) - nmb_subitems
3452
+ if star_idx is None :
3453
+ if delta != 0 : # a, b = x, y, z or a, b, c = x, y
3454
+ self .msg .wrong_number_values_to_unpack (
3455
+ len (item .items ), nmb_subitems , context
3456
+ )
3457
+ return
3458
+ items .append (item .items .copy ()) # a, b = x, y
3459
+ else :
3460
+ if delta < - 1 : # a, b, c, *d = x, y
3461
+ self .msg .wrong_number_values_to_unpack (
3462
+ len (item .items ), nmb_subitems - 1 , context
3463
+ )
3464
+ return
3465
+ if delta == - 1 : # a, b, *c = x, y
3466
+ items .append (item .items .copy ())
3467
+ # to be removed after transposing:
3468
+ items [- 1 ].insert (star_idx , PlaceholderType ("temp" , [], - 1 ))
3469
+ elif delta == 0 : # a, b, *c = x, y, z
3470
+ items .append (handle_star_index (item .items .copy ()))
3471
+ else : # a, *b = x, y, z
3472
+ union = make_simplified_union (item .items [star_idx : star_idx + delta + 1 ])
3473
+ subitems = (
3474
+ item .items [:star_idx ] + [union ] + item .items [star_idx + delta + 1 :]
3475
+ )
3476
+ items .append (handle_star_index (subitems ))
3482
3477
else :
3483
- self .store_type (lv , union )
3484
- self .no_partial_types = False
3478
+ if isinstance (item , AnyType ):
3479
+ items .append (handle_star_index (nmb_subitems * [cast (Type , item )]))
3480
+ elif isinstance (item , Instance ) and (item .type .fullname == "builtins.str" ):
3481
+ self .msg .unpacking_strings_disallowed (context )
3482
+ return
3483
+ elif isinstance (item , Instance ) and self .type_is_iterable (item ):
3484
+ items .append (handle_star_index (nmb_subitems * [self .iterable_item_type (item )]))
3485
+ else :
3486
+ self .msg .type_not_iterable (item , context )
3487
+ return
3488
+ items_transposed = zip (* items )
3489
+ items_cleared = []
3490
+ for subitems_ in items_transposed :
3491
+ subitems = []
3492
+ for item in subitems_ :
3493
+ item = get_proper_type (item )
3494
+ if not isinstance (item , PlaceholderType ):
3495
+ subitems .append (item )
3496
+ items_cleared .append (subitems )
3497
+ tupletype = TupleType (
3498
+ [make_simplified_union (subitems ) for subitems in items_cleared ],
3499
+ fallback = self .named_type ("builtins.tuple" ),
3500
+ )
3501
+ self .check_multi_assignment_from_tuple (
3502
+ lvalues , rvalue , tupletype , context , infer_lvalue_type , False
3503
+ )
3504
+ return
3485
3505
3486
3506
def flatten_lvalues (self , lvalues : list [Expression ]) -> list [Expression ]:
3487
3507
res : list [Expression ] = []
@@ -3500,8 +3520,8 @@ def check_multi_assignment_from_tuple(
3500
3520
rvalue : Expression ,
3501
3521
rvalue_type : TupleType ,
3502
3522
context : Context ,
3503
- undefined_rvalue : bool ,
3504
3523
infer_lvalue_type : bool = True ,
3524
+ convert_star_rvalue_type : bool = True ,
3505
3525
) -> None :
3506
3526
if self .check_rvalue_count_in_assignment (lvalues , len (rvalue_type .items ), context ):
3507
3527
star_index = next (
@@ -3512,46 +3532,23 @@ def check_multi_assignment_from_tuple(
3512
3532
star_lv = cast (StarExpr , lvalues [star_index ]) if star_index != len (lvalues ) else None
3513
3533
right_lvs = lvalues [star_index + 1 :]
3514
3534
3515
- if not undefined_rvalue :
3516
- # Infer rvalue again, now in the correct type context.
3517
- lvalue_type = self .lvalue_type_for_inference (lvalues , rvalue_type )
3518
- reinferred_rvalue_type = get_proper_type (
3519
- self .expr_checker .accept (rvalue , lvalue_type )
3520
- )
3521
-
3522
- if isinstance (reinferred_rvalue_type , UnionType ):
3523
- # If this is an Optional type in non-strict Optional code, unwrap it.
3524
- relevant_items = reinferred_rvalue_type .relevant_items ()
3525
- if len (relevant_items ) == 1 :
3526
- reinferred_rvalue_type = get_proper_type (relevant_items [0 ])
3527
- if isinstance (reinferred_rvalue_type , UnionType ):
3528
- self .check_multi_assignment_from_union (
3529
- lvalues , rvalue , reinferred_rvalue_type , context , infer_lvalue_type
3530
- )
3531
- return
3532
- if isinstance (reinferred_rvalue_type , AnyType ):
3533
- # We can get Any if the current node is
3534
- # deferred. Doing more inference in deferred nodes
3535
- # is hard, so give up for now. We can also get
3536
- # here if reinferring types above changes the
3537
- # inferred return type for an overloaded function
3538
- # to be ambiguous.
3539
- return
3540
- assert isinstance (reinferred_rvalue_type , TupleType )
3541
- rvalue_type = reinferred_rvalue_type
3542
-
3543
3535
left_rv_types , star_rv_types , right_rv_types = self .split_around_star (
3544
3536
rvalue_type .items , star_index , len (lvalues )
3545
3537
)
3546
3538
3547
3539
for lv , rv_type in zip (left_lvs , left_rv_types ):
3548
3540
self .check_assignment (lv , self .temp_node (rv_type , context ), infer_lvalue_type )
3549
3541
if star_lv :
3550
- list_expr = ListExpr (
3551
- [self .temp_node (rv_type , context ) for rv_type in star_rv_types ]
3552
- )
3553
- list_expr .set_line (context )
3554
- self .check_assignment (star_lv .expr , list_expr , infer_lvalue_type )
3542
+ if convert_star_rvalue_type :
3543
+ list_expr = ListExpr (
3544
+ [self .temp_node (rv_type , context ) for rv_type in star_rv_types ]
3545
+ )
3546
+ list_expr .set_line (context )
3547
+ self .check_assignment (star_lv .expr , list_expr , infer_lvalue_type )
3548
+ else :
3549
+ self .check_assignment (
3550
+ star_lv .expr , self .temp_node (star_rv_types [0 ], context ), infer_lvalue_type
3551
+ )
3555
3552
for lv , rv_type in zip (right_lvs , right_rv_types ):
3556
3553
self .check_assignment (lv , self .temp_node (rv_type , context ), infer_lvalue_type )
3557
3554
@@ -3702,10 +3699,7 @@ def infer_variable_type(
3702
3699
"""Infer the type of initialized variables from initializer type."""
3703
3700
if isinstance (init_type , DeletedType ):
3704
3701
self .msg .deleted_as_rvalue (init_type , context )
3705
- elif (
3706
- not is_valid_inferred_type (init_type , is_lvalue_final = name .is_final )
3707
- and not self .no_partial_types
3708
- ):
3702
+ elif not is_valid_inferred_type (init_type , is_lvalue_final = name .is_final ):
3709
3703
# We cannot use the type of the initialization expression for full type
3710
3704
# inference (it's not specific enough), but we might be able to give
3711
3705
# partial type which will be made more specific later. A partial type
0 commit comments