Skip to content

Commit e63faf3

Browse files
grievejiameta-codesync[bot]
authored andcommitted
Fix unwrap_iterable for union types by distributing over members
Summary: The `unwrap_iterable` function checks whether a type is a subtype of `Iterable[T]` by creating a fresh type variable `T` and calling `is_subset_eq`. For union types like `tuple[int, ...] | tuple[str, ...]`, the solver pins `T` on the first member (e.g. `T = int`), then rejects later members with different element types (e.g. `str`). Fix this by using `distribute_over_union` to check each member independently with its own fresh type variable, then union the results. This also fixes `unwrap_async_iterable` which had the same issue. This addresses a regression from D95478317 where `tuple()` constructor calls now return structural `Type::Tuple` instead of nominal `ClassType`. When the result is a union of structural tuples used in star unpacking (`*expr`), pyrefly incorrectly reported a `not-iterable` error because `unwrap_iterable` couldn't handle the union. Reviewed By: ndmitchell Differential Revision: D95621346 fbshipit-source-id: 66b6ba5be8f7b6dbeca5f69bc5af74b14a70d72f
1 parent 13849da commit e63faf3

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

pyrefly/lib/alt/unwrap.rs

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -259,28 +259,42 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
259259

260260
/// Warning: this returns `Some` if the type is `Any` or a class that extends `Any`
261261
pub fn unwrap_iterable(&self, ty: &Type) -> Option<Type> {
262-
let iter_ty = self.fresh_var();
263-
let iterable_ty = self
264-
.heap
265-
.mk_class_type(self.stdlib.iterable(iter_ty.to_type(self.heap)));
266-
if self.is_subset_eq(ty, &iterable_ty) {
267-
Some(self.resolve_var(ty, iter_ty))
268-
} else {
269-
None
270-
}
262+
// Distribute over union members so each gets its own fresh type
263+
// variable. Checking the whole union at once fails because the solver
264+
// pins the variable on the first member, then rejects later members
265+
// with different element types.
266+
let mut failed = false;
267+
let result = self.distribute_over_union(ty, |member| {
268+
let iter_ty = self.fresh_var();
269+
let iterable_ty = self
270+
.heap
271+
.mk_class_type(self.stdlib.iterable(iter_ty.to_type(self.heap)));
272+
if self.is_subset_eq(member, &iterable_ty) {
273+
self.resolve_var(member, iter_ty)
274+
} else {
275+
failed = true;
276+
self.heap.mk_never()
277+
}
278+
});
279+
if failed { None } else { Some(result) }
271280
}
272281

273282
/// Warning: this returns `Some` if the type is `Any` or a class that extends `Any`
274283
pub fn unwrap_async_iterable(&self, ty: &Type) -> Option<Type> {
275-
let iter_ty = self.fresh_var();
276-
let iterable_ty = self
277-
.heap
278-
.mk_class_type(self.stdlib.async_iterable(iter_ty.to_type(self.heap)));
279-
if self.is_subset_eq(ty, &iterable_ty) {
280-
Some(self.resolve_var(ty, iter_ty))
281-
} else {
282-
None
283-
}
284+
let mut failed = false;
285+
let result = self.distribute_over_union(ty, |member| {
286+
let iter_ty = self.fresh_var();
287+
let iterable_ty = self
288+
.heap
289+
.mk_class_type(self.stdlib.async_iterable(iter_ty.to_type(self.heap)));
290+
if self.is_subset_eq(member, &iterable_ty) {
291+
self.resolve_var(member, iter_ty)
292+
} else {
293+
failed = true;
294+
self.heap.mk_never()
295+
}
296+
});
297+
if failed { None } else { Some(result) }
284298
}
285299

286300
/// Warning: this returns `Some` if the type is `Any` or a class that extends `Any`

pyrefly/lib/test/tuple.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,25 @@ assert_type(m, MyTuple)
429429
"#,
430430
);
431431

432+
testcase!(
433+
test_star_unpack_single_unbounded_tuple,
434+
r#"
435+
from typing import assert_type
436+
def test(x: tuple[int, ...]) -> None:
437+
y = (*x,)
438+
"#,
439+
);
440+
441+
testcase!(
442+
test_star_unpack_union_of_tuples,
443+
r#"
444+
from typing import assert_type
445+
def f() -> tuple[int, ...] | tuple[str, ...]:
446+
...
447+
x = (*f(),)
448+
"#,
449+
);
450+
432451
testcase!(
433452
test_tuple_aug_assign,
434453
r#"

0 commit comments

Comments
 (0)