Skip to content

Fix incorrect tracking of "final" Instances #6763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
from mypy.semanal import set_callable_name, refers_to_fullname
from mypy.mro import calculate_mro
from mypy.erasetype import erase_typevars
from mypy.erasetype import erase_typevars, remove_instance_last_known_values
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.visitor import NodeVisitor
from mypy.join import join_types
Expand Down Expand Up @@ -1868,10 +1868,9 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
rvalue_type = self.expr_checker.accept(
rvalue,
in_final_declaration=inferred.is_final,
)
rvalue_type = self.expr_checker.accept(rvalue)
if not inferred.is_final:
rvalue_type = remove_instance_last_known_values(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
Expand Down
36 changes: 13 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,6 @@ def __init__(self,
self.plugin = plugin
self.type_context = [None]

# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
# For example, if we're checking the "3" in a statement like "var: Final = 3".
#
# This flag changes the type that eventually gets inferred for "var". Instead of
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
# of the underlying literal value. See the comments in Instance's constructors for
# more details.
self.in_final_declaration = False

# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
Expand Down Expand Up @@ -224,8 +215,8 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
if isinstance(var.type, Instance):
if self.is_literal_context() and var.type.final_value is not None:
return var.type.final_value
if self.is_literal_context() and var.type.last_known_value is not None:
return var.type.last_known_value
if var.name() in {'True', 'False'}:
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
return var.type
Expand Down Expand Up @@ -1812,15 +1803,13 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
typ = self.named_type(fallback_name)
if self.is_literal_context():
return LiteralType(value=value, fallback=typ)
elif self.in_final_declaration:
else:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
else:
return typ

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
Expand Down Expand Up @@ -2450,7 +2439,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type:
It may also represent type application.
"""
result = self.visit_index_expr_helper(e)
return self.narrow_type_from_binder(e, result)
result = self.narrow_type_from_binder(e, result)
if (self.is_literal_context() and isinstance(result, Instance)
and result.last_known_value is not None):
result = result.last_known_value
return result

def visit_index_expr_helper(self, e: IndexExpr) -> Type:
if e.analyzed:
Expand Down Expand Up @@ -2542,8 +2535,8 @@ def _get_value(self, index: Expression) -> Optional[int]:
if isinstance(operand, IntExpr):
return -1 * operand.value
typ = self.accept(index)
if isinstance(typ, Instance) and typ.final_value is not None:
typ = typ.final_value
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
return typ.value
return None
Expand All @@ -2553,8 +2546,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
item_name = index.value
else:
typ = self.accept(index)
if isinstance(typ, Instance) and typ.final_value is not None:
typ = typ.final_value
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value

if isinstance(typ, LiteralType) and isinstance(typ.value, str):
item_name = typ.value
Expand Down Expand Up @@ -3253,16 +3246,13 @@ def accept(self,
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
in_final_declaration: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
old_in_final_declaration = self.in_final_declaration
self.in_final_declaration = in_final_declaration
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand All @@ -3274,8 +3264,8 @@ def accept(self,
except Exception as err:
report_internal_error(err, self.chk.errors.file,
node.line, self.chk.errors, self.chk.options)

self.type_context.pop()
self.in_final_declaration = old_in_final_declaration
assert typ is not None
self.chk.store_type(node, typ)

Expand Down
4 changes: 2 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def analyze_member_access(name: str,
msg,
chk=chk)
result = _analyze_member_access(name, typ, mx, override_info)
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
return result.final_value
if in_literal_context and isinstance(result, Instance) and result.last_known_value is not None:
return result.last_known_value
else:
return result

Expand Down
14 changes: 14 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t


def remove_instance_last_known_values(t: Type) -> Type:
return t.accept(LastKnownValueEraser())


class LastKnownValueEraser(TypeTranslator):
"""Removes the Literal[...] type that may be associated with any
Instance types."""

def visit_instance(self, t: Instance) -> Type:
if t.last_known_value:
return t.copy_modified(final_value=None)
return t
4 changes: 2 additions & 2 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
base.accept(self)
for a in inst.args:
a.accept(self)
if inst.final_value is not None:
inst.final_value.accept(self)
if inst.last_known_value is not None:
inst.last_known_value.accept(self)

def visit_any(self, o: Any) -> None:
pass # Nothing to descend into.
Expand Down
4 changes: 2 additions & 2 deletions mypy/newsemanal/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
elif isinstance(arg, (NoneType, LiteralType)):
# Types that we can just add directly to the literal/potential union of literals.
return [arg]
elif isinstance(arg, Instance) and arg.final_value is not None:
elif isinstance(arg, Instance) and arg.last_known_value is not None:
# Types generated from declarations like "var: Final = 4".
return [arg.final_value]
return [arg.last_known_value]
elif isinstance(arg, UnionType):
out = []
for union_arg in arg.items:
Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
"""If this expression is a string literal, or if the corresponding type
is something like 'Literal["some string here"]', returns the underlying
string value. Otherwise, returns None."""
if isinstance(typ, Instance) and typ.final_value is not None:
typ = typ.final_value
if isinstance(typ, Instance) and typ.last_known_value is not None:
typ = typ.last_known_value

if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
val = typ.value
Expand Down
2 changes: 1 addition & 1 deletion mypy/sametypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def visit_instance(self, left: Instance) -> bool:
return (isinstance(self.right, Instance) and
left.type == self.right.type and
is_same_types(left.args, self.right.args) and
left.final_value == self.right.final_value)
left.last_known_value == self.right.last_known_value)

def visit_type_var(self, left: TypeVarType) -> bool:
return (isinstance(self.right, TypeVarType) and
Expand Down
2 changes: 1 addition & 1 deletion mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def visit_instance(self, typ: Instance) -> SnapshotItem:
return ('Instance',
typ.type.fullname(),
snapshot_types(typ.args),
None if typ.final_value is None else snapshot_type(typ.final_value))
None if typ.last_known_value is None else snapshot_type(typ.last_known_value))

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
return ('TypeVar',
Expand Down
4 changes: 2 additions & 2 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
typ.type = self.fixup(typ.type)
for arg in typ.args:
arg.accept(self)
if typ.final_value:
typ.final_value.accept(self)
if typ.last_known_value:
typ.last_known_value.accept(self)

def visit_any(self, typ: AnyType) -> None:
pass
Expand Down
4 changes: 2 additions & 2 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]:
triggers = [trigger]
for arg in typ.args:
triggers.extend(self.get_type_triggers(arg))
if typ.final_value:
triggers.extend(self.get_type_triggers(typ.final_value))
if typ.last_known_value:
triggers.extend(self.get_type_triggers(typ.last_known_value))
return triggers

def visit_any(self, typ: AnyType) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ def visit_deleted_type(self, t: DeletedType) -> Type:

def visit_instance(self, t: Instance) -> Type:
final_value = None # type: Optional[LiteralType]
if t.final_value is not None:
raw_final_value = t.final_value.accept(self)
if t.last_known_value is not None:
raw_final_value = t.last_known_value.accept(self)
assert isinstance(raw_final_value, LiteralType)
final_value = raw_final_value
return Instance(
typ=t.type,
args=self.translate_types(t.args),
line=t.line,
column=t.column,
final_value=final_value,
last_known_value=final_value,
)

def visit_type_var(self, t: TypeVarType) -> Type:
Expand Down
4 changes: 2 additions & 2 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
elif isinstance(arg, (NoneType, LiteralType)):
# Types that we can just add directly to the literal/potential union of literals.
return [arg]
elif isinstance(arg, Instance) and arg.final_value is not None:
elif isinstance(arg, Instance) and arg.last_known_value is not None:
# Types generated from declarations like "var: Final = 4".
return [arg.final_value]
return [arg.last_known_value]
elif isinstance(arg, UnionType):
out = []
for union_arg in arg.items:
Expand Down
58 changes: 36 additions & 22 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,11 @@ class Instance(Type):
The list of type variables may be empty.
"""

__slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'final_value')
__slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'last_known_value')

def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
line: int = -1, column: int = -1, erased: bool = False,
final_value: Optional['LiteralType'] = None) -> None:
last_known_value: Optional['LiteralType'] = None) -> None:
super().__init__(line, column)
self.type = typ
self.args = args
Expand All @@ -595,15 +595,31 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
# True if recovered after incorrect number of type arguments error
self.invalid = False

# This field keeps track of the underlying Literal[...] value if this instance
# was created via a Final declaration. For example, if we did `x: Final = 3`, x
# would have an instance with a `final_value` of `LiteralType(3, int_fallback)`.
# This field keeps track of the underlying Literal[...] value associated with
# this instance, if one is known.
#
# Or more broadly, this field lets this Instance "remember" its original declaration.
# We want this behavior because we want implicit Final declarations to act pretty
# much identically with constants: we should be able to replace any places where we
# use some Final variable with the original value and get the same type-checking
# behavior. For example, we want this program:
# This field is set whenever possible within expressions, but is erased upon
# variable assignment (see erasetype.remove_instance_last_known_values) unless
# the variable is declared to be final.
#
# For example, consider the following program:
#
# a = 1
# b: Final[int] = 2
# c: Final = 3
# print(a + b + c + 4)
#
# The 'Instance' objects associated with the expressions '1', '2', '3', and '4' will
# have last_known_values of type Literal[1], Literal[2], Literal[3], and Literal[4]
# respectively. However, the Instance object assigned to 'a' and 'b' will have their
# last_known_value erased: variable 'a' is mutable; variable 'b' was declared to be
# specifically an int.
#
# Or more broadly, this field lets this Instance "remember" its original declaration
# when applicable. We want this behavior because we want implicit Final declarations
# to act pretty much identically with constants: we should be able to replace any
# places where we use some Final variable with the original value and get the same
# type-checking behavior. For example, we want this program:
#
# def expects_literal(x: Literal[3]) -> None: pass
# var: Final = 3
Expand All @@ -617,39 +633,37 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type],
# In order to make this work (especially with literal types), we need var's type
# (an Instance) to remember the "original" value.
#
# This field is currently set only when we encounter an *implicit* final declaration
# like `x: Final = 3` where the RHS is some literal expression. This field remains 'None'
# when we do things like `x: Final[int] = 3` or `x: Final = foo + bar`.
# Preserving this value within expressions is useful for similar reasons.
#
# Currently most of mypy will ignore this field and will continue to treat this type like
# a regular Instance. We end up using this field only when we are explicitly within a
# Literal context.
self.final_value = final_value
self.last_known_value = last_known_value

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_instance(self)

def __hash__(self) -> int:
return hash((self.type, tuple(self.args), self.final_value))
return hash((self.type, tuple(self.args), self.last_known_value))

def __eq__(self, other: object) -> bool:
if not isinstance(other, Instance):
return NotImplemented
return (self.type == other.type
and self.args == other.args
and self.final_value == other.final_value)
and self.last_known_value == other.last_known_value)

def serialize(self) -> Union[JsonDict, str]:
assert self.type is not None
type_ref = self.type.fullname()
if not self.args and not self.final_value:
if not self.args and not self.last_known_value:
return type_ref
data = {'.class': 'Instance',
} # type: JsonDict
data['type_ref'] = type_ref
data['args'] = [arg.serialize() for arg in self.args]
if self.final_value is not None:
data['final_value'] = self.final_value.serialize()
if self.last_known_value is not None:
data['last_known_value'] = self.last_known_value.serialize()
return data

@classmethod
Expand All @@ -666,8 +680,8 @@ def deserialize(cls, data: Union[JsonDict, str]) -> 'Instance':
args = [deserialize_type(arg) for arg in args_list]
inst = Instance(NOT_READY, args)
inst.type_ref = data['type_ref'] # Will be fixed up by fixup.py later.
if 'final_value' in data:
inst.final_value = LiteralType.deserialize(data['final_value'])
if 'last_known_value' in data:
inst.last_known_value = LiteralType.deserialize(data['last_known_value'])
return inst

def copy_modified(self, *,
Expand All @@ -679,7 +693,7 @@ def copy_modified(self, *,
self.line,
self.column,
self.erased,
final_value if final_value is not _dummy else self.final_value,
final_value if final_value is not _dummy else self.last_known_value,
)

def has_readable_member(self, name: str) -> bool:
Expand Down
Loading