Skip to content

Commit 37777b3

Browse files
authored
Predict enum value type for unknown member names (#9443)
It is very common for enums to have homogenous member-value types. In the case where we do not know what enum member we are dealing with, we should sniff for that case and still collapse to a known type if that assumption holds. Handles auto() too, even if you override _get_next_value_.
1 parent b707d29 commit 37777b3

File tree

3 files changed

+164
-20
lines changed

3 files changed

+164
-20
lines changed

mypy/plugins/enums.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
we actually bake some of it directly in to the semantic analysis layer (see
1111
semanal_enum.py).
1212
"""
13-
from typing import Optional
13+
from typing import Iterable, Optional, TypeVar
1414
from typing_extensions import Final
1515

1616
import mypy.plugin # To avoid circular imports.
17-
from mypy.types import Type, Instance, LiteralType, get_proper_type
17+
from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type
1818

1919
# Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
2020
# enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
@@ -53,6 +53,56 @@ def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
5353
return str_type.copy_modified(last_known_value=literal_type)
5454

5555

56+
_T = TypeVar('_T')
57+
58+
59+
def _first(it: Iterable[_T]) -> Optional[_T]:
60+
"""Return the first value from any iterable.
61+
62+
Returns ``None`` if the iterable is empty.
63+
"""
64+
for val in it:
65+
return val
66+
return None
67+
68+
69+
def _infer_value_type_with_auto_fallback(
70+
ctx: 'mypy.plugin.AttributeContext',
71+
proper_type: Optional[ProperType]) -> Optional[Type]:
72+
"""Figure out the type of an enum value accounting for `auto()`.
73+
74+
This method is a no-op for a `None` proper_type and also in the case where
75+
the type is not "enum.auto"
76+
"""
77+
if proper_type is None:
78+
return None
79+
if not ((isinstance(proper_type, Instance) and
80+
proper_type.type.fullname == 'enum.auto')):
81+
return proper_type
82+
assert isinstance(ctx.type, Instance), 'An incorrect ctx.type was passed.'
83+
info = ctx.type.type
84+
# Find the first _generate_next_value_ on the mro. We need to know
85+
# if it is `Enum` because `Enum` types say that the return-value of
86+
# `_generate_next_value_` is `Any`. In reality the default `auto()`
87+
# returns an `int` (presumably the `Any` in typeshed is to make it
88+
# easier to subclass and change the returned type).
89+
type_with_gnv = _first(
90+
ti for ti in info.mro if ti.names.get('_generate_next_value_'))
91+
if type_with_gnv is None:
92+
return ctx.default_attr_type
93+
94+
stnode = type_with_gnv.names['_generate_next_value_']
95+
96+
# This should be a `CallableType`
97+
node_type = get_proper_type(stnode.type)
98+
if isinstance(node_type, CallableType):
99+
if type_with_gnv.fullname == 'enum.Enum':
100+
int_type = ctx.api.named_generic_type('builtins.int', [])
101+
return int_type
102+
return get_proper_type(node_type.ret_type)
103+
return ctx.default_attr_type
104+
105+
56106
def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
57107
"""This plugin refines the 'value' attribute in enums to refer to
58108
the original underlying value. For example, suppose we have the
@@ -78,6 +128,32 @@ class SomeEnum:
78128
"""
79129
enum_field_name = _extract_underlying_field_name(ctx.type)
80130
if enum_field_name is None:
131+
# We do not know the enum field name (perhaps it was passed to a
132+
# function and we only know that it _is_ a member). All is not lost
133+
# however, if we can prove that the all of the enum members have the
134+
# same value-type, then it doesn't matter which member was passed in.
135+
# The value-type is still known.
136+
if isinstance(ctx.type, Instance):
137+
info = ctx.type.type
138+
stnodes = (info.get(name) for name in info.names)
139+
# Enums _can_ have methods.
140+
# Omit methods for our value inference.
141+
node_types = (
142+
get_proper_type(n.type) if n else None
143+
for n in stnodes)
144+
proper_types = (
145+
_infer_value_type_with_auto_fallback(ctx, t)
146+
for t in node_types
147+
if t is None or not isinstance(t, CallableType))
148+
underlying_type = _first(proper_types)
149+
if underlying_type is None:
150+
return ctx.default_attr_type
151+
all_same_value_type = all(
152+
proper_type is not None and proper_type == underlying_type
153+
for proper_type in proper_types)
154+
if all_same_value_type:
155+
if underlying_type is not None:
156+
return underlying_type
81157
return ctx.default_attr_type
82158

83159
assert isinstance(ctx.type, Instance)
@@ -86,15 +162,9 @@ class SomeEnum:
86162
if stnode is None:
87163
return ctx.default_attr_type
88164

89-
underlying_type = get_proper_type(stnode.type)
165+
underlying_type = _infer_value_type_with_auto_fallback(
166+
ctx, get_proper_type(stnode.type))
90167
if underlying_type is None:
91-
# TODO: Deduce the inferred type if the user omits adding their own default types.
92-
# TODO: Consider using the return type of `Enum._generate_next_value_` here?
93-
return ctx.default_attr_type
94-
95-
if isinstance(underlying_type, Instance) and underlying_type.type.fullname == 'enum.auto':
96-
# TODO: Deduce the correct inferred type when the user uses 'enum.auto'.
97-
# We should use the same strategy we end up picking up above.
98168
return ctx.default_attr_type
99169

100170
return underlying_type

test-data/unit/check-enum.test

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,76 @@ reveal_type(Truth.true.name) # N: Revealed type is 'Literal['true']?'
5959
reveal_type(Truth.false.value) # N: Revealed type is 'builtins.bool'
6060
[builtins fixtures/bool.pyi]
6161

62+
[case testEnumValueExtended]
63+
from enum import Enum
64+
class Truth(Enum):
65+
true = True
66+
false = False
67+
68+
def infer_truth(truth: Truth) -> None:
69+
reveal_type(truth.value) # N: Revealed type is 'builtins.bool'
70+
[builtins fixtures/bool.pyi]
71+
72+
[case testEnumValueAllAuto]
73+
from enum import Enum, auto
74+
class Truth(Enum):
75+
true = auto()
76+
false = auto()
77+
78+
def infer_truth(truth: Truth) -> None:
79+
reveal_type(truth.value) # N: Revealed type is 'builtins.int'
80+
[builtins fixtures/primitives.pyi]
81+
82+
[case testEnumValueSomeAuto]
83+
from enum import Enum, auto
84+
class Truth(Enum):
85+
true = 8675309
86+
false = auto()
87+
88+
def infer_truth(truth: Truth) -> None:
89+
reveal_type(truth.value) # N: Revealed type is 'builtins.int'
90+
[builtins fixtures/primitives.pyi]
91+
92+
[case testEnumValueExtraMethods]
93+
from enum import Enum, auto
94+
class Truth(Enum):
95+
true = True
96+
false = False
97+
98+
def foo(self) -> str:
99+
return 'bar'
100+
101+
def infer_truth(truth: Truth) -> None:
102+
reveal_type(truth.value) # N: Revealed type is 'builtins.bool'
103+
[builtins fixtures/bool.pyi]
104+
105+
[case testEnumValueCustomAuto]
106+
from enum import Enum, auto
107+
class AutoName(Enum):
108+
109+
# In `typeshed`, this is a staticmethod and has more arguments,
110+
# but I have lied a bit to keep the test stubs lean.
111+
def _generate_next_value_(self) -> str:
112+
return "name"
113+
114+
class Truth(AutoName):
115+
true = auto()
116+
false = auto()
117+
118+
def infer_truth(truth: Truth) -> None:
119+
reveal_type(truth.value) # N: Revealed type is 'builtins.str'
120+
[builtins fixtures/primitives.pyi]
121+
122+
[case testEnumValueInhomogenous]
123+
from enum import Enum
124+
class Truth(Enum):
125+
true = 'True'
126+
false = 0
127+
128+
def cannot_infer_truth(truth: Truth) -> None:
129+
reveal_type(truth.value) # N: Revealed type is 'Any'
130+
[builtins fixtures/bool.pyi]
131+
62132
[case testEnumUnique]
63133
import enum
64134
@enum.unique
@@ -497,8 +567,8 @@ reveal_type(A1.x.value) # N: Revealed type is 'Any'
497567
reveal_type(A1.x._value_) # N: Revealed type is 'Any'
498568
is_x(reveal_type(A2.x.name)) # N: Revealed type is 'Literal['x']'
499569
is_x(reveal_type(A2.x._name_)) # N: Revealed type is 'Literal['x']'
500-
reveal_type(A2.x.value) # N: Revealed type is 'Any'
501-
reveal_type(A2.x._value_) # N: Revealed type is 'Any'
570+
reveal_type(A2.x.value) # N: Revealed type is 'builtins.int'
571+
reveal_type(A2.x._value_) # N: Revealed type is 'builtins.int'
502572
is_x(reveal_type(A3.x.name)) # N: Revealed type is 'Literal['x']'
503573
is_x(reveal_type(A3.x._name_)) # N: Revealed type is 'Literal['x']'
504574
reveal_type(A3.x.value) # N: Revealed type is 'builtins.int'
@@ -519,7 +589,7 @@ reveal_type(B1.x._value_) # N: Revealed type is 'Any'
519589
is_x(reveal_type(B2.x.name)) # N: Revealed type is 'Literal['x']'
520590
is_x(reveal_type(B2.x._name_)) # N: Revealed type is 'Literal['x']'
521591
reveal_type(B2.x.value) # N: Revealed type is 'builtins.int'
522-
reveal_type(B2.x._value_) # N: Revealed type is 'Any'
592+
reveal_type(B2.x._value_) # N: Revealed type is 'builtins.int'
523593
is_x(reveal_type(B3.x.name)) # N: Revealed type is 'Literal['x']'
524594
is_x(reveal_type(B3.x._name_)) # N: Revealed type is 'Literal['x']'
525595
reveal_type(B3.x.value) # N: Revealed type is 'builtins.int'
@@ -540,8 +610,8 @@ reveal_type(C1.x.value) # N: Revealed type is 'Any'
540610
reveal_type(C1.x._value_) # N: Revealed type is 'Any'
541611
is_x(reveal_type(C2.x.name)) # N: Revealed type is 'Literal['x']'
542612
is_x(reveal_type(C2.x._name_)) # N: Revealed type is 'Literal['x']'
543-
reveal_type(C2.x.value) # N: Revealed type is 'Any'
544-
reveal_type(C2.x._value_) # N: Revealed type is 'Any'
613+
reveal_type(C2.x.value) # N: Revealed type is 'builtins.int'
614+
reveal_type(C2.x._value_) # N: Revealed type is 'builtins.int'
545615
is_x(reveal_type(C3.x.name)) # N: Revealed type is 'Literal['x']'
546616
is_x(reveal_type(C3.x._name_)) # N: Revealed type is 'Literal['x']'
547617
reveal_type(C3.x.value) # N: Revealed type is 'builtins.int'
@@ -559,8 +629,8 @@ reveal_type(D1.x.value) # N: Revealed type is 'Any'
559629
reveal_type(D1.x._value_) # N: Revealed type is 'Any'
560630
is_x(reveal_type(D2.x.name)) # N: Revealed type is 'Literal['x']'
561631
is_x(reveal_type(D2.x._name_)) # N: Revealed type is 'Literal['x']'
562-
reveal_type(D2.x.value) # N: Revealed type is 'Any'
563-
reveal_type(D2.x._value_) # N: Revealed type is 'Any'
632+
reveal_type(D2.x.value) # N: Revealed type is 'builtins.int'
633+
reveal_type(D2.x._value_) # N: Revealed type is 'builtins.int'
564634
is_x(reveal_type(D3.x.name)) # N: Revealed type is 'Literal['x']'
565635
is_x(reveal_type(D3.x._name_)) # N: Revealed type is 'Literal['x']'
566636
reveal_type(D3.x.value) # N: Revealed type is 'builtins.int'
@@ -578,8 +648,8 @@ class E3(Parent):
578648

579649
is_x(reveal_type(E2.x.name)) # N: Revealed type is 'Literal['x']'
580650
is_x(reveal_type(E2.x._name_)) # N: Revealed type is 'Literal['x']'
581-
reveal_type(E2.x.value) # N: Revealed type is 'Any'
582-
reveal_type(E2.x._value_) # N: Revealed type is 'Any'
651+
reveal_type(E2.x.value) # N: Revealed type is 'builtins.int'
652+
reveal_type(E2.x._value_) # N: Revealed type is 'builtins.int'
583653
is_x(reveal_type(E3.x.name)) # N: Revealed type is 'Literal['x']'
584654
is_x(reveal_type(E3.x._name_)) # N: Revealed type is 'Literal['x']'
585655
reveal_type(E3.x.value) # N: Revealed type is 'builtins.int'

test-data/unit/lib-stub/enum.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class Enum(metaclass=EnumMeta):
2121
_name_: str
2222
_value_: Any
2323

24+
# In reality, _generate_next_value_ is python3.6 only and has a different signature.
25+
# However, this should be quick and doesn't require additional stubs (e.g. `staticmethod`)
26+
def _generate_next_value_(self) -> Any: pass
27+
2428
class IntEnum(int, Enum):
2529
value: int
2630

@@ -37,4 +41,4 @@ class IntFlag(int, Flag):
3741

3842

3943
class auto(IntFlag):
40-
value: Any
44+
value: Any

0 commit comments

Comments
 (0)