Skip to content

Commit cda376a

Browse files
Renkaicarljm
andauthored
[ty]eliminate definitely-impossible types from union in equality narrowing (#20164)
solves astral-sh/ty#939 --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent b14fc96 commit cda376a

File tree

3 files changed

+193
-26
lines changed

3 files changed

+193
-26
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,106 @@ if (x := f()) in (1,):
9292
else:
9393
reveal_type(x) # revealed: Literal[2, 3]
9494
```
95+
96+
## Union with `Literal`, `None` and `int`
97+
98+
```py
99+
from typing import Literal
100+
101+
def test(x: Literal["a", "b", "c"] | None | int = None):
102+
if x in ("a", "b"):
103+
# int is included because custom __eq__ methods could make
104+
# an int equal to "a" or "b", so we can't eliminate it
105+
reveal_type(x) # revealed: Literal["a", "b"] | int
106+
else:
107+
reveal_type(x) # revealed: Literal["c"] | None | int
108+
```
109+
110+
## Direct `not in` conditional
111+
112+
```py
113+
from typing import Literal
114+
115+
def test(x: Literal["a", "b", "c"] | None | int = None):
116+
if x not in ("a", "c"):
117+
# int is included because custom __eq__ methods could make
118+
# an int equal to "a" or "b", so we can't eliminate it
119+
reveal_type(x) # revealed: Literal["b"] | None | int
120+
else:
121+
reveal_type(x) # revealed: Literal["a", "c"] | int
122+
```
123+
124+
## bool
125+
126+
```py
127+
def _(x: bool):
128+
if x in (True,):
129+
reveal_type(x) # revealed: Literal[True]
130+
else:
131+
reveal_type(x) # revealed: Literal[False]
132+
133+
def _(x: bool | str):
134+
if x in (False,):
135+
# `str` remains due to possible custom __eq__ methods on a subclass
136+
reveal_type(x) # revealed: Literal[False] | str
137+
else:
138+
reveal_type(x) # revealed: Literal[True] | str
139+
```
140+
141+
## LiteralString
142+
143+
```py
144+
from typing_extensions import LiteralString
145+
146+
def _(x: LiteralString):
147+
if x in ("a", "b", "c"):
148+
reveal_type(x) # revealed: Literal["a", "b", "c"]
149+
else:
150+
reveal_type(x) # revealed: LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"]
151+
152+
def _(x: LiteralString | int):
153+
if x in ("a", "b", "c"):
154+
reveal_type(x) # revealed: Literal["a", "b", "c"] | int
155+
else:
156+
reveal_type(x) # revealed: (LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"]) | int
157+
```
158+
159+
## enums
160+
161+
```py
162+
from enum import Enum
163+
164+
class Color(Enum):
165+
RED = "red"
166+
GREEN = "green"
167+
BLUE = "blue"
168+
169+
def _(x: Color):
170+
if x in (Color.RED, Color.GREEN):
171+
# TODO should be `Literal[Color.RED, Color.GREEN]`
172+
reveal_type(x) # revealed: Color
173+
else:
174+
# TODO should be `Literal[Color.BLUE]`
175+
reveal_type(x) # revealed: Color
176+
```
177+
178+
## Union with enum and `int`
179+
180+
```py
181+
from enum import Enum
182+
183+
class Status(Enum):
184+
PENDING = 1
185+
APPROVED = 2
186+
REJECTED = 3
187+
188+
def test(x: Status | int):
189+
if x in (Status.PENDING, Status.APPROVED):
190+
# TODO should be `Literal[Status.PENDING, Status.APPROVED] | int`
191+
# int is included because custom __eq__ methods could make
192+
# an int equal to Status.PENDING or Status.APPROVED, so we can't eliminate it
193+
reveal_type(x) # revealed: Status | int
194+
else:
195+
# TODO should be `Literal[Status.REJECTED] | int`
196+
reveal_type(x) # revealed: Status | int
197+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,16 @@ impl<'db> Type<'db> {
10541054
|| self.is_literal_string()
10551055
}
10561056

1057+
pub(crate) fn is_union_with_single_valued(&self, db: &'db dyn Db) -> bool {
1058+
self.into_union().is_some_and(|union| {
1059+
union
1060+
.elements(db)
1061+
.iter()
1062+
.any(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
1063+
}) || self.is_bool(db)
1064+
|| self.is_literal_string()
1065+
}
1066+
10571067
pub(crate) fn into_string_literal(self) -> Option<StringLiteralType<'db>> {
10581068
match self {
10591069
Type::StringLiteral(string_literal) => Some(string_literal),
@@ -9953,14 +9963,6 @@ impl<'db> StringLiteralType<'db> {
99539963
pub(crate) fn python_len(self, db: &'db dyn Db) -> usize {
99549964
self.value(db).chars().count()
99559965
}
9956-
9957-
/// Return an iterator over each character in the string literal.
9958-
/// as would be returned by Python's `iter()`.
9959-
pub(crate) fn iter_each_char(self, db: &'db dyn Db) -> impl Iterator<Item = Self> {
9960-
self.value(db)
9961-
.chars()
9962-
.map(|c| StringLiteralType::new(db, c.to_string().into_boxed_str()))
9963-
}
99649966
}
99659967

99669968
/// # Ordering

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -615,24 +615,88 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
615615
}
616616
}
617617

618+
// TODO `expr_in` and `expr_not_in` should perhaps be unified with `expr_eq` and `expr_ne`,
619+
// since `eq` and `ne` are equivalent to `in` and `not in` with only one element in the RHS.
618620
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
619621
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
620-
if let Type::StringLiteral(string_literal) = rhs_ty {
621-
Some(UnionType::from_elements(
622-
self.db,
623-
string_literal
624-
.iter_each_char(self.db)
625-
.map(Type::StringLiteral),
626-
))
627-
} else if let Some(tuple_spec) = rhs_ty.tuple_instance_spec(self.db) {
628-
// N.B. Strictly speaking this is unsound, since a tuple subclass might override `__contains__`
629-
// but we'd still apply the narrowing here. This seems unlikely, however, and narrowing is
630-
// generally unsound in numerous ways anyway (attribute narrowing, subscript, narrowing,
631-
// narrowing of globals, etc.). So this doesn't seem worth worrying about too much.
632-
Some(UnionType::from_elements(self.db, tuple_spec.all_elements()))
633-
} else {
634-
None
622+
rhs_ty
623+
.try_iterate(self.db)
624+
.ok()
625+
.map(|iterable| iterable.homogeneous_element_type(self.db))
626+
} else if lhs_ty.is_union_with_single_valued(self.db) {
627+
let rhs_values = rhs_ty
628+
.try_iterate(self.db)
629+
.ok()?
630+
.homogeneous_element_type(self.db);
631+
632+
let mut builder = UnionBuilder::new(self.db);
633+
634+
// Add the narrowed values from the RHS first, to keep literals before broader types.
635+
builder = builder.add(rhs_values);
636+
637+
if let Some(lhs_union) = lhs_ty.into_union() {
638+
for element in lhs_union.elements(self.db) {
639+
// Keep only the non-single-valued portion of the original type.
640+
if !element.is_single_valued(self.db)
641+
&& !element.is_literal_string()
642+
&& !element.is_bool(self.db)
643+
{
644+
builder = builder.add(*element);
645+
}
646+
}
635647
}
648+
Some(builder.build())
649+
} else {
650+
None
651+
}
652+
}
653+
654+
fn evaluate_expr_not_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
655+
let rhs_values = rhs_ty
656+
.try_iterate(self.db)
657+
.ok()?
658+
.homogeneous_element_type(self.db);
659+
660+
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
661+
// Exclude the RHS values from the entire (single-valued) LHS domain.
662+
let complement = IntersectionBuilder::new(self.db)
663+
.add_positive(lhs_ty)
664+
.add_negative(rhs_values)
665+
.build();
666+
Some(complement)
667+
} else if lhs_ty.is_union_with_single_valued(self.db) {
668+
// Split LHS into single-valued portion and the rest. Exclude RHS values from the
669+
// single-valued portion, keep the rest intact.
670+
let mut single_builder = UnionBuilder::new(self.db);
671+
let mut rest_builder = UnionBuilder::new(self.db);
672+
673+
if let Some(lhs_union) = lhs_ty.into_union() {
674+
for element in lhs_union.elements(self.db) {
675+
if element.is_single_valued(self.db)
676+
|| element.is_literal_string()
677+
|| element.is_bool(self.db)
678+
{
679+
single_builder = single_builder.add(*element);
680+
} else {
681+
rest_builder = rest_builder.add(*element);
682+
}
683+
}
684+
}
685+
686+
let single_union = single_builder.build();
687+
let rest_union = rest_builder.build();
688+
689+
let narrowed_single = IntersectionBuilder::new(self.db)
690+
.add_positive(single_union)
691+
.add_negative(rhs_values)
692+
.build();
693+
694+
// Keep order: first literal complement, then broader arms.
695+
let result = UnionBuilder::new(self.db)
696+
.add(narrowed_single)
697+
.add(rest_union)
698+
.build();
699+
Some(result)
636700
} else {
637701
None
638702
}
@@ -660,9 +724,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
660724
ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
661725
ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
662726
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
663-
ast::CmpOp::NotIn => self
664-
.evaluate_expr_in(lhs_ty, rhs_ty)
665-
.map(|ty| ty.negate(self.db)),
727+
ast::CmpOp::NotIn => self.evaluate_expr_not_in(lhs_ty, rhs_ty),
666728
_ => None,
667729
}
668730
}

0 commit comments

Comments
 (0)