Skip to content

Commit 4ed3d3a

Browse files
rchen152meta-codesync[bot]
authored andcommitted
Convert union of types to type of union in unions_internal (#2708)
Summary: Pull Request resolved: #2708 The spec says that `type` distributes over unions: https://typing.python.org/en/latest/spec/special-types.html#type, but we weren't always respecting this. Fixed by standardizing how we represent a union of types/type of union. I also fixed up a few places where we weren't handling `type[Union[...]]` correctly, which is a pre-existing issue that is more obvious now that we're standardizing to that form. Differential Revision: D95638889
1 parent 9805d0b commit 4ed3d3a

File tree

6 files changed

+119
-2
lines changed

6 files changed

+119
-2
lines changed

crates/pyrefly_types/src/simplify.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ fn unions_internal(
8080
promote_anonymous_typed_dicts(&mut res, stdlib, heap);
8181
}
8282
collapse_tuple_unions_with_empty(&mut res, heap);
83+
collapse_builtins_type(&mut res, heap);
8384
// `res` is collapsible again if `flatten_and_dedup` drops `xs` to 0 or 1 elements
8485
try_collapse(res, heap).unwrap_or_else(|members| heap.mk_union(members))
8586
})
@@ -320,6 +321,38 @@ fn flatten_unpacked_concrete_tuples(elts: Vec<Type>) -> Vec<Type> {
320321
result
321322
}
322323

324+
/// `type[int] | type[str]` => `type[int | str]`
325+
fn collapse_builtins_type(types: &mut Vec<Type>, heap: &TypeHeap) {
326+
let mut idx = 0;
327+
let mut first_elt = None;
328+
let mut additional_elts = Vec::new();
329+
types.retain(|t| {
330+
let retain = match t {
331+
Type::Type(box t) if first_elt.is_none() => {
332+
first_elt = Some((idx, t.clone()));
333+
true
334+
}
335+
Type::Type(box t) => {
336+
additional_elts.push(t.clone());
337+
false
338+
}
339+
_ => true,
340+
};
341+
idx += 1;
342+
retain
343+
});
344+
if let Some((idx, first_elt)) = first_elt
345+
&& !additional_elts.is_empty()
346+
{
347+
let mut elts = vec![first_elt.clone()];
348+
elts.extend(additional_elts);
349+
*(types
350+
.get_mut(idx)
351+
.expect("idx out of bounds when collapsing type members in union")) =
352+
heap.mk_type_form(heap.mk_union(elts));
353+
}
354+
}
355+
323356
// After a TypeVarTuple gets substituted with a tuple type, try to simplify the type
324357
pub fn simplify_tuples(tuple: Tuple, _heap: &TypeHeap) -> Tuple {
325358
match tuple {

pyrefly/lib/alt/call.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
234234
Type::Type(box Type::ClassType(cls)) => CallTargetLookup::Ok(Box::new(
235235
CallTarget::Class(cls, ConstructorKind::TypeOfClass, None),
236236
)),
237+
// `type[A | B]` is equivalent to `type[A] | type[B]` for call target resolution.
238+
// Distribute `type[...]` over union members and resolve as a union.
239+
Type::Type(box Type::Union(box Union { members: xs, .. })) => {
240+
let union_of_types = self
241+
.heap
242+
.mk_union(xs.into_iter().map(|x| self.heap.mk_type_form(x)).collect());
243+
self.as_call_target_impl(union_of_types, quantified)
244+
}
237245
Type::Type(box Type::SelfType(cls)) => CallTargetLookup::Ok(Box::new(
238246
CallTarget::Class(cls, ConstructorKind::TypeOfSelf, None),
239247
)),
@@ -326,7 +334,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
326334
}
327335
Type::Any(style) => CallTargetLookup::Ok(Box::new(CallTarget::Any(style))),
328336
Type::TypeAlias(ta) => {
329-
self.as_call_target_impl(self.get_type_alias(&ta).as_value(self.stdlib), quantified)
337+
let body = self.get_type_alias(&ta).as_value(self.stdlib);
338+
match body {
339+
// This comes from an expression like `int | str`, which is not callable.
340+
Type::Type(box Type::Union(_)) => CallTargetLookup::Error(vec![]),
341+
_ => self.as_call_target_impl(body, quantified),
342+
}
330343
}
331344
Type::ClassType(cls) => {
332345
let maybe_dunder_call = if let Some(quantified) = &quantified {

pyrefly/lib/report/pysa/types.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ fn get_classes_of_type(type_: &Type, context: &ModuleContext) -> ClassNamesFromT
292292
ClassNamesFromType::from_class(class_type.class_object(), context)
293293
.prepend_modifier(TypeModifier::Type)
294294
}
295+
Type::Type(box Type::Union(box Union {
296+
members: elements, ..
297+
})) if !elements.is_empty() => elements
298+
.iter()
299+
.map(|inner| match inner {
300+
Type::ClassType(class_type) => {
301+
ClassNamesFromType::from_class(class_type.class_object(), context)
302+
.prepend_modifier(TypeModifier::Type)
303+
}
304+
_ => ClassNamesFromType::not_a_class(),
305+
})
306+
.reduce(|acc, next| acc.join_with(next))
307+
.expect("expected at least one element in union")
308+
.sort_and_dedup(),
295309
Type::Tuple(_) => ClassNamesFromType::from_class(context.stdlib.tuple_object(), context),
296310
Type::TypedDict(TypedDict::TypedDict(inner)) => {
297311
ClassNamesFromType::from_class(inner.class_object(), context)

pyrefly/lib/test/pysa/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class MyTypedDict(TypedDict):
492492

493493
assert_eq!(
494494
PysaType::new(
495-
"type[test.A] | type[test.B]".to_owned(),
495+
"type[test.A | test.B]".to_owned(),
496496
ClassNamesFromType::from_classes(
497497
vec![
498498
get_class_ref("test", "A", &context),

pyrefly/lib/test/simple.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,54 @@ def f(x: type[int] | type[str], y: type[int | str]) -> None:
5858
"#,
5959
);
6060

61+
testcase!(
62+
test_distribute_type_assignability,
63+
r#"
64+
def f() -> type[int | str]: ...
65+
def g() -> type[int] | type[str]: ...
66+
x: type[int] | type[str] = f()
67+
y: type[int | str] = g()
68+
"#,
69+
);
70+
71+
testcase!(
72+
test_type_of_union_matches_type_param,
73+
r#"
74+
from typing import assert_type
75+
def f[T](x: type[T]) -> T: ...
76+
assert_type(f(int | None), int | None)
77+
"#,
78+
);
79+
80+
testcase!(
81+
test_type_of_union_matches_type_param_or_none,
82+
r#"
83+
from typing import assert_type
84+
def f[T](x: type[T] | None) -> T: ...
85+
assert_type(f(int | str), int | str)
86+
"#,
87+
);
88+
89+
testcase!(
90+
test_type_of_union_partially_matches_type_param,
91+
r#"
92+
from typing import assert_type
93+
def f[T](x: type[T] | type[int]) -> T: ...
94+
def g(x: type[int | str]):
95+
assert_type(f(x), str)
96+
"#,
97+
);
98+
99+
testcase!(
100+
test_union_of_type_matches_type_param,
101+
r#"
102+
from typing import assert_type
103+
def f[T](x: type[T]) -> T: ...
104+
def g(x: type[int] | type[str]):
105+
assert_type(f(x), int | str)
106+
"#,
107+
);
108+
61109
testcase!(
62110
test_simple_call,
63111
r#"

pyrefly/lib/test/type_alias.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,3 +1262,12 @@ def f3(x: C) -> None:
12621262
assert_type(x, list[int])
12631263
"#,
12641264
);
1265+
1266+
testcase!(
1267+
test_union_is_not_callable,
1268+
r#"
1269+
from typing import TypeAlias
1270+
X: TypeAlias = int | str
1271+
X() # E: Expected a callable
1272+
"#,
1273+
);

0 commit comments

Comments
 (0)