Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
from mypy.semanal import set_callable_name, refers_to_fullname
from mypy.mro import calculate_mro
from mypy.erasetype import erase_typevars
from mypy.erasetype import erase_typevars, remove_literal_metadata
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.visitor import NodeVisitor
from mypy.join import join_types
Expand Down Expand Up @@ -1868,10 +1868,9 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
rvalue_type = self.expr_checker.accept(
rvalue,
in_final_declaration=inferred.is_final,
)
rvalue_type = self.expr_checker.accept(rvalue)
if not inferred.is_final:
rvalue_type = remove_literal_metadata(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
Expand Down
24 changes: 7 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,6 @@ def __init__(self,
self.plugin = plugin
self.type_context = [None]

# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
# For example, if we're checking the "3" in a statement like "var: Final = 3".
#
# This flag changes the type that eventually gets inferred for "var". Instead of
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
# of the underlying literal value. See the comments in Instance's constructors for
# more details.
self.in_final_declaration = False

# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
Expand Down Expand Up @@ -1812,15 +1803,13 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
typ = self.named_type(fallback_name)
if self.is_literal_context():
return LiteralType(value=value, fallback=typ)
elif self.in_final_declaration:
else:
return typ.copy_modified(final_value=LiteralType(
value=value,
fallback=typ,
line=typ.line,
column=typ.column,
))
else:
return typ

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
Expand Down Expand Up @@ -2450,7 +2439,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type:
It may also represent type application.
"""
result = self.visit_index_expr_helper(e)
return self.narrow_type_from_binder(e, result)
result = self.narrow_type_from_binder(e, result)
if (self.is_literal_context() and isinstance(result, Instance)
and result.final_value is not None):
result = result.final_value
return result

def visit_index_expr_helper(self, e: IndexExpr) -> Type:
if e.analyzed:
Expand Down Expand Up @@ -3253,16 +3246,13 @@ def accept(self,
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
in_final_declaration: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
old_in_final_declaration = self.in_final_declaration
self.in_final_declaration = in_final_declaration
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand All @@ -3274,8 +3264,8 @@ def accept(self,
except Exception as err:
report_internal_error(err, self.chk.errors.file,
node.line, self.chk.errors, self.chk.options)

self.type_context.pop()
self.in_final_declaration = old_in_final_declaration
assert typ is not None
self.chk.store_type(node, typ)

Expand Down
14 changes: 14 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t


def remove_literal_metadata(t: Type) -> Type:
return t.accept(LiteralMetadataEraser())


class LiteralMetadataEraser(TypeTranslator):
"""Removes the Literal[...] type that may be associated with any
Instance types."""

def visit_instance(self, t: Instance) -> Type:
if t.final_value:
return t.copy_modified(final_value=None)
return t
39 changes: 39 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -2806,3 +2806,42 @@ Alias = Test
x: Literal[Alias.FOO]
reveal_type(x) # E: Revealed type is 'Literal[__main__.Test.FOO]'
[out]

[case testLiteralWithFinalPropagation]
from typing_extensions import Final, Literal

a: Final = 3
b: Final = a
c = a

def expect_3(x: Literal[3]) -> None: pass
expect_3(a)
expect_3(b)
expect_3(c) # E: Argument 1 to "expect_3" has incompatible type "int"; expected "Literal[3]"
[out]

[case testLiteralWithFinalPropagationIsNotLeaking]
from typing_extensions import Final, Literal

final_tuple_direct: Final = (2, 3)
final_tuple_indirect: Final = final_tuple_direct
mutable_tuple = final_tuple_direct
final_list_1: Final = [2]
final_list_2: Final = [2, 2]
final_dict: Final = {"foo": 2}
final_set_1: Final = {2}
final_set_2: Final = {2, 2}

def expect_2(x: Literal[2]) -> None: pass

expect_2(final_tuple_direct[0])
expect_2(final_tuple_indirect[0])

expect_2(mutable_tuple[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
expect_2(final_list_1[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
expect_2(final_list_2[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
expect_2(final_dict["foo"]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
expect_2(final_set_1.pop()) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
expect_2(final_set_2.pop()) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]"
[builtins fixtures/isinstancelist.pyi]
[out]
3 changes: 2 additions & 1 deletion test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -8553,8 +8553,9 @@ from typing_extensions import Literal
def expect_3(x: Literal[3]) -> None: pass
expect_3(foo)
[file mod1.py]
from typing_extensions import Final
from mod2 import bar
foo = bar
foo: Final = bar
[file mod2.py]
from mod3 import qux as bar
[file mod3.py]
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/isinstancelist.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ class set(Generic[T]):
def add(self, x: T) -> None: pass
def discard(self, x: T) -> None: pass
def update(self, x: Set[T]) -> None: pass
def pop(self) -> T: pass