Skip to content

Commit 28815f1

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Fix incorrect tracking of "final" Instances (#6763)
* Refine how Literal and Final interact This diff makes three changes: it fixes a bug where we incorrectly track "final" Instances, does some related refactoring, and finally modifies tuple indexing to be aware of literal contexts. Specifically, here is an example of the bug. Note that mypy ignores the mutable nature of `bar`: def expect_3(x: Literal[3]) -> None: ... foo: Final = 3 bar = foo for i in range(10): bar = i # Currently type-check; this PR makes mypy correctly report an error expect_3(bar) To fix this bug, I decided to adjust the variable assignment logic: if the variable is non-final, we now scan the inferred type we try assigning and recursively erase all set `instance.final_value` fields. This change ended up making the `in_final_declaration` field redundant -- after all, we're going to be actively erasing types on non-final assignments anyways. So, I decided to just remove this field. I suspect this change will also result in some nice dividends down the road: defaulting to preserving the underlying literal when inferring expression types will probably make it easier to add more sophisticated literal-related inference down the road. In the process of implementing the above two, I discovered that "nested" Instance types are effectively ignored. So, the following program does not type check, despite the `Final` and despite that tuples are immutable: bar: Final = (3, 2, 1) # 'bar[0] == 3' is always true, but we currently report an error expect_3(bar[0]) This is mildly annoying, and also made it slightly harder for me to verify my changes above, so I decided to modify `visit_index_expr` to also examine the literal context. (Actually, I found I could move this check directly into the 'accept' method instead of special-casing things within `visit_index_expr` and `analyze_var_ref`. But I decided against this approach: the special-casing feels less intrusive, easier to audit, and slightly more efficient.) * Rename 'final_value' field to 'last_known_value' * Adjust one existing test
1 parent 04d9649 commit 28815f1

File tree

17 files changed

+135
-70
lines changed

17 files changed

+135
-70
lines changed

mypy/checker.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
5454
from mypy.semanal import set_callable_name, refers_to_fullname
5555
from mypy.mro import calculate_mro
56-
from mypy.erasetype import erase_typevars
56+
from mypy.erasetype import erase_typevars, remove_instance_last_known_values
5757
from mypy.expandtype import expand_type, expand_type_by_instance
5858
from mypy.visitor import NodeVisitor
5959
from mypy.join import join_types
@@ -1898,10 +1898,9 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
18981898
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
18991899

19001900
if inferred:
1901-
rvalue_type = self.expr_checker.accept(
1902-
rvalue,
1903-
in_final_declaration=inferred.is_final,
1904-
)
1901+
rvalue_type = self.expr_checker.accept(rvalue)
1902+
if not inferred.is_final:
1903+
rvalue_type = remove_instance_last_known_values(rvalue_type)
19051904
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
19061905

19071906
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],

mypy/checkexpr.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,6 @@ def __init__(self,
141141
self.plugin = plugin
142142
self.type_context = [None]
143143

144-
# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
145-
# For example, if we're checking the "3" in a statement like "var: Final = 3".
146-
#
147-
# This flag changes the type that eventually gets inferred for "var". Instead of
148-
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
149-
# of the underlying literal value. See the comments in Instance's constructors for
150-
# more details.
151-
self.in_final_declaration = False
152-
153144
# Temporary overrides for expression types. This is currently
154145
# used by the union math in overloads.
155146
# TODO: refactor this to use a pattern similar to one in
@@ -224,8 +215,8 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
224215
def analyze_var_ref(self, var: Var, context: Context) -> Type:
225216
if var.type:
226217
if isinstance(var.type, Instance):
227-
if self.is_literal_context() and var.type.final_value is not None:
228-
return var.type.final_value
218+
if self.is_literal_context() and var.type.last_known_value is not None:
219+
return var.type.last_known_value
229220
if var.name() in {'True', 'False'}:
230221
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
231222
return var.type
@@ -1812,15 +1803,13 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
18121803
typ = self.named_type(fallback_name)
18131804
if self.is_literal_context():
18141805
return LiteralType(value=value, fallback=typ)
1815-
elif self.in_final_declaration:
1806+
else:
18161807
return typ.copy_modified(final_value=LiteralType(
18171808
value=value,
18181809
fallback=typ,
18191810
line=typ.line,
18201811
column=typ.column,
18211812
))
1822-
else:
1823-
return typ
18241813

18251814
def visit_int_expr(self, e: IntExpr) -> Type:
18261815
"""Type check an integer literal (trivial)."""
@@ -2450,7 +2439,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type:
24502439
It may also represent type application.
24512440
"""
24522441
result = self.visit_index_expr_helper(e)
2453-
return self.narrow_type_from_binder(e, result)
2442+
result = self.narrow_type_from_binder(e, result)
2443+
if (self.is_literal_context() and isinstance(result, Instance)
2444+
and result.last_known_value is not None):
2445+
result = result.last_known_value
2446+
return result
24542447

24552448
def visit_index_expr_helper(self, e: IndexExpr) -> Type:
24562449
if e.analyzed:
@@ -2542,8 +2535,8 @@ def _get_value(self, index: Expression) -> Optional[int]:
25422535
if isinstance(operand, IntExpr):
25432536
return -1 * operand.value
25442537
typ = self.accept(index)
2545-
if isinstance(typ, Instance) and typ.final_value is not None:
2546-
typ = typ.final_value
2538+
if isinstance(typ, Instance) and typ.last_known_value is not None:
2539+
typ = typ.last_known_value
25472540
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
25482541
return typ.value
25492542
return None
@@ -2553,8 +2546,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
25532546
item_name = index.value
25542547
else:
25552548
typ = self.accept(index)
2556-
if isinstance(typ, Instance) and typ.final_value is not None:
2557-
typ = typ.final_value
2549+
if isinstance(typ, Instance) and typ.last_known_value is not None:
2550+
typ = typ.last_known_value
25582551

25592552
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
25602553
item_name = typ.value
@@ -3253,16 +3246,13 @@ def accept(self,
32533246
type_context: Optional[Type] = None,
32543247
allow_none_return: bool = False,
32553248
always_allow_any: bool = False,
3256-
in_final_declaration: bool = False,
32573249
) -> Type:
32583250
"""Type check a node in the given type context. If allow_none_return
32593251
is True and this expression is a call, allow it to return None. This
32603252
applies only to this expression and not any subexpressions.
32613253
"""
32623254
if node in self.type_overrides:
32633255
return self.type_overrides[node]
3264-
old_in_final_declaration = self.in_final_declaration
3265-
self.in_final_declaration = in_final_declaration
32663256
self.type_context.append(type_context)
32673257
try:
32683258
if allow_none_return and isinstance(node, CallExpr):
@@ -3274,8 +3264,8 @@ def accept(self,
32743264
except Exception as err:
32753265
report_internal_error(err, self.chk.errors.file,
32763266
node.line, self.chk.errors, self.chk.options)
3267+
32773268
self.type_context.pop()
3278-
self.in_final_declaration = old_in_final_declaration
32793269
assert typ is not None
32803270
self.chk.store_type(node, typ)
32813271

mypy/checkmember.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def analyze_member_access(name: str,
101101
msg,
102102
chk=chk)
103103
result = _analyze_member_access(name, typ, mx, override_info)
104-
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
105-
return result.final_value
104+
if in_literal_context and isinstance(result, Instance) and result.last_known_value is not None:
105+
return result.last_known_value
106106
else:
107107
return result
108108

mypy/erasetype.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,17 @@ def visit_type_var(self, t: TypeVarType) -> Type:
119119
if self.erase_id(t.id):
120120
return self.replacement
121121
return t
122+
123+
124+
def remove_instance_last_known_values(t: Type) -> Type:
125+
return t.accept(LastKnownValueEraser())
126+
127+
128+
class LastKnownValueEraser(TypeTranslator):
129+
"""Removes the Literal[...] type that may be associated with any
130+
Instance types."""
131+
132+
def visit_instance(self, t: Instance) -> Type:
133+
if t.last_known_value:
134+
return t.copy_modified(final_value=None)
135+
return t

mypy/fixup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
155155
base.accept(self)
156156
for a in inst.args:
157157
a.accept(self)
158-
if inst.final_value is not None:
159-
inst.final_value.accept(self)
158+
if inst.last_known_value is not None:
159+
inst.last_known_value.accept(self)
160160

161161
def visit_any(self, o: Any) -> None:
162162
pass # Nothing to descend into.

mypy/newsemanal/typeanal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
700700
elif isinstance(arg, (NoneType, LiteralType)):
701701
# Types that we can just add directly to the literal/potential union of literals.
702702
return [arg]
703-
elif isinstance(arg, Instance) and arg.final_value is not None:
703+
elif isinstance(arg, Instance) and arg.last_known_value is not None:
704704
# Types generated from declarations like "var: Final = 4".
705-
return [arg.final_value]
705+
return [arg.last_known_value]
706706
elif isinstance(arg, UnionType):
707707
out = []
708708
for union_arg in arg.items:

mypy/plugins/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
125125
"""If this expression is a string literal, or if the corresponding type
126126
is something like 'Literal["some string here"]', returns the underlying
127127
string value. Otherwise, returns None."""
128-
if isinstance(typ, Instance) and typ.final_value is not None:
129-
typ = typ.final_value
128+
if isinstance(typ, Instance) and typ.last_known_value is not None:
129+
typ = typ.last_known_value
130130

131131
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
132132
val = typ.value

mypy/sametypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def visit_instance(self, left: Instance) -> bool:
7979
return (isinstance(self.right, Instance) and
8080
left.type == self.right.type and
8181
is_same_types(left.args, self.right.args) and
82-
left.final_value == self.right.final_value)
82+
left.last_known_value == self.right.last_known_value)
8383

8484
def visit_type_var(self, left: TypeVarType) -> bool:
8585
return (isinstance(self.right, TypeVarType) and

mypy/server/astdiff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def visit_instance(self, typ: Instance) -> SnapshotItem:
284284
return ('Instance',
285285
typ.type.fullname(),
286286
snapshot_types(typ.args),
287-
None if typ.final_value is None else snapshot_type(typ.final_value))
287+
None if typ.last_known_value is None else snapshot_type(typ.last_known_value))
288288

289289
def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
290290
return ('TypeVar',

mypy/server/astmerge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
342342
typ.type = self.fixup(typ.type)
343343
for arg in typ.args:
344344
arg.accept(self)
345-
if typ.final_value:
346-
typ.final_value.accept(self)
345+
if typ.last_known_value:
346+
typ.last_known_value.accept(self)
347347

348348
def visit_any(self, typ: AnyType) -> None:
349349
pass

mypy/server/deps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
883883
triggers = [trigger]
884884
for arg in typ.args:
885885
triggers.extend(self.get_type_triggers(arg))
886-
if typ.final_value:
887-
triggers.extend(self.get_type_triggers(typ.final_value))
886+
if typ.last_known_value:
887+
triggers.extend(self.get_type_triggers(typ.last_known_value))
888888
return triggers
889889

890890
def visit_any(self, typ: AnyType) -> List[str]:

mypy/type_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,16 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
164164

165165
def visit_instance(self, t: Instance) -> Type:
166166
final_value = None # type: Optional[LiteralType]
167-
if t.final_value is not None:
168-
raw_final_value = t.final_value.accept(self)
167+
if t.last_known_value is not None:
168+
raw_final_value = t.last_known_value.accept(self)
169169
assert isinstance(raw_final_value, LiteralType)
170170
final_value = raw_final_value
171171
return Instance(
172172
typ=t.type,
173173
args=self.translate_types(t.args),
174174
line=t.line,
175175
column=t.column,
176-
final_value=final_value,
176+
last_known_value=final_value,
177177
)
178178

179179
def visit_type_var(self, t: TypeVarType) -> Type:

mypy/typeanal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
730730
elif isinstance(arg, (NoneType, LiteralType)):
731731
# Types that we can just add directly to the literal/potential union of literals.
732732
return [arg]
733-
elif isinstance(arg, Instance) and arg.final_value is not None:
733+
elif isinstance(arg, Instance) and arg.last_known_value is not None:
734734
# Types generated from declarations like "var: Final = 4".
735-
return [arg.final_value]
735+
return [arg.last_known_value]
736736
elif isinstance(arg, UnionType):
737737
out = []
738738
for union_arg in arg.items:

mypy/types.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -585,11 +585,11 @@ class Instance(Type):
585585
The list of type variables may be empty.
586586
"""
587587

588-
__slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'final_value')
588+
__slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'last_known_value')
589589

590590
def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
591591
line: int = -1, column: int = -1, erased: bool = False,
592-
final_value: Optional['LiteralType'] = None) -> None:
592+
last_known_value: Optional['LiteralType'] = None) -> None:
593593
super().__init__(line, column)
594594
self.type = typ
595595
self.args = args
@@ -601,15 +601,31 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
601601
# True if recovered after incorrect number of type arguments error
602602
self.invalid = False
603603

604-
# This field keeps track of the underlying Literal[...] value if this instance
605-
# was created via a Final declaration. For example, if we did `x: Final = 3`, x
606-
# would have an instance with a `final_value` of `LiteralType(3, int_fallback)`.
604+
# This field keeps track of the underlying Literal[...] value associated with
605+
# this instance, if one is known.
607606
#
608-
# Or more broadly, this field lets this Instance "remember" its original declaration.
609-
# We want this behavior because we want implicit Final declarations to act pretty
610-
# much identically with constants: we should be able to replace any places where we
611-
# use some Final variable with the original value and get the same type-checking
612-
# behavior. For example, we want this program:
607+
# This field is set whenever possible within expressions, but is erased upon
608+
# variable assignment (see erasetype.remove_instance_last_known_values) unless
609+
# the variable is declared to be final.
610+
#
611+
# For example, consider the following program:
612+
#
613+
# a = 1
614+
# b: Final[int] = 2
615+
# c: Final = 3
616+
# print(a + b + c + 4)
617+
#
618+
# The 'Instance' objects associated with the expressions '1', '2', '3', and '4' will
619+
# have last_known_values of type Literal[1], Literal[2], Literal[3], and Literal[4]
620+
# respectively. However, the Instance object assigned to 'a' and 'b' will have their
621+
# last_known_value erased: variable 'a' is mutable; variable 'b' was declared to be
622+
# specifically an int.
623+
#
624+
# Or more broadly, this field lets this Instance "remember" its original declaration
625+
# when applicable. We want this behavior because we want implicit Final declarations
626+
# to act pretty much identically with constants: we should be able to replace any
627+
# places where we use some Final variable with the original value and get the same
628+
# type-checking behavior. For example, we want this program:
613629
#
614630
# def expects_literal(x: Literal[3]) -> None: pass
615631
# var: Final = 3
@@ -623,39 +639,37 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
623639
# In order to make this work (especially with literal types), we need var's type
624640
# (an Instance) to remember the "original" value.
625641
#
626-
# This field is currently set only when we encounter an *implicit* final declaration
627-
# like `x: Final = 3` where the RHS is some literal expression. This field remains 'None'
628-
# when we do things like `x: Final[int] = 3` or `x: Final = foo + bar`.
642+
# Preserving this value within expressions is useful for similar reasons.
629643
#
630644
# Currently most of mypy will ignore this field and will continue to treat this type like
631645
# a regular Instance. We end up using this field only when we are explicitly within a
632646
# Literal context.
633-
self.final_value = final_value
647+
self.last_known_value = last_known_value
634648

635649
def accept(self, visitor: 'TypeVisitor[T]') -> T:
636650
return visitor.visit_instance(self)
637651

638652
def __hash__(self) -> int:
639-
return hash((self.type, tuple(self.args), self.final_value))
653+
return hash((self.type, tuple(self.args), self.last_known_value))
640654

641655
def __eq__(self, other: object) -> bool:
642656
if not isinstance(other, Instance):
643657
return NotImplemented
644658
return (self.type == other.type
645659
and self.args == other.args
646-
and self.final_value == other.final_value)
660+
and self.last_known_value == other.last_known_value)
647661

648662
def serialize(self) -> Union[JsonDict, str]:
649663
assert self.type is not None
650664
type_ref = self.type.fullname()
651-
if not self.args and not self.final_value:
665+
if not self.args and not self.last_known_value:
652666
return type_ref
653667
data = {'.class': 'Instance',
654668
} # type: JsonDict
655669
data['type_ref'] = type_ref
656670
data['args'] = [arg.serialize() for arg in self.args]
657-
if self.final_value is not None:
658-
data['final_value'] = self.final_value.serialize()
671+
if self.last_known_value is not None:
672+
data['last_known_value'] = self.last_known_value.serialize()
659673
return data
660674

661675
@classmethod
@@ -672,8 +686,8 @@ def deserialize(cls, data: Union[JsonDict, str]) -> 'Instance':
672686
args = [deserialize_type(arg) for arg in args_list]
673687
inst = Instance(NOT_READY, args)
674688
inst.type_ref = data['type_ref'] # Will be fixed up by fixup.py later.
675-
if 'final_value' in data:
676-
inst.final_value = LiteralType.deserialize(data['final_value'])
689+
if 'last_known_value' in data:
690+
inst.last_known_value = LiteralType.deserialize(data['last_known_value'])
677691
return inst
678692

679693
def copy_modified(self, *,
@@ -685,7 +699,7 @@ def copy_modified(self, *,
685699
self.line,
686700
self.column,
687701
self.erased,
688-
final_value if final_value is not _dummy else self.final_value,
702+
final_value if final_value is not _dummy else self.last_known_value,
689703
)
690704

691705
def has_readable_member(self, name: str) -> bool:

0 commit comments

Comments
 (0)