Skip to content

Commit 56618b9

Browse files
authored
Support union type syntax in runtime contexts (#10770)
Support the X | Y syntax (PEP 604) in type aliases, casts, type applications and base classes. These are only available in Python 3.10 mode and probably won't work in stubs when targeting earlier Python versions yet. Work on #9880.
1 parent 178df79 commit 56618b9

11 files changed

+86
-45
lines changed

mypy/exprtotype.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from mypy.nodes import (
66
Expression, NameExpr, MemberExpr, IndexExpr, RefExpr, TupleExpr, IntExpr, FloatExpr, UnaryExpr,
7-
ComplexExpr, ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
7+
ComplexExpr, ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr, OpExpr,
88
get_member_expr_fullname
99
)
1010
from mypy.fastparse import parse_type_string
1111
from mypy.types import (
1212
Type, UnboundType, TypeList, EllipsisType, AnyType, CallableArgument, TypeOfAny,
13-
RawExpressionType, ProperType
13+
RawExpressionType, ProperType, UnionType
1414
)
15+
from mypy.options import Options
1516

1617

1718
class TypeTranslationError(Exception):
@@ -29,7 +30,9 @@ def _extract_argument_name(expr: Expression) -> Optional[str]:
2930
raise TypeTranslationError()
3031

3132

32-
def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> ProperType:
33+
def expr_to_unanalyzed_type(expr: Expression,
34+
options: Optional[Options] = None,
35+
_parent: Optional[Expression] = None) -> ProperType:
3336
"""Translate an expression to the corresponding type.
3437
3538
The result is not semantically analyzed. It can be UnboundType or TypeList.
@@ -53,7 +56,7 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
5356
else:
5457
raise TypeTranslationError()
5558
elif isinstance(expr, IndexExpr):
56-
base = expr_to_unanalyzed_type(expr.base, expr)
59+
base = expr_to_unanalyzed_type(expr.base, options, expr)
5760
if isinstance(base, UnboundType):
5861
if base.args:
5962
raise TypeTranslationError()
@@ -69,14 +72,20 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
6972
# of the Annotation definition and only returning the type information,
7073
# losing all the annotations.
7174

72-
return expr_to_unanalyzed_type(args[0], expr)
75+
return expr_to_unanalyzed_type(args[0], options, expr)
7376
else:
74-
base.args = tuple(expr_to_unanalyzed_type(arg, expr) for arg in args)
77+
base.args = tuple(expr_to_unanalyzed_type(arg, options, expr) for arg in args)
7578
if not base.args:
7679
base.empty_tuple_index = True
7780
return base
7881
else:
7982
raise TypeTranslationError()
83+
elif (isinstance(expr, OpExpr)
84+
and expr.op == '|'
85+
and options
86+
and options.python_version >= (3, 10)):
87+
return UnionType([expr_to_unanalyzed_type(expr.left, options),
88+
expr_to_unanalyzed_type(expr.right, options)])
8089
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
8190
c = expr.callee
8291
names = []
@@ -109,19 +118,19 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
109118
if typ is not default_type:
110119
# Two types
111120
raise TypeTranslationError()
112-
typ = expr_to_unanalyzed_type(arg, expr)
121+
typ = expr_to_unanalyzed_type(arg, options, expr)
113122
continue
114123
else:
115124
raise TypeTranslationError()
116125
elif i == 0:
117-
typ = expr_to_unanalyzed_type(arg, expr)
126+
typ = expr_to_unanalyzed_type(arg, options, expr)
118127
elif i == 1:
119128
name = _extract_argument_name(arg)
120129
else:
121130
raise TypeTranslationError()
122131
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
123132
elif isinstance(expr, ListExpr):
124-
return TypeList([expr_to_unanalyzed_type(t, expr) for t in expr.items],
133+
return TypeList([expr_to_unanalyzed_type(t, options, expr) for t in expr.items],
125134
line=expr.line, column=expr.column)
126135
elif isinstance(expr, StrExpr):
127136
return parse_type_string(expr.value, 'builtins.str', expr.line, expr.column,
@@ -133,7 +142,7 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
133142
return parse_type_string(expr.value, 'builtins.unicode', expr.line, expr.column,
134143
assume_str_is_unicode=True)
135144
elif isinstance(expr, UnaryExpr):
136-
typ = expr_to_unanalyzed_type(expr.expr)
145+
typ = expr_to_unanalyzed_type(expr.expr, options)
137146
if isinstance(typ, RawExpressionType):
138147
if isinstance(typ.literal_value, int) and expr.op == '-':
139148
typ.literal_value *= -1

mypy/plugins/attrs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
552552
type_arg = _get_argument(rvalue, 'type')
553553
if type_arg and not init_type:
554554
try:
555-
un_type = expr_to_unanalyzed_type(type_arg)
555+
un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options)
556556
except TypeTranslationError:
557557
ctx.api.fail('Invalid argument to type', type_arg)
558558
else:

mypy/semanal.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ class Foo(Bar, Generic[T]): ...
12671267
self.analyze_type_expr(base_expr)
12681268

12691269
try:
1270-
base = expr_to_unanalyzed_type(base_expr)
1270+
base = expr_to_unanalyzed_type(base_expr, self.options)
12711271
except TypeTranslationError:
12721272
# This error will be caught later.
12731273
continue
@@ -1373,7 +1373,7 @@ def get_all_bases_tvars(self,
13731373
for i, base_expr in enumerate(base_type_exprs):
13741374
if i not in removed:
13751375
try:
1376-
base = expr_to_unanalyzed_type(base_expr)
1376+
base = expr_to_unanalyzed_type(base_expr, self.options)
13771377
except TypeTranslationError:
13781378
# This error will be caught later.
13791379
continue
@@ -2101,7 +2101,7 @@ def should_wait_rhs(self, rv: Expression) -> bool:
21012101
return self.should_wait_rhs(rv.callee)
21022102
return False
21032103

2104-
def can_be_type_alias(self, rv: Expression) -> bool:
2104+
def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool:
21052105
"""Is this a valid r.h.s. for an alias definition?
21062106
21072107
Note: this function should be only called for expressions where self.should_wait_rhs()
@@ -2113,6 +2113,13 @@ def can_be_type_alias(self, rv: Expression) -> bool:
21132113
return True
21142114
if self.is_none_alias(rv):
21152115
return True
2116+
if allow_none and isinstance(rv, NameExpr) and rv.fullname == 'builtins.None':
2117+
return True
2118+
if (isinstance(rv, OpExpr)
2119+
and rv.op == '|'
2120+
and self.can_be_type_alias(rv.left, allow_none=True)
2121+
and self.can_be_type_alias(rv.right, allow_none=True)):
2122+
return True
21162123
return False
21172124

21182125
def is_type_ref(self, rv: Expression, bare: bool = False) -> bool:
@@ -3195,7 +3202,7 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]:
31953202
result: List[Type] = []
31963203
for node in items:
31973204
try:
3198-
analyzed = self.anal_type(expr_to_unanalyzed_type(node),
3205+
analyzed = self.anal_type(expr_to_unanalyzed_type(node, self.options),
31993206
allow_placeholder=True)
32003207
if analyzed is None:
32013208
# Type variables are special: we need to place them in the symbol table
@@ -3638,7 +3645,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
36383645
return
36393646
# Translate first argument to an unanalyzed type.
36403647
try:
3641-
target = expr_to_unanalyzed_type(expr.args[0])
3648+
target = expr_to_unanalyzed_type(expr.args[0], self.options)
36423649
except TypeTranslationError:
36433650
self.fail('Cast target is not a type', expr)
36443651
return
@@ -3696,7 +3703,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
36963703
return
36973704
# Translate first argument to an unanalyzed type.
36983705
try:
3699-
target = expr_to_unanalyzed_type(expr.args[0])
3706+
target = expr_to_unanalyzed_type(expr.args[0], self.options)
37003707
except TypeTranslationError:
37013708
self.fail('Argument 1 to _promote is not a type', expr)
37023709
return
@@ -3892,7 +3899,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]
38923899
items = [index]
38933900
for item in items:
38943901
try:
3895-
typearg = expr_to_unanalyzed_type(item)
3902+
typearg = expr_to_unanalyzed_type(item, self.options)
38963903
except TypeTranslationError:
38973904
self.fail('Type expected within [...]', expr)
38983905
return None
@@ -4199,7 +4206,7 @@ def lookup_qualified(self, name: str, ctx: Context,
41994206

42004207
def lookup_type_node(self, expr: Expression) -> Optional[SymbolTableNode]:
42014208
try:
4202-
t = expr_to_unanalyzed_type(expr)
4209+
t = expr_to_unanalyzed_type(expr, self.options)
42034210
except TypeTranslationError:
42044211
return None
42054212
if isinstance(t, UnboundType):
@@ -4919,7 +4926,7 @@ def expr_to_analyzed_type(self,
49194926
assert info.tuple_type, "NamedTuple without tuple type"
49204927
fallback = Instance(info, [])
49214928
return TupleType(info.tuple_type.items, fallback=fallback)
4922-
typ = expr_to_unanalyzed_type(expr)
4929+
typ = expr_to_unanalyzed_type(expr, self.options)
49234930
return self.anal_type(typ, report_invalid_types=report_invalid_types,
49244931
allow_placeholder=allow_placeholder)
49254932

mypy/semanal_namedtuple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C
356356
self.fail("Invalid NamedTuple() field name", item)
357357
return None
358358
try:
359-
type = expr_to_unanalyzed_type(type_node)
359+
type = expr_to_unanalyzed_type(type_node, self.options)
360360
except TypeTranslationError:
361361
self.fail('Invalid field type', type_node)
362362
return None

mypy/semanal_newtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def check_newtype_args(self, name: str, call: CallExpr,
160160
# Check second argument
161161
msg = "Argument 2 to NewType(...) must be a valid type"
162162
try:
163-
unanalyzed_type = expr_to_unanalyzed_type(args[1])
163+
unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options)
164164
except TypeTranslationError:
165165
self.fail(msg, context)
166166
return None, False

mypy/semanal_typeddict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def parse_typeddict_fields_with_types(
290290
self.fail_typeddict_arg("Invalid TypedDict() field name", name_context)
291291
return [], [], False
292292
try:
293-
type = expr_to_unanalyzed_type(field_type_expr)
293+
type = expr_to_unanalyzed_type(field_type_expr, self.options)
294294
except TypeTranslationError:
295295
self.fail_typeddict_arg('Invalid field type', field_type_expr)
296296
return [], [], False

mypy/typeanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def analyze_type_alias(node: Expression,
8080
Return None otherwise. 'node' must have been semantically analyzed.
8181
"""
8282
try:
83-
type = expr_to_unanalyzed_type(node)
83+
type = expr_to_unanalyzed_type(node, options)
8484
except TypeTranslationError:
8585
api.fail('Invalid type alias: expression is not a valid type', node)
8686
return None

test-data/unit/check-classes.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3154,7 +3154,7 @@ def foo(arg: Type[Any]):
31543154
from typing import Type, Any
31553155
def foo(arg: Type[Any]):
31563156
reveal_type(arg.__str__) # N: Revealed type is "def () -> builtins.str"
3157-
reveal_type(arg.mro()) # N: Revealed type is "builtins.list[builtins.type]"
3157+
reveal_type(arg.mro()) # N: Revealed type is "builtins.list[builtins.type[Any]]"
31583158
[builtins fixtures/type.pyi]
31593159
[out]
31603160

test-data/unit/check-union-or-syntax.test

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def f(x: int | str) -> int | str:
1111
reveal_type(f) # N: Revealed type is "def (x: Union[builtins.int, builtins.str]) -> Union[builtins.int, builtins.str]"
1212
[builtins fixtures/tuple.pyi]
1313

14-
1514
[case testUnionOrSyntaxWithThreeBuiltinsTypes]
1615
# flags: --python-version 3.10
1716
def f(x: int | str | float) -> int | str | float:
@@ -21,7 +20,6 @@ def f(x: int | str | float) -> int | str | float:
2120
return x
2221
reveal_type(f) # N: Revealed type is "def (x: Union[builtins.int, builtins.str, builtins.float]) -> Union[builtins.int, builtins.str, builtins.float]"
2322

24-
2523
[case testUnionOrSyntaxWithTwoTypes]
2624
# flags: --python-version 3.10
2725
class A: pass
@@ -33,7 +31,6 @@ def f(x: A | B) -> A | B:
3331
return x
3432
reveal_type(f) # N: Revealed type is "def (x: Union[__main__.A, __main__.B]) -> Union[__main__.A, __main__.B]"
3533

36-
3734
[case testUnionOrSyntaxWithThreeTypes]
3835
# flags: --python-version 3.10
3936
class A: pass
@@ -46,34 +43,29 @@ def f(x: A | B | C) -> A | B | C:
4643
return x
4744
reveal_type(f) # N: Revealed type is "def (x: Union[__main__.A, __main__.B, __main__.C]) -> Union[__main__.A, __main__.B, __main__.C]"
4845

49-
5046
[case testUnionOrSyntaxWithLiteral]
5147
# flags: --python-version 3.10
5248
from typing_extensions import Literal
5349
reveal_type(Literal[4] | str) # N: Revealed type is "Any"
5450
[builtins fixtures/tuple.pyi]
5551

56-
5752
[case testUnionOrSyntaxWithBadOperator]
5853
# flags: --python-version 3.10
5954
x: 1 + 2 # E: Invalid type comment or annotation
6055

61-
6256
[case testUnionOrSyntaxWithBadOperands]
6357
# flags: --python-version 3.10
6458
x: int | 42 # E: Invalid type: try using Literal[42] instead?
6559
y: 42 | int # E: Invalid type: try using Literal[42] instead?
6660
z: str | 42 | int # E: Invalid type: try using Literal[42] instead?
6761

68-
6962
[case testUnionOrSyntaxWithGenerics]
7063
# flags: --python-version 3.10
7164
from typing import List
7265
x: List[int | str]
7366
reveal_type(x) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"
7467
[builtins fixtures/list.pyi]
7568

76-
7769
[case testUnionOrSyntaxWithQuotedFunctionTypes]
7870
# flags: --python-version 3.4
7971
from typing import Union
@@ -87,47 +79,79 @@ def g(x: "int | str | None") -> "int | None":
8779
return 42
8880
reveal_type(g) # N: Revealed type is "def (x: Union[builtins.int, builtins.str, None]) -> Union[builtins.int, None]"
8981

90-
9182
[case testUnionOrSyntaxWithQuotedVariableTypes]
9283
# flags: --python-version 3.6
9384
y: "int | str" = 42
9485
reveal_type(y) # N: Revealed type is "Union[builtins.int, builtins.str]"
9586

96-
9787
[case testUnionOrSyntaxWithTypeAliasWorking]
9888
# flags: --python-version 3.10
99-
from typing import Union
100-
T = Union[int, str]
89+
T = int | str
10190
x: T
10291
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"
92+
S = list[int] | str | None
93+
y: S
94+
reveal_type(y) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.str, None]"
95+
U = str | None
96+
z: U
97+
reveal_type(z) # N: Revealed type is "Union[builtins.str, None]"
98+
99+
def f(): pass
103100

101+
X = int | str | f()
102+
b: X # E: Variable "__main__.X" is not valid as a type \
103+
# N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
104+
[builtins fixtures/type.pyi]
104105

105-
[case testUnionOrSyntaxWithTypeAliasNotAllowed]
106+
[case testUnionOrSyntaxWithinRuntimeContextNotAllowed]
106107
# flags: --python-version 3.9
107108
from __future__ import annotations
108-
T = int | str # E: Unsupported left operand type for | ("Type[int]")
109+
from typing import List
110+
T = int | str # E: Invalid type alias: expression is not a valid type \
111+
# E: Unsupported left operand type for | ("Type[int]")
112+
class C(List[int | str]): # E: Type expected within [...] \
113+
# E: Invalid base class "List"
114+
pass
115+
C()
109116
[builtins fixtures/tuple.pyi]
110117

118+
[case testUnionOrSyntaxWithinRuntimeContextNotAllowed2]
119+
# flags: --python-version 3.9
120+
from __future__ import annotations
121+
from typing import cast
122+
cast(str | int, 'x') # E: Cast target is not a type
123+
[builtins fixtures/tuple.pyi]
124+
[typing fixtures/typing-full.pyi]
111125

112126
[case testUnionOrSyntaxInComment]
113127
# flags: --python-version 3.6
114128
x = 1 # type: int | str
115129

116-
117130
[case testUnionOrSyntaxFutureImport]
118131
# flags: --python-version 3.7
119132
from __future__ import annotations
120133
x: int | None
121134
[builtins fixtures/tuple.pyi]
122135

123-
124136
[case testUnionOrSyntaxMissingFutureImport]
125137
# flags: --python-version 3.9
126138
x: int | None # E: X | Y syntax for unions requires Python 3.10
127139

128-
129140
[case testUnionOrSyntaxInStubFile]
130141
# flags: --python-version 3.6
131142
from lib import x
132143
[file lib.pyi]
133144
x: int | None
145+
146+
[case testUnionOrSyntaxInMiscRuntimeContexts]
147+
# flags: --python-version 3.10
148+
from typing import cast
149+
150+
class C(list[int | None]):
151+
pass
152+
153+
def f() -> object: pass
154+
155+
reveal_type(cast(str | None, f())) # N: Revealed type is "Union[builtins.str, None]"
156+
reveal_type(list[str | None]()) # N: Revealed type is "builtins.list[Union[builtins.str, None]]"
157+
[builtins fixtures/type.pyi]

0 commit comments

Comments
 (0)