Skip to content

Commit d7a8c04

Browse files
committed
Fix checking multiple assignments based on tuple unpacking involving partially initialised variables (Fixes python#12915).
This proposal is an alternative to python#14423. Similar to python#14423, the main idea is to convert unions of tuples to tuples of (simplified) unions during multi-assignment checks. In addition, it extends this idea to other iterable types, which allows removing the `undefined_rvalue` logic and the `no_partial_types` logic. Hence, the problem reported in python#12915 with partially initialised variables should be fixed for unions that combine, for example, tuples and lists, as well. Besides the new test case also provided by python#14423 (`testDefinePartiallyInitialisedVariableDuringTupleUnpacking`), this commit also adds the test cases `testUnionUnpackingIncludingListPackingSameItemTypes`, `testUnionUnpackingIncludingListPackingDifferentItemTypes`, and `testUnionUnpackingIncludingListPackingForVariousItemTypes`.
1 parent c4a5f56 commit d7a8c04

File tree

6 files changed

+301
-131
lines changed

6 files changed

+301
-131
lines changed

mypy/checker.py

Lines changed: 91 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
NoneType,
191191
Overloaded,
192192
PartialType,
193+
PlaceholderType,
193194
ProperType,
194195
StarType,
195196
TupleType,
@@ -338,8 +339,6 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
338339
# Used for collecting inferred attribute types so that they can be checked
339340
# for consistency.
340341
inferred_attribute_types: dict[Var, Type] | None = None
341-
# Don't infer partial None types if we are processing assignment from Union
342-
no_partial_types: bool = False
343342

344343
# The set of all dependencies (suppressed or not) that this module accesses, either
345344
# directly or indirectly.
@@ -3375,7 +3374,6 @@ def check_multi_assignment(
33753374
context: Context,
33763375
infer_lvalue_type: bool = True,
33773376
rv_type: Type | None = None,
3378-
undefined_rvalue: bool = False,
33793377
) -> None:
33803378
"""Check the assignment of one rvalue to a number of lvalues."""
33813379

@@ -3386,12 +3384,6 @@ def check_multi_assignment(
33863384
if isinstance(rvalue_type, TypeVarLikeType):
33873385
rvalue_type = get_proper_type(rvalue_type.upper_bound)
33883386

3389-
if isinstance(rvalue_type, UnionType):
3390-
# If this is an Optional type in non-strict Optional code, unwrap it.
3391-
relevant_items = rvalue_type.relevant_items()
3392-
if len(relevant_items) == 1:
3393-
rvalue_type = get_proper_type(relevant_items[0])
3394-
33953387
if isinstance(rvalue_type, AnyType):
33963388
for lv in lvalues:
33973389
if isinstance(lv, StarExpr):
@@ -3402,7 +3394,7 @@ def check_multi_assignment(
34023394
self.check_assignment(lv, temp_node, infer_lvalue_type)
34033395
elif isinstance(rvalue_type, TupleType):
34043396
self.check_multi_assignment_from_tuple(
3405-
lvalues, rvalue, rvalue_type, context, undefined_rvalue, infer_lvalue_type
3397+
lvalues, rvalue, rvalue_type, context, infer_lvalue_type
34063398
)
34073399
elif isinstance(rvalue_type, UnionType):
34083400
self.check_multi_assignment_from_union(
@@ -3430,58 +3422,86 @@ def check_multi_assignment_from_union(
34303422
x, y = t
34313423
reveal_type(x) # Union[int, str]
34323424
3433-
The idea in this case is to process the assignment for every item of the union.
3434-
Important note: the types are collected in two places, 'union_types' contains
3435-
inferred types for first assignments, 'assignments' contains the narrowed types
3436-
for binder.
3425+
The idea is to convert unions of tuples or other iterables to tuples of (simplified)
3426+
unions and then simply apply `check_multi_assignment_from_tuple`.
34373427
"""
3438-
self.no_partial_types = True
3439-
transposed: tuple[list[Type], ...] = tuple([] for _ in self.flatten_lvalues(lvalues))
3440-
# Notify binder that we want to defer bindings and instead collect types.
3441-
with self.binder.accumulate_type_assignments() as assignments:
3442-
for item in rvalue_type.items:
3443-
# Type check the assignment separately for each union item and collect
3444-
# the inferred lvalue types for each union item.
3445-
self.check_multi_assignment(
3446-
lvalues,
3447-
rvalue,
3448-
context,
3449-
infer_lvalue_type=infer_lvalue_type,
3450-
rv_type=item,
3451-
undefined_rvalue=True,
3452-
)
3453-
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
3454-
# We can access _type_maps directly since temporary type maps are
3455-
# only created within expressions.
3456-
t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form)))
3457-
union_types = tuple(make_simplified_union(col) for col in transposed)
3458-
for expr, items in assignments.items():
3459-
# Bind a union of types collected in 'assignments' to every expression.
3460-
if isinstance(expr, StarExpr):
3461-
expr = expr.expr
3462-
3463-
# TODO: See todo in binder.py, ConditionalTypeBinder.assign_type
3464-
# It's unclear why the 'declared_type' param is sometimes 'None'
3465-
clean_items: list[tuple[Type, Type]] = []
3466-
for type, declared_type in items:
3467-
assert declared_type is not None
3468-
clean_items.append((type, declared_type))
3469-
3470-
types, declared_types = zip(*clean_items)
3471-
self.binder.assign_type(
3472-
expr,
3473-
make_simplified_union(list(types)),
3474-
make_simplified_union(list(declared_types)),
3475-
False,
3428+
# if `rvalue_type` is Optional type in non-strict Optional code, unwap it:
3429+
relevant_items = rvalue_type.relevant_items()
3430+
if len(relevant_items) == 1:
3431+
self.check_multi_assignment(
3432+
lvalues, rvalue, context, infer_lvalue_type, relevant_items[0]
34763433
)
3477-
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
3478-
# Properly store the inferred types.
3479-
_1, _2, inferred = self.check_lvalue(lv)
3480-
if inferred:
3481-
self.set_inferred_type(inferred, lv, union)
3434+
return
3435+
3436+
# union to tuple conversion:
3437+
star_idx = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), None)
3438+
3439+
def handle_star_index(orig_types: list[Type]) -> list[Type]:
3440+
if star_idx is not None:
3441+
orig_types[star_idx] = self.named_generic_type(
3442+
"builtins.list", [orig_types[star_idx]]
3443+
)
3444+
return orig_types
3445+
3446+
nmb_subitems = len(lvalues)
3447+
items: list[list[Type]] = []
3448+
for idx, item in enumerate(rvalue_type.items):
3449+
item = get_proper_type(item)
3450+
if isinstance(item, TupleType):
3451+
delta = len(item.items) - nmb_subitems
3452+
if star_idx is None:
3453+
if delta != 0: # a, b = x, y, z or a, b, c = x, y
3454+
self.msg.wrong_number_values_to_unpack(
3455+
len(item.items), nmb_subitems, context
3456+
)
3457+
return
3458+
items.append(item.items.copy()) # a, b = x, y
3459+
else:
3460+
if delta < -1: # a, b, c, *d = x, y
3461+
self.msg.wrong_number_values_to_unpack(
3462+
len(item.items), nmb_subitems - 1, context
3463+
)
3464+
return
3465+
if delta == -1: # a, b, *c = x, y
3466+
items.append(item.items.copy())
3467+
# to be removed after transposing:
3468+
items[-1].insert(star_idx, PlaceholderType("temp", [], -1))
3469+
elif delta == 0: # a, b, *c = x, y, z
3470+
items.append(handle_star_index(item.items.copy()))
3471+
else: # a, *b = x, y, z
3472+
union = make_simplified_union(item.items[star_idx : star_idx + delta + 1])
3473+
subitems = (
3474+
item.items[:star_idx] + [union] + item.items[star_idx + delta + 1 :]
3475+
)
3476+
items.append(handle_star_index(subitems))
34823477
else:
3483-
self.store_type(lv, union)
3484-
self.no_partial_types = False
3478+
if isinstance(item, AnyType):
3479+
items.append(handle_star_index(nmb_subitems * [cast(Type, item)]))
3480+
elif isinstance(item, Instance) and (item.type.fullname == "builtins.str"):
3481+
self.msg.unpacking_strings_disallowed(context)
3482+
return
3483+
elif isinstance(item, Instance) and self.type_is_iterable(item):
3484+
items.append(handle_star_index(nmb_subitems * [self.iterable_item_type(item)]))
3485+
else:
3486+
self.msg.type_not_iterable(item, context)
3487+
return
3488+
items_transposed = zip(*items)
3489+
items_cleared = []
3490+
for subitems_ in items_transposed:
3491+
subitems = []
3492+
for item in subitems_:
3493+
item = get_proper_type(item)
3494+
if not isinstance(item, PlaceholderType):
3495+
subitems.append(item)
3496+
items_cleared.append(subitems)
3497+
tupletype = TupleType(
3498+
[make_simplified_union(subitems) for subitems in items_cleared],
3499+
fallback=self.named_type("builtins.tuple"),
3500+
)
3501+
self.check_multi_assignment_from_tuple(
3502+
lvalues, rvalue, tupletype, context, infer_lvalue_type, False
3503+
)
3504+
return
34853505

34863506
def flatten_lvalues(self, lvalues: list[Expression]) -> list[Expression]:
34873507
res: list[Expression] = []
@@ -3500,8 +3520,8 @@ def check_multi_assignment_from_tuple(
35003520
rvalue: Expression,
35013521
rvalue_type: TupleType,
35023522
context: Context,
3503-
undefined_rvalue: bool,
35043523
infer_lvalue_type: bool = True,
3524+
convert_star_rvalue_type: bool = True,
35053525
) -> None:
35063526
if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context):
35073527
star_index = next(
@@ -3512,46 +3532,23 @@ def check_multi_assignment_from_tuple(
35123532
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
35133533
right_lvs = lvalues[star_index + 1 :]
35143534

3515-
if not undefined_rvalue:
3516-
# Infer rvalue again, now in the correct type context.
3517-
lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type)
3518-
reinferred_rvalue_type = get_proper_type(
3519-
self.expr_checker.accept(rvalue, lvalue_type)
3520-
)
3521-
3522-
if isinstance(reinferred_rvalue_type, UnionType):
3523-
# If this is an Optional type in non-strict Optional code, unwrap it.
3524-
relevant_items = reinferred_rvalue_type.relevant_items()
3525-
if len(relevant_items) == 1:
3526-
reinferred_rvalue_type = get_proper_type(relevant_items[0])
3527-
if isinstance(reinferred_rvalue_type, UnionType):
3528-
self.check_multi_assignment_from_union(
3529-
lvalues, rvalue, reinferred_rvalue_type, context, infer_lvalue_type
3530-
)
3531-
return
3532-
if isinstance(reinferred_rvalue_type, AnyType):
3533-
# We can get Any if the current node is
3534-
# deferred. Doing more inference in deferred nodes
3535-
# is hard, so give up for now. We can also get
3536-
# here if reinferring types above changes the
3537-
# inferred return type for an overloaded function
3538-
# to be ambiguous.
3539-
return
3540-
assert isinstance(reinferred_rvalue_type, TupleType)
3541-
rvalue_type = reinferred_rvalue_type
3542-
35433535
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
35443536
rvalue_type.items, star_index, len(lvalues)
35453537
)
35463538

35473539
for lv, rv_type in zip(left_lvs, left_rv_types):
35483540
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
35493541
if star_lv:
3550-
list_expr = ListExpr(
3551-
[self.temp_node(rv_type, context) for rv_type in star_rv_types]
3552-
)
3553-
list_expr.set_line(context)
3554-
self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type)
3542+
if convert_star_rvalue_type:
3543+
list_expr = ListExpr(
3544+
[self.temp_node(rv_type, context) for rv_type in star_rv_types]
3545+
)
3546+
list_expr.set_line(context)
3547+
self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type)
3548+
else:
3549+
self.check_assignment(
3550+
star_lv.expr, self.temp_node(star_rv_types[0], context), infer_lvalue_type
3551+
)
35553552
for lv, rv_type in zip(right_lvs, right_rv_types):
35563553
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
35573554

@@ -3702,10 +3699,7 @@ def infer_variable_type(
37023699
"""Infer the type of initialized variables from initializer type."""
37033700
if isinstance(init_type, DeletedType):
37043701
self.msg.deleted_as_rvalue(init_type, context)
3705-
elif (
3706-
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
3707-
and not self.no_partial_types
3708-
):
3702+
elif not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final):
37093703
# We cannot use the type of the initialization expression for full type
37103704
# inference (it's not specific enough), but we might be able to give
37113705
# partial type which will be made more specific later. A partial type

test-data/unit/check-inference-context.test

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,8 @@ if int():
168168
ab, ao = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
169169
if int():
170170
ao, ab = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
171-
172171
if int():
173-
ao, ao = f(b)
172+
ao, ao = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
174173
if int():
175174
ab, ab = f(b)
176175
if int():
@@ -199,11 +198,10 @@ if int():
199198
ao, ab, ab, ab = h(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
200199
if int():
201200
ab, ab, ao, ab = h(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
202-
203201
if int():
204-
ao, ab, ab = f(b, b)
202+
ao, ab, ab = f(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
205203
if int():
206-
ab, ab, ao = g(b, b)
204+
ab, ab, ao = g(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
207205
if int():
208206
ab, ab, ab, ab = h(b, b)
209207

test-data/unit/check-inference.test

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,54 @@ class C:
19191919
a = 42
19201920
[out]
19211921

1922+
[case testDefinePartiallyInitialisedVariableDuringTupleUnpacking]
1923+
# flags: --strict-optional
1924+
from typing import Tuple, Union
1925+
1926+
t1: Union[Tuple[None], Tuple[str]]
1927+
x1 = None
1928+
x1, = t1
1929+
reveal_type(x1) # N: Revealed type is "Union[None, builtins.str]"
1930+
1931+
t2: Union[Tuple[str], Tuple[None]]
1932+
x2 = None
1933+
x2, = t2
1934+
reveal_type(x2) # N: Revealed type is "Union[builtins.str, None]"
1935+
1936+
t3: Union[Tuple[int], Tuple[str]]
1937+
x3 = None
1938+
x3, = t3
1939+
reveal_type(x3) # N: Revealed type is "Union[builtins.int, builtins.str]"
1940+
1941+
def f() -> Union[
1942+
Tuple[None, None, None, int, int, int, int, int, int],
1943+
Tuple[None, None, None, int, int, int, str, str, str]
1944+
]: ...
1945+
a1 = None
1946+
b1 = None
1947+
c1 = None
1948+
a2: object
1949+
b2: object
1950+
c2: object
1951+
a1, a2, a3, b1, b2, b3, c1, c2, c3 = f()
1952+
reveal_type(a1) # N: Revealed type is "None"
1953+
reveal_type(a2) # N: Revealed type is "None"
1954+
reveal_type(a3) # N: Revealed type is "None"
1955+
reveal_type(b1) # N: Revealed type is "builtins.int"
1956+
reveal_type(b2) # N: Revealed type is "builtins.int"
1957+
reveal_type(b3) # N: Revealed type is "builtins.int"
1958+
reveal_type(c1) # N: Revealed type is "Union[builtins.int, builtins.str]"
1959+
reveal_type(c2) # N: Revealed type is "Union[builtins.int, builtins.str]"
1960+
reveal_type(c3) # N: Revealed type is "Union[builtins.int, builtins.str]"
1961+
1962+
tt: Tuple[Union[Tuple[None], Tuple[str], Tuple[int]]]
1963+
z = None
1964+
z, = tt[0]
1965+
reveal_type(z) # N: Revealed type is "Union[None, builtins.str, builtins.int]"
1966+
1967+
[builtins fixtures/tuple.pyi]
1968+
1969+
19221970
-- More partial type errors
19231971
-- ------------------------
19241972

test-data/unit/check-tuples.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,8 @@ def g(x: T) -> Tuple[T, T]:
10541054
return (x, x)
10551055

10561056
z = 1
1057-
x, y = g(z) # E: Argument 1 to "g" has incompatible type "int"; expected "Tuple[B1, B2]"
1057+
x, y = g(z) # E: Incompatible types in assignment (expression has type "int", variable has type "Tuple[A, ...]") \
1058+
# E: Incompatible types in assignment (expression has type "int", variable has type "Tuple[Union[B1, C], Union[B2, C]]")
10581059
[builtins fixtures/tuple.pyi]
10591060
[out]
10601061

0 commit comments

Comments
 (0)