Skip to content

Commit cb0728c

Browse files
authored
Add basic end-to-end support for Literal types (#5947)
This pull request primarily adds support for string, int, bool, and None Literal types in the parsing and semantic analysis layer. It also adds a small smidge of logic to the subtyping and overlapping types parts of the typechecking layer -- just enough so we can start using Literal types end-to-end and write some tests. It also adds a "skip path normalization" option to our test cases. Basically, my previous diff made a bunch of gross changes in a horizontal way to lay out a foundation; this diff makes a bunch of gross changes in a vertical way to create a proof-of-concept of the proof-of-concept. I'll probably need to submit a few follow-up PRs cleaning up some stuff in the semantic analysis layer, but my hope is to switch focus on fleshing out the type checking layer shortly after this PR lands. Specific changes made: 1. I added a new 'RawLiteralType' synthetic type meant to represent any literal expression that appears in the earliest stages of semantic analysis. If these 'RawLiteralTypes' appear in a "Literal[...]" context, they transformed into regular 'LiteralTypes' during phase 2 of semantic analysis (turning UnboundTypes into actual types). If they appear outside of the correct context, we report an error instead. I also added a string field to UnboundType to keep track of similar information. Basically, if we have `Foo["bar"]`, we don't actually know immediately whether or not that "bar" is meant to be string literal vs a forward reference. (Foo could be a regular class or an alias for 'Literal'). Note: I wanted to avoid having to introduce yet another type, so looked into perhaps attaching even more metadata to UnboundType or perhaps to LiteralType. Both of those approaches ended up looking pretty messy though, so I went with this. 2. As a consequence, some of the syntax checking logic had to move from the parsing layer to the semantic analysis layer, and some of the existing error messages changed slightly. 3. The PEP, at some point, provisionally rejected having Literals contain other Literal types. For example, something like this: RegularIds = Literal[1, 2, 3, 4] AdminIds = Literal[100, 101, 102] AllIds = Literal[RegularIds, AdminIds, 30, 32] This was mainly because I thought this would be challenging to implement, but it turned out to be easier then expected -- it happened almost by accident while I was implementing support for 'Literal[1, 2, 3]'. I can excise this part out if we think supporting this kind of thing is a fundamentally bad idea. 4. I also penciled in some minimal code to the subtyping and overlapping types logic. This diff also tweaks some of our test-case logic: we can now specify we want to skip path normalization on certain test cases. Specifically, the problem was that mypy attempts to normalize paths by replacing all instances of '\' with '/' so that the output when running the tests on Windows matches the specified errors. This is what we want to most of the time: except when we want to print out Literal types containing strings with slashes. I thought about maybe changing the output of mypy in general so it always uses '/' for paths in error outputs, even on Windows: this would mean we would no longer have to do path normalization. However, I wasn't convinced this was the right thing to do: using '\' on Windows technically *is* the right thing to do, and I didn't want to complicate the codebase by forcing us to keep track of when to use os.sep vs '/'. I don't think I'll add too many of these test cases, so I decided to just go with a localized solution instead of changing mypy's error output as a whole.
1 parent 4f2a88d commit cb0728c

29 files changed

+1100
-81
lines changed

mypy/exprtotype.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Translate an Expression to a Type value."""
22

33
from mypy.nodes import (
4-
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr,
4+
Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr, IntExpr, FloatExpr, UnaryExpr,
55
ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
66
get_member_expr_fullname
77
)
88
from mypy.fastparse import parse_type_comment
99
from mypy.types import (
10-
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument, TypeOfAny
10+
Type, UnboundType, TypeList, EllipsisType, AnyType, Optional, CallableArgument, TypeOfAny,
11+
LiteralType, RawLiteralType,
1112
)
1213

1314

@@ -37,7 +38,12 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
3738
name = None # type: Optional[str]
3839
if isinstance(expr, NameExpr):
3940
name = expr.name
40-
return UnboundType(name, line=expr.line, column=expr.column)
41+
if name == 'True':
42+
return RawLiteralType(True, 'builtins.bool', line=expr.line, column=expr.column)
43+
elif name == 'False':
44+
return RawLiteralType(False, 'builtins.bool', line=expr.line, column=expr.column)
45+
else:
46+
return UnboundType(name, line=expr.line, column=expr.column)
4147
elif isinstance(expr, MemberExpr):
4248
fullname = get_member_expr_fullname(expr)
4349
if fullname:
@@ -108,11 +114,26 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
108114
elif isinstance(expr, (StrExpr, BytesExpr, UnicodeExpr)):
109115
# Parse string literal type.
110116
try:
111-
result = parse_type_comment(expr.value, expr.line, None)
112-
assert result is not None
117+
node = parse_type_comment(expr.value, expr.line, None)
118+
assert node is not None
119+
if isinstance(node, UnboundType) and node.original_str_expr is None:
120+
node.original_str_expr = expr.value
121+
return node
113122
except SyntaxError:
123+
return RawLiteralType(expr.value, 'builtins.str', line=expr.line, column=expr.column)
124+
elif isinstance(expr, UnaryExpr):
125+
typ = expr_to_unanalyzed_type(expr.expr)
126+
if isinstance(typ, RawLiteralType) and isinstance(typ.value, int) and expr.op == '-':
127+
typ.value *= -1
128+
return typ
129+
else:
114130
raise TypeTranslationError()
115-
return result
131+
elif isinstance(expr, IntExpr):
132+
return RawLiteralType(expr.value, 'builtins.int', line=expr.line, column=expr.column)
133+
elif isinstance(expr, FloatExpr):
134+
# Floats are not valid parameters for RawLiteralType, so we just
135+
# pass in 'None' for now. We'll report the appropriate error at a later stage.
136+
return RawLiteralType(None, 'builtins.float', line=expr.line, column=expr.column)
116137
elif isinstance(expr, EllipsisExpr):
117138
return EllipsisType(expr.line)
118139
else:

mypy/fastparse.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from mypy.types import (
3434
Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument,
35-
TypeOfAny, Instance,
35+
TypeOfAny, Instance, RawLiteralType,
3636
)
3737
from mypy import defaults
3838
from mypy import messages
@@ -53,6 +53,9 @@
5353
Expression as ast3_Expression,
5454
Str,
5555
Index,
56+
Num,
57+
UnaryOp,
58+
USub,
5659
)
5760
except ImportError:
5861
if sys.version_info.minor > 2:
@@ -1138,12 +1141,46 @@ def visit_Name(self, n: Name) -> Type:
11381141
return UnboundType(n.id, line=self.line)
11391142

11401143
def visit_NameConstant(self, n: NameConstant) -> Type:
1141-
return UnboundType(str(n.value))
1144+
if isinstance(n.value, bool):
1145+
return RawLiteralType(n.value, 'builtins.bool', line=self.line)
1146+
else:
1147+
return UnboundType(str(n.value), line=self.line)
1148+
1149+
# UnaryOp(op, operand)
1150+
def visit_UnaryOp(self, n: UnaryOp) -> Type:
1151+
# We support specifically Literal[-4] and nothing else.
1152+
# For example, Literal[+4] or Literal[~6] is not supported.
1153+
typ = self.visit(n.operand)
1154+
if isinstance(typ, RawLiteralType) and isinstance(n.op, USub):
1155+
if isinstance(typ.value, int):
1156+
typ.value *= -1
1157+
return typ
1158+
self.fail(TYPE_COMMENT_AST_ERROR, self.line, getattr(n, 'col_offset', -1))
1159+
return AnyType(TypeOfAny.from_error)
1160+
1161+
# Num(number n)
1162+
def visit_Num(self, n: Num) -> Type:
1163+
# Could be either float or int
1164+
numeric_value = n.n
1165+
if isinstance(numeric_value, int):
1166+
return RawLiteralType(numeric_value, 'builtins.int', line=self.line)
1167+
elif isinstance(numeric_value, float):
1168+
# Floats and other numbers are not valid parameters for RawLiteralType, so we just
1169+
# pass in 'None' for now. We'll report the appropriate error at a later stage.
1170+
return RawLiteralType(None, 'builtins.float', line=self.line)
1171+
else:
1172+
self.fail(TYPE_COMMENT_AST_ERROR, self.line, getattr(n, 'col_offset', -1))
1173+
return AnyType(TypeOfAny.from_error)
11421174

11431175
# Str(string s)
11441176
def visit_Str(self, n: Str) -> Type:
1145-
return (parse_type_comment(n.s.strip(), self.line, self.errors) or
1146-
AnyType(TypeOfAny.from_error))
1177+
try:
1178+
node = parse_type_comment(n.s.strip(), self.line, errors=None)
1179+
if isinstance(node, UnboundType) and node.original_str_expr is None:
1180+
node.original_str_expr = n.s
1181+
return node or AnyType(TypeOfAny.from_error)
1182+
except SyntaxError:
1183+
return RawLiteralType(n.s, 'builtins.str', line=self.line)
11471184

11481185
# Subscript(expr value, slice slice, expr_context ctx)
11491186
def visit_Subscript(self, n: ast3.Subscript) -> Type:

mypy/indirection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def visit_tuple_type(self, t: types.TupleType) -> Set[str]:
9090
def visit_typeddict_type(self, t: types.TypedDictType) -> Set[str]:
9191
return self._visit(t.items.values()) | self._visit(t.fallback)
9292

93+
def visit_raw_literal_type(self, t: types.RawLiteralType) -> Set[str]:
94+
assert False, "Unexpected RawLiteralType after semantic analysis phase"
95+
9396
def visit_literal_type(self, t: types.LiteralType) -> Set[str]:
9497
return self._visit(t.fallback)
9598

mypy/meet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool:
235235
elif isinstance(right, CallableType):
236236
right = right.fallback
237237

238+
if isinstance(left, LiteralType) and isinstance(right, LiteralType):
239+
return left == right
240+
elif isinstance(left, LiteralType):
241+
left = left.fallback
242+
elif isinstance(right, LiteralType):
243+
right = right.fallback
244+
238245
# Finally, we handle the case where left and right are instances.
239246

240247
if isinstance(left, Instance) and isinstance(right, Instance):

mypy/messages.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.erasetype import erase_type
2020
from mypy.errors import Errors
2121
from mypy.types import (
22-
Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType,
22+
Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType,
2323
UnionType, NoneTyp, AnyType, Overloaded, FunctionLike, DeletedType, TypeType,
2424
UninhabitedType, TypeOfAny, ForwardRef, UnboundType
2525
)
@@ -297,6 +297,8 @@ def format_bare(self, typ: Type, verbosity: int = 0) -> str:
297297
self.format_bare(item_type)))
298298
s = 'TypedDict({{{}}})'.format(', '.join(items))
299299
return s
300+
elif isinstance(typ, LiteralType):
301+
return str(typ)
300302
elif isinstance(typ, UnionType):
301303
# Only print Unions as Optionals if the Optional wouldn't have to contain another Union
302304
print_as_optional = (len(typ.items) -

mypy/semanal_newtype.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import Tuple, Optional
77

8-
from mypy.types import Type, Instance, CallableType, NoneTyp, TupleType
8+
from mypy.types import Type, Instance, CallableType, NoneTyp, TupleType, AnyType, TypeOfAny
99
from mypy.nodes import (
1010
AssignmentStmt, NewTypeExpr, CallExpr, NameExpr, RefExpr, Context, StrExpr, BytesExpr,
1111
UnicodeExpr, Block, FuncDef, Argument, TypeInfo, Var, SymbolTableNode, GDEF, MDEF, ARG_POS
@@ -107,13 +107,21 @@ def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Opt
107107
has_failed = True
108108

109109
# Check second argument
110+
msg = "Argument 2 to NewType(...) must be a valid type"
110111
try:
111112
unanalyzed_type = expr_to_unanalyzed_type(args[1])
112113
except TypeTranslationError:
113-
self.fail("Argument 2 to NewType(...) must be a valid type", context)
114+
self.fail(msg, context)
114115
return None
116+
115117
old_type = self.api.anal_type(unanalyzed_type)
116118

119+
# The caller of this function assumes that if we return a Type, it's always
120+
# a valid one. So, we translate AnyTypes created from errors into None.
121+
if isinstance(old_type, AnyType) and old_type.type_of_any == TypeOfAny.from_error:
122+
self.fail(msg, context)
123+
return None
124+
117125
return None if has_failed else old_type
118126

119127
def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) -> TypeInfo:

mypy/semanal_pass3.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from mypy.types import (
2424
Type, Instance, AnyType, TypeOfAny, CallableType, TupleType, TypeVarType, TypedDictType,
25-
UnionType, TypeType, Overloaded, ForwardRef, TypeTranslator, function_type
25+
UnionType, TypeType, Overloaded, ForwardRef, TypeTranslator, function_type, LiteralType,
2626
)
2727
from mypy.errors import Errors, report_internal_error
2828
from mypy.options import Options
@@ -704,6 +704,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
704704
assert isinstance(fallback, Instance)
705705
return TypedDictType(items, t.required_keys, fallback, t.line, t.column)
706706

707+
def visit_literal_type(self, t: LiteralType) -> Type:
708+
if self.check_recursion(t):
709+
return AnyType(TypeOfAny.from_error)
710+
fallback = self.visit_instance(t.fallback, from_fallback=True)
711+
assert isinstance(fallback, Instance)
712+
return LiteralType(t.value, fallback, t.line, t.column)
713+
707714
def visit_union_type(self, t: UnionType) -> Type:
708715
if self.check_recursion(t):
709716
return AnyType(TypeOfAny.from_error)

mypy/server/astmerge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
Type, SyntheticTypeVisitor, Instance, AnyType, NoneTyp, CallableType, DeletedType, PartialType,
6060
TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType,
6161
Overloaded, TypeVarDef, TypeList, CallableArgument, EllipsisType, StarType, LiteralType,
62+
RawLiteralType,
6263
)
6364
from mypy.util import get_prefix, replace_object_state
6465
from mypy.typestate import TypeState
@@ -391,6 +392,9 @@ def visit_typeddict_type(self, typ: TypedDictType) -> None:
391392
value_type.accept(self)
392393
typ.fallback.accept(self)
393394

395+
def visit_raw_literal_type(self, t: RawLiteralType) -> None:
396+
assert False, "Unexpected RawLiteralType after semantic analysis phase"
397+
394398
def visit_literal_type(self, typ: LiteralType) -> None:
395399
typ.fallback.accept(self)
396400

mypy/subtypes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,11 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
327327
else:
328328
return False
329329

330-
def visit_literal_type(self, t: LiteralType) -> bool:
331-
raise NotImplementedError()
330+
def visit_literal_type(self, left: LiteralType) -> bool:
331+
if isinstance(self.right, LiteralType):
332+
return left == self.right
333+
else:
334+
return self._is_subtype(left.fallback, self.right)
332335

333336
def visit_overloaded(self, left: Overloaded) -> bool:
334337
right = self.right
@@ -1172,7 +1175,10 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
11721175
return self._is_proper_subtype(left.fallback, right)
11731176

11741177
def visit_literal_type(self, left: LiteralType) -> bool:
1175-
raise NotImplementedError()
1178+
if isinstance(self.right, LiteralType):
1179+
return left == self.right
1180+
else:
1181+
return self._is_proper_subtype(left.fallback, self.right)
11761182

11771183
def visit_overloaded(self, left: Overloaded) -> bool:
11781184
# TODO: What's the right thing to do here?

mypy/test/data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
4242
join = posixpath.join # type: ignore
4343

4444
out_section_missing = case.suite.required_out_section
45+
normalize_output = True
4546

4647
files = [] # type: List[Tuple[str, str]] # path and contents
4748
output_files = [] # type: List[Tuple[str, str]] # path and contents for output files
@@ -98,8 +99,11 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
9899
full = join(base_path, m.group(1))
99100
deleted_paths.setdefault(num, set()).add(full)
100101
elif re.match(r'out[0-9]*$', item.id):
102+
if item.arg == 'skip-path-normalization':
103+
normalize_output = False
104+
101105
tmp_output = [expand_variables(line) for line in item.data]
102-
if os.path.sep == '\\':
106+
if os.path.sep == '\\' and normalize_output:
103107
tmp_output = [fix_win_path(line) for line in tmp_output]
104108
if item.id == 'out' or item.id == 'out1':
105109
output = tmp_output
@@ -147,6 +151,7 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None:
147151
case.expected_rechecked_modules = rechecked_modules
148152
case.deleted_paths = deleted_paths
149153
case.triggered = triggered or []
154+
case.normalize_output = normalize_output
150155

151156

152157
class DataDrivenTestCase(pytest.Item): # type: ignore # inheriting from Any
@@ -168,6 +173,10 @@ class DataDrivenTestCase(pytest.Item): # type: ignore # inheriting from Any
168173
# Files/directories to clean up after test case; (is directory, path) tuples
169174
clean_up = None # type: List[Tuple[bool, str]]
170175

176+
# Whether or not we should normalize the output to standardize things like
177+
# forward vs backward slashes in file paths for Windows vs Linux.
178+
normalize_output = True
179+
171180
def __init__(self,
172181
parent: 'DataSuiteCollector',
173182
suite: 'DataSuite',

mypy/test/testcheck.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
'check-ctypes.test',
8181
'check-dataclasses.test',
8282
'check-final.test',
83+
'check-literal.test',
8384
]
8485

8586

@@ -177,7 +178,8 @@ def run_case_once(self, testcase: DataDrivenTestCase,
177178
assert sys.path[0] == plugin_dir
178179
del sys.path[0]
179180

180-
a = normalize_error_messages(a)
181+
if testcase.normalize_output:
182+
a = normalize_error_messages(a)
181183

182184
# Make sure error messages match
183185
if incremental_step == 0:

mypy/test/testcmdline.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,19 @@ def test_python_cmdline(testcase: DataDrivenTestCase) -> None:
8585
actual_output_content = output_file.read().splitlines()
8686
normalized_output = normalize_file_output(actual_output_content,
8787
os.path.abspath(test_temp_dir))
88-
if testcase.suite.native_sep and os.path.sep == '\\':
89-
normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
90-
normalized_output = normalize_error_messages(normalized_output)
88+
# We always normalize things like timestamp, but only handle operating-system
89+
# specific things if requested.
90+
if testcase.normalize_output:
91+
if testcase.suite.native_sep and os.path.sep == '\\':
92+
normalized_output = [fix_cobertura_filename(line)
93+
for line in normalized_output]
94+
normalized_output = normalize_error_messages(normalized_output)
9195
assert_string_arrays_equal(expected_content.splitlines(), normalized_output,
9296
'Output file {} did not match its expected output'.format(
9397
path))
9498
else:
95-
out = normalize_error_messages(err + out)
99+
if testcase.normalize_output:
100+
out = normalize_error_messages(err + out)
96101
obvious_result = 1 if out else 0
97102
if obvious_result != result:
98103
out.append('== Return code: {}'.format(result))

mypy/test/testmerge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
9393
# Verify that old AST nodes are removed from the expression type map.
9494
assert expr not in new_types
9595

96-
a = normalize_error_messages(a)
96+
if testcase.normalize_output:
97+
a = normalize_error_messages(a)
9798

9899
assert_string_arrays_equal(
99100
testcase.output, a,

0 commit comments

Comments
 (0)