Skip to content

Commit 621ee50

Browse files
committed
Add support for narrowing booleans to literals
1 parent 7189a23 commit 621ee50

File tree

3 files changed

+50
-21
lines changed

3 files changed

+50
-21
lines changed

mypy/checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4536,7 +4536,8 @@ def refine_identity_comparison_expression(self,
45364536

45374537
enum_name = None
45384538
target = get_proper_type(target)
4539-
if isinstance(target, LiteralType) and target.is_enum_literal():
4539+
if (isinstance(target, LiteralType) and
4540+
(target.is_enum_literal() or target.fallback.type.fullname == "builtins.bool")):
45404541
enum_name = target.fallback.type.fullname
45414542

45424543
target_type = [TypeRange(target, is_upper_bound=False)]

mypy/typeops.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -693,26 +693,32 @@ class Status(Enum):
693693
if isinstance(typ, UnionType):
694694
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
695695
return make_simplified_union(items, contract_literals=False)
696-
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
697-
new_items = []
698-
for name, symbol in typ.type.names.items():
699-
if not isinstance(symbol.node, Var):
700-
continue
701-
# Skip "_order_" and "__order__", since Enum will remove it
702-
if name in ("_order_", "__order__"):
703-
continue
704-
new_items.append(LiteralType(name, typ))
705-
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
706-
# insertion order only starting with Python 3.7. So, we sort these for older
707-
# versions of Python to help make tests deterministic.
708-
#
709-
# We could probably skip the sort for Python 3.6 since people probably run mypy
710-
# only using CPython, but we might as well for the sake of full correctness.
711-
if sys.version_info < (3, 7):
712-
new_items.sort(key=lambda lit: lit.value)
713-
return make_simplified_union(new_items, contract_literals=False)
714-
else:
715-
return typ
696+
elif isinstance(typ, Instance) and typ.type.fullname == target_fullname:
697+
if typ.type.is_enum:
698+
new_items = []
699+
for name, symbol in typ.type.names.items():
700+
if not isinstance(symbol.node, Var):
701+
continue
702+
# Skip "_order_" and "__order__", since Enum will remove it
703+
if name in ("_order_", "__order__"):
704+
continue
705+
new_items.append(LiteralType(name, typ))
706+
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
707+
# insertion order only starting with Python 3.7. So, we sort these for older
708+
# versions of Python to help make tests deterministic.
709+
#
710+
# We could probably skip the sort for Python 3.6 since people probably run mypy
711+
# only using CPython, but we might as well for the sake of full correctness.
712+
if sys.version_info < (3, 7):
713+
new_items.sort(key=lambda lit: lit.value)
714+
return make_simplified_union(new_items, contract_literals=False)
715+
elif typ.type.fullname == "builtins.bool":
716+
return make_simplified_union(
717+
[LiteralType(True, typ), LiteralType(False, typ)],
718+
contract_literals=False
719+
)
720+
721+
return typ
716722

717723

718724
def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:

test-data/unit/check-narrowing.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,5 +1027,27 @@ if str_or_bool_literal is not True and str_or_bool_literal is not False:
10271027
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str"
10281028
else:
10291029
reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]"
1030+
[builtins fixtures/primitives.pyi]
1031+
1032+
[case testNarrowingBoolean]
1033+
# flags: --strict-optional
1034+
from typing import Optional
1035+
from typing_extensions import Literal
1036+
1037+
bool_val: bool
10301038

1039+
if bool_val is not False:
1040+
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
1041+
else:
1042+
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
1043+
1044+
opt_bool_val: Optional[bool]
1045+
1046+
if opt_bool_val is not None:
1047+
reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool"
1048+
1049+
if opt_bool_val is not False:
1050+
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]"
1051+
else:
1052+
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
10311053
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)