Skip to content

Commit 67c6753

Browse files
committed
Fix narrowing of union types containing StrEnum/IntEnum and Literal
When a union contains both StrEnum/IntEnum and Literal/None types, the ambiguity guard in narrow_type_by_identity_equality skips all narrowing. This processes Literal/None union items individually via conditional_types while keeping enum items as-is. Fixes #20915
1 parent 0177c0d commit 67c6753

File tree

4 files changed

+83
-3
lines changed

4 files changed

+83
-3
lines changed

mypy/checker.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6783,6 +6783,40 @@ def narrow_type_by_identity_equality(
67836783
enum_comparison_is_ambiguous
67846784
and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1
67856785
):
6786+
# For unions with both StrEnum/IntEnum and Literal/None items,
6787+
# narrow the Literal/None items while keeping enum items as-is.
6788+
orig_type = get_proper_type(coerce_to_literal(operand_types[i]))
6789+
if isinstance(orig_type, UnionType):
6790+
yes_items: list[Type] = []
6791+
no_items: list[Type] = []
6792+
has_narrowable = False
6793+
target = TypeRange(target_type, is_upper_bound=False)
6794+
for item in orig_type.items:
6795+
p_item = get_proper_type(item)
6796+
is_enum = bool(
6797+
ambiguous_enum_equality_keys(item) - {"<other>"}
6798+
)
6799+
if not is_enum and isinstance(p_item, (LiteralType, NoneType)):
6800+
has_narrowable = True
6801+
y, n = conditional_types(
6802+
item, [target], default=item, from_equality=True
6803+
)
6804+
yes_items.append(y)
6805+
no_items.append(n)
6806+
else:
6807+
yes_items.append(item)
6808+
no_items.append(item)
6809+
if has_narrowable:
6810+
if_map, else_map = conditional_types_to_typemaps(
6811+
operands[i],
6812+
UnionType.make_union(yes_items),
6813+
UnionType.make_union(no_items),
6814+
)
6815+
all_if_maps.append(if_map)
6816+
if is_target_for_value_narrowing(
6817+
get_proper_type(target_type)
6818+
):
6819+
all_else_maps.append(else_map)
67866820
continue
67876821

67886822
target = TypeRange(target_type, is_upper_bound=False)

mypy/stubtest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,6 +2049,10 @@ def _named_type(name: str) -> mypy.types.Instance:
20492049
return mypy.types.TupleType(items, fallback)
20502050

20512051
fallback = mypy.types.Instance(type_info, [anytype() for _ in type_info.type_vars])
2052+
if type(runtime) != runtime.__class__:
2053+
# Since `__class__` is redefined for an instance, we can't trust
2054+
# its `isinstance` checks, it can be dynamic. See #20919
2055+
return fallback
20522056

20532057
value: bool | int | str
20542058
if isinstance(runtime, enum.Enum) and isinstance(runtime.name, str):

mypy/test/teststubtest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,49 @@ def f(): return 3
15641564
error=None,
15651565
)
15661566

1567+
@collect_cases
1568+
def test_proxy_object(self) -> Iterator[Case]:
1569+
yield Case(
1570+
stub="""
1571+
class LazyObject:
1572+
def __init__(self, func: object) -> None: ...
1573+
def __bool__(self) -> bool: ...
1574+
""",
1575+
runtime="""
1576+
class LazyObject:
1577+
def __init__(self, func):
1578+
self.__dict__["_wrapped"] = None
1579+
self.__dict__["_setupfunc"] = func
1580+
def _setup(self):
1581+
self.__dict__["_wrapped"] = self._setupfunc()
1582+
@property
1583+
def __class__(self):
1584+
if self._wrapped is None:
1585+
self._setup()
1586+
return type(self._wrapped)
1587+
def __bool__(self):
1588+
if self._wrapped is None:
1589+
self._setup()
1590+
return bool(self._wrapped)
1591+
""",
1592+
error="test_module.LazyObject.__class__",
1593+
)
1594+
yield Case(
1595+
stub="""
1596+
def default_value() -> bool: ...
1597+
1598+
DEFAULT_VALUE: bool
1599+
""",
1600+
runtime="""
1601+
def default_value():
1602+
return True
1603+
1604+
DEFAULT_VALUE = LazyObject(default_value)
1605+
bool(DEFAULT_VALUE) # evaluate the lazy object
1606+
""",
1607+
error="test_module.DEFAULT_VALUE",
1608+
)
1609+
15671610
@collect_cases
15681611
def test_all_at_runtime_not_stub(self) -> Iterator[Case]:
15691612
yield Case(

test-data/unit/check-enum.test

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2784,9 +2784,8 @@ def f1(a: Foo | Literal['foo']) -> Foo:
27842784
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
27852785
return Foo.FOO
27862786

2787-
# Ideally this passes
2788-
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
2789-
return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo")
2787+
reveal_type(a) # N: Revealed type is "Literal[__main__.Foo.FOO]"
2788+
return a
27902789
[builtins fixtures/primitives.pyi]
27912790

27922791
[case testStrEnumEqualityAlias]

0 commit comments

Comments
 (0)