Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/pyrefly_types/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::class::ClassType;
use crate::literal::Lit;
use crate::stdlib::Stdlib;
use crate::tuple::Tuple;
use crate::typed_dict::TypedDict;
use crate::types::Type;
use crate::types::Union;

Expand Down Expand Up @@ -74,6 +75,7 @@ fn unions_internal(
let mut res = flatten_and_dedup(xs);
if let Some(stdlib) = stdlib {
collapse_literals(&mut res, stdlib, enum_members.unwrap_or(&|_| None));
promote_anonymous_typed_dicts(&mut res, stdlib);
}
collapse_tuple_unions_with_empty(&mut res);
// `res` is collapsible again if `flatten_and_dedup` drops `xs` to 0 or 1 elements
Expand Down Expand Up @@ -254,6 +256,17 @@ fn collapse_literals(
}
}

/// Promote anonymous typed dicts to `dict[str, value_type]`
fn promote_anonymous_typed_dicts(types: &mut [Type], stdlib: &Stdlib) {
for ty in types.iter_mut() {
if let Type::TypedDict(TypedDict::Anonymous(inner)) = ty {
*ty = stdlib
.dict(stdlib.str().clone().to_type(), inner.value_type.clone())
.to_type();
}
}
}

fn collapse_tuple_unions_with_empty(types: &mut Vec<Type>) {
let Some(empty_idx) = types.iter().position(|t| match t {
Type::Tuple(Tuple::Concrete(elts)) => elts.is_empty(),
Expand Down
45 changes: 24 additions & 21 deletions pyrefly/lib/state/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,27 +1468,30 @@ impl<'a> Transaction<'a> {
// Check; demand; check - the second check is guaranteed to work.
for _ in 0..2 {
let lock = module_data.state.read();
if let Some(solutions) = &lock.steps.solutions
&& lock.epochs.checked == self.data.now
&& lock.steps.last_step == Some(Step::Solutions)
{
return solutions.get_hashed_opt(key).duped();
} else if let Some(answers) = &lock.steps.answers {
let load = lock.steps.load.dupe().unwrap();
let answers = answers.dupe();
drop(lock);
let stdlib = self.get_stdlib(&module_data.handle);
let lookup = self.lookup(module_data);
return answers.1.solve_exported_key(
&lookup,
&lookup,
&answers.0,
&load.errors,
&stdlib,
&self.data.state.uniques,
key,
thread_state,
);
if lock.epochs.checked == self.data.now {
// Only use existing solutions or answers if the module data is current.
// Otherwise, the module might be dirty and require computation.
if let Some(solutions) = &lock.steps.solutions
&& lock.steps.last_step == Some(Step::Solutions)
{
return solutions.get_hashed_opt(key).duped();
} else if let Some(answers) = &lock.steps.answers {
let load = lock.steps.load.dupe().unwrap();
let answers = answers.dupe();
drop(lock);
let stdlib = self.get_stdlib(&module_data.handle);
let lookup = self.lookup(module_data);
return answers.1.solve_exported_key(
&lookup,
&lookup,
&answers.0,
&load.errors,
&stdlib,
&self.data.state.uniques,
key,
thread_state,
);
}
}
drop(lock);
self.demand(&module_data, Step::Answers);
Expand Down
3 changes: 2 additions & 1 deletion pyrefly/lib/test/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2284,7 +2284,8 @@ class A:
self.y = {"x": 0} if check else 42
def f(a: A):
x: TD = a.x
y: TD | int = a.y
# anoynmous typed dicts are promoted away when unioned
y: dict[str, int] | int = a.y
"#,
);

Expand Down
17 changes: 16 additions & 1 deletion pyrefly/lib/test/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ dict(x = 1, y = "test")
"#,
);

testcase!(
test_anonymous_typed_dict_union_promotion,
r#"
from typing import assert_type

def test(cond: bool):
x = {"a": 1, "b": "2"}
y = {"a": 1, "b": "2", "c": 3}
# we promote anonymous typed dicts when unioning
z = x if cond else y
assert_type(z["a"], int | str)
assert_type(z, dict[str, int | str])
"#,
);

testcase!(
test_unpack_empty,
r#"
Expand Down Expand Up @@ -48,6 +63,6 @@ def bar(yes: bool) -> None:
else:
kwargs = {"goodbye": 1}

foo(**kwargs)
foo(**kwargs)
"#,
);
Loading