Skip to content

Commit b83ac9c

Browse files
authored
Try empty context when assigning to union typed variables (#14151)
Fixes #4805 Fixes #13936 It is known that mypy can overuse outer type context sometimes (especially when it is a union). This prevents a common use case for narrowing types (see issue and test cases). This is a somewhat major semantic change, but I think it should match better what a user would expect.
1 parent 3c5f368 commit b83ac9c

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

mypy/checker.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
AssignmentStmt,
7777
Block,
7878
BreakStmt,
79+
BytesExpr,
7980
CallExpr,
8081
ClassDef,
8182
ComparisonExpr,
@@ -86,6 +87,7 @@
8687
EllipsisExpr,
8788
Expression,
8889
ExpressionStmt,
90+
FloatExpr,
8991
ForStmt,
9092
FuncBase,
9193
FuncDef,
@@ -115,6 +117,7 @@
115117
ReturnStmt,
116118
StarExpr,
117119
Statement,
120+
StrExpr,
118121
SymbolNode,
119122
SymbolTable,
120123
SymbolTableNode,
@@ -3826,6 +3829,23 @@ def inference_error_fallback_type(self, type: Type) -> Type:
38263829
# we therefore need to erase them.
38273830
return erase_typevars(fallback)
38283831

3832+
def simple_rvalue(self, rvalue: Expression) -> bool:
3833+
"""Returns True for expressions for which inferred type should not depend on context.
3834+
3835+
Note that this function can still return False for some expressions where inferred type
3836+
does not depend on context. It only exists for performance optimizations.
3837+
"""
3838+
if isinstance(rvalue, (IntExpr, StrExpr, BytesExpr, FloatExpr, RefExpr)):
3839+
return True
3840+
if isinstance(rvalue, CallExpr):
3841+
if isinstance(rvalue.callee, RefExpr) and isinstance(rvalue.callee.node, FuncBase):
3842+
typ = rvalue.callee.node.type
3843+
if isinstance(typ, CallableType):
3844+
return not typ.variables
3845+
elif isinstance(typ, Overloaded):
3846+
return not any(item.variables for item in typ.items)
3847+
return False
3848+
38293849
def check_simple_assignment(
38303850
self,
38313851
lvalue_type: Type | None,
@@ -3847,6 +3867,30 @@ def check_simple_assignment(
38473867
rvalue_type = self.expr_checker.accept(
38483868
rvalue, lvalue_type, always_allow_any=always_allow_any
38493869
)
3870+
if (
3871+
isinstance(get_proper_type(lvalue_type), UnionType)
3872+
# Skip literal types, as they have special logic (for better errors).
3873+
and not isinstance(get_proper_type(rvalue_type), LiteralType)
3874+
and not self.simple_rvalue(rvalue)
3875+
):
3876+
# Try re-inferring r.h.s. in empty context, and use that if it
3877+
# results in a narrower type. We don't do this always because this
3878+
# may cause some perf impact, plus we want to partially preserve
3879+
# the old behavior. This helps with various practical examples, see
3880+
# e.g. testOptionalTypeNarrowedByGenericCall.
3881+
with self.msg.filter_errors() as local_errors, self.local_type_map() as type_map:
3882+
alt_rvalue_type = self.expr_checker.accept(
3883+
rvalue, None, always_allow_any=always_allow_any
3884+
)
3885+
if (
3886+
not local_errors.has_new_errors()
3887+
# Skip Any type, since it is special cased in binder.
3888+
and not isinstance(get_proper_type(alt_rvalue_type), AnyType)
3889+
and is_valid_inferred_type(alt_rvalue_type)
3890+
and is_proper_subtype(alt_rvalue_type, rvalue_type)
3891+
):
3892+
rvalue_type = alt_rvalue_type
3893+
self.store_types(type_map)
38503894
if isinstance(rvalue_type, DeletedType):
38513895
self.msg.deleted_as_rvalue(rvalue_type, context)
38523896
if isinstance(lvalue_type, DeletedType):

test-data/unit/check-inference-context.test

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,3 +1419,60 @@ def bar(x: Union[Mapping[Any, Any], Dict[Any, Sequence[Any]]]) -> None:
14191419
...
14201420
bar({1: 2})
14211421
[builtins fixtures/dict.pyi]
1422+
1423+
[case testOptionalTypeNarrowedByGenericCall]
1424+
# flags: --strict-optional
1425+
from typing import Dict, Optional
1426+
1427+
d: Dict[str, str] = {}
1428+
1429+
def foo(arg: Optional[str] = None) -> None:
1430+
if arg is None:
1431+
arg = d.get("a", "b")
1432+
reveal_type(arg) # N: Revealed type is "builtins.str"
1433+
[builtins fixtures/dict.pyi]
1434+
1435+
[case testOptionalTypeNarrowedByGenericCall2]
1436+
# flags: --strict-optional
1437+
from typing import Dict, Optional
1438+
1439+
d: Dict[str, str] = {}
1440+
x: Optional[str]
1441+
if x:
1442+
reveal_type(x) # N: Revealed type is "builtins.str"
1443+
x = d.get(x, x)
1444+
reveal_type(x) # N: Revealed type is "builtins.str"
1445+
[builtins fixtures/dict.pyi]
1446+
1447+
[case testOptionalTypeNarrowedByGenericCall3]
1448+
# flags: --strict-optional
1449+
from typing import Generic, TypeVar, Union
1450+
1451+
T = TypeVar("T")
1452+
def bar(arg: Union[str, T]) -> Union[str, T]: ...
1453+
1454+
def foo(arg: Union[str, int]) -> None:
1455+
if isinstance(arg, int):
1456+
arg = bar("default")
1457+
reveal_type(arg) # N: Revealed type is "builtins.str"
1458+
[builtins fixtures/isinstance.pyi]
1459+
1460+
[case testOptionalTypeNarrowedByGenericCall4]
1461+
# flags: --strict-optional
1462+
from typing import Optional, List, Generic, TypeVar
1463+
1464+
T = TypeVar("T", covariant=True)
1465+
class C(Generic[T]): ...
1466+
1467+
x: Optional[C[int]] = None
1468+
y = x = C()
1469+
reveal_type(y) # N: Revealed type is "__main__.C[builtins.int]"
1470+
1471+
[case testOptionalTypeNarrowedByGenericCall5]
1472+
from typing import Any, Tuple, Union
1473+
1474+
i: Union[Tuple[Any, ...], int]
1475+
b: Any
1476+
i = i if isinstance(i, int) else b
1477+
reveal_type(i) # N: Revealed type is "Union[Any, builtins.int]"
1478+
[builtins fixtures/isinstance.pyi]

test-data/unit/check-typeddict.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ B = TypedDict('B', {'@type': Literal['b-type'], 'b': int})
893893

894894
c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'}
895895
reveal_type(c) # N: Revealed type is "Union[TypedDict('__main__.A', {'@type': Literal['a-type'], 'a': builtins.str}), TypedDict('__main__.B', {'@type': Literal['b-type'], 'b': builtins.int})]"
896-
[builtins fixtures/tuple.pyi]
896+
[builtins fixtures/dict.pyi]
897897

898898
[case testTypedDictUnionAmbiguousCase]
899899
from typing import Union, Mapping, Any, cast

0 commit comments

Comments
 (0)