Skip to content

Commit a90841f

Browse files
elazarggvanrossum
authored andcommitted
Sanity checks for declared selftype (#2381)
Fixes #2374.
1 parent 6918ccf commit a90841f

File tree

6 files changed

+171
-65
lines changed

6 files changed

+171
-65
lines changed

mypy/checker.py

+77-53
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import itertools
44
import fnmatch
5+
from contextlib import contextmanager
56

67
from typing import (
7-
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple
8+
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator
89
)
910

1011
from mypy.errors import Errors, report_internal_error
@@ -35,10 +36,11 @@
3536
from mypy.sametypes import is_same_type
3637
from mypy.messages import MessageBuilder
3738
import mypy.checkexpr
38-
from mypy.checkmember import map_type_from_supertype, bind_self
39+
from mypy.checkmember import map_type_from_supertype, bind_self, erase_to_bound
3940
from mypy import messages
4041
from mypy.subtypes import (
41-
is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away
42+
is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away,
43+
is_subtype_ignoring_tvars
4244
)
4345
from mypy.maptype import map_instance_to_supertype
4446
from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname
@@ -65,7 +67,7 @@
6567
[
6668
('node', FuncItem),
6769
('context_type_name', Optional[str]), # Name of the surrounding class (for error messages)
68-
('class_type', Optional[Type]), # And its type (from class_context)
70+
('active_class', Optional[Type]), # And its type (for selftype handline)
6971
])
7072

7173

@@ -91,19 +93,13 @@ class TypeChecker(NodeVisitor[Type]):
9193
# Helper for type checking expressions
9294
expr_checker = None # type: mypy.checkexpr.ExpressionChecker
9395

94-
# Class context for checking overriding of a method of the form
95-
# def foo(self: T) -> T
96-
# We need to pass the current class definition for instantiation of T
97-
class_context = None # type: List[Type]
98-
96+
scope = None # type: Scope
9997
# Stack of function return types
10098
return_types = None # type: List[Type]
10199
# Type context for type inference
102100
type_context = None # type: List[Type]
103101
# Flags; true for dynamically typed functions
104102
dynamic_funcs = None # type: List[bool]
105-
# Stack of functions being type checked
106-
function_stack = None # type: List[FuncItem]
107103
# Stack of collections of variables with partial types
108104
partial_types = None # type: List[Dict[Var, Context]]
109105
globals = None # type: SymbolTable
@@ -139,13 +135,12 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
139135
self.path = path
140136
self.msg = MessageBuilder(errors, modules)
141137
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg)
142-
self.class_context = []
138+
self.scope = Scope(tree)
143139
self.binder = ConditionalTypeBinder()
144140
self.globals = tree.names
145141
self.return_types = []
146142
self.type_context = []
147143
self.dynamic_funcs = []
148-
self.function_stack = []
149144
self.partial_types = []
150145
self.deferred_nodes = []
151146
self.type_map = {}
@@ -203,7 +198,7 @@ def check_second_pass(self) -> bool:
203198
todo = self.deferred_nodes
204199
self.deferred_nodes = []
205200
done = set() # type: Set[FuncItem]
206-
for node, type_name, class_type in todo:
201+
for node, type_name, active_class in todo:
207202
if node in done:
208203
continue
209204
# This is useful for debugging:
@@ -212,28 +207,27 @@ def check_second_pass(self) -> bool:
212207
done.add(node)
213208
if type_name:
214209
self.errors.push_type(type_name)
215-
if class_type:
216-
self.class_context.append(class_type)
217-
self.accept(node)
218-
if class_type:
219-
self.class_context.pop()
210+
211+
if active_class:
212+
with self.scope.push_class(active_class):
213+
self.accept(node)
214+
else:
215+
self.accept(node)
220216
if type_name:
221217
self.errors.pop_type()
222218
return True
223219

224220
def handle_cannot_determine_type(self, name: str, context: Context) -> None:
225-
if self.pass_num < LAST_PASS and self.function_stack:
221+
node = self.scope.top_function()
222+
if self.pass_num < LAST_PASS and node is not None:
226223
# Don't report an error yet. Just defer.
227-
node = self.function_stack[-1]
228224
if self.errors.type_name:
229225
type_name = self.errors.type_name[-1]
230226
else:
231227
type_name = None
232-
if self.class_context:
233-
class_context_top = self.class_context[-1]
234-
else:
235-
class_context_top = None
236-
self.deferred_nodes.append(DeferredNode(node, type_name, class_context_top))
228+
# Shouldn't we freeze the entire scope?
229+
active_class = self.scope.active_class()
230+
self.deferred_nodes.append(DeferredNode(node, type_name, active_class))
237231
# Set a marker so that we won't infer additional types in this
238232
# function. Any inferred types could be bogus, because there's at
239233
# least one type that we don't know.
@@ -510,7 +504,6 @@ def check_func_item(self, defn: FuncItem,
510504
if isinstance(defn, FuncDef):
511505
fdef = defn
512506

513-
self.function_stack.append(defn)
514507
self.dynamic_funcs.append(defn.is_dynamic() and not type_override)
515508

516509
if fdef:
@@ -532,7 +525,6 @@ def check_func_item(self, defn: FuncItem,
532525
self.errors.pop_function()
533526

534527
self.dynamic_funcs.pop()
535-
self.function_stack.pop()
536528
self.current_node_deferred = False
537529

538530
def check_func_def(self, defn: FuncItem, typ: CallableType, name: str) -> None:
@@ -618,14 +610,22 @@ def is_implicit_any(t: Type) -> bool:
618610
for i in range(len(typ.arg_types)):
619611
arg_type = typ.arg_types[i]
620612

621-
# Refuse covariant parameter type variables
622-
# TODO: check recuresively for inner type variables
623-
if isinstance(arg_type, TypeVarType):
624-
if i > 0:
625-
if arg_type.variance == COVARIANT:
626-
self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT,
627-
arg_type)
628-
# FIX: if i == 0 and this is not a method then same as above
613+
ref_type = self.scope.active_class()
614+
if (isinstance(defn, FuncDef) and ref_type is not None and i == 0
615+
and not defn.is_static
616+
and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2]):
617+
if defn.is_class or defn.name() == '__new__':
618+
ref_type = mypy.types.TypeType(ref_type)
619+
erased = erase_to_bound(arg_type)
620+
if not is_subtype_ignoring_tvars(ref_type, erased):
621+
self.fail("The erased type of self '{}' "
622+
"is not a supertype of its class '{}'"
623+
.format(erased, ref_type), defn)
624+
elif isinstance(arg_type, TypeVarType):
625+
# Refuse covariant parameter type variables
626+
# TODO: check recuresively for inner type variables
627+
if arg_type.variance == COVARIANT:
628+
self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, arg_type)
629629
if typ.arg_kinds[i] == nodes.ARG_STAR:
630630
# builtins.tuple[T] is typing.Tuple[T, ...]
631631
arg_type = self.named_generic_type('builtins.tuple',
@@ -644,7 +644,8 @@ def is_implicit_any(t: Type) -> bool:
644644

645645
# Type check body in a new scope.
646646
with self.binder.top_frame_context():
647-
self.accept(item.body)
647+
with self.scope.push_function(defn):
648+
self.accept(item.body)
648649
unreachable = self.binder.is_unreachable()
649650

650651
if (self.options.warn_no_return and not unreachable
@@ -890,7 +891,7 @@ def check_method_override_for_base_with_name(
890891
# The name of the method is defined in the base class.
891892

892893
# Construct the type of the overriding method.
893-
typ = bind_self(self.function_type(defn), self.class_context[-1])
894+
typ = bind_self(self.function_type(defn), self.scope.active_class())
894895
# Map the overridden method type to subtype context so that
895896
# it can be checked for compatibility.
896897
original_type = base_attr.type
@@ -903,7 +904,7 @@ def check_method_override_for_base_with_name(
903904
assert False, str(base_attr.node)
904905
if isinstance(original_type, FunctionLike):
905906
original = map_type_from_supertype(
906-
bind_self(original_type, self.class_context[-1]),
907+
bind_self(original_type, self.scope.active_class()),
907908
defn.info, base)
908909
# Check that the types are compatible.
909910
# TODO overloaded signatures
@@ -987,9 +988,8 @@ def visit_class_def(self, defn: ClassDef) -> Type:
987988
old_binder = self.binder
988989
self.binder = ConditionalTypeBinder()
989990
with self.binder.top_frame_context():
990-
self.class_context.append(fill_typevars(defn.info))
991-
self.accept(defn.defs)
992-
self.class_context.pop()
991+
with self.scope.push_class(fill_typevars(defn.info)):
992+
self.accept(defn.defs)
993993
self.binder = old_binder
994994
if not defn.has_incompatible_baseclass:
995995
# Otherwise we've already found errors; more errors are not useful
@@ -1528,8 +1528,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type:
15281528
return None
15291529

15301530
def check_return_stmt(self, s: ReturnStmt) -> None:
1531-
if self.is_within_function():
1532-
defn = self.function_stack[-1]
1531+
defn = self.scope.top_function()
1532+
if defn is not None:
15331533
if defn.is_generator:
15341534
return_type = self.get_generator_return_type(self.return_types[-1],
15351535
defn.is_coroutine)
@@ -1546,7 +1546,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
15461546
if self.is_unusable_type(return_type):
15471547
# Lambdas are allowed to have a unusable returns.
15481548
# Functions returning a value of type None are allowed to have a Void return.
1549-
if isinstance(self.function_stack[-1], FuncExpr) or isinstance(typ, NoneTyp):
1549+
if isinstance(self.scope.top_function(), FuncExpr) or isinstance(typ, NoneTyp):
15501550
return
15511551
self.fail(messages.NO_RETURN_VALUE_EXPECTED, s)
15521552
else:
@@ -1559,7 +1559,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
15591559
msg=messages.INCOMPATIBLE_RETURN_VALUE_TYPE)
15601560
else:
15611561
# Empty returns are valid in Generators with Any typed returns.
1562-
if (self.function_stack[-1].is_generator and isinstance(return_type, AnyType)):
1562+
if (defn.is_generator and isinstance(return_type, AnyType)):
15631563
return
15641564

15651565
if isinstance(return_type, (Void, NoneTyp, AnyType)):
@@ -2335,13 +2335,6 @@ def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]:
23352335
return partial_types
23362336
return None
23372337

2338-
def is_within_function(self) -> bool:
2339-
"""Are we currently type checking within a function?
2340-
2341-
I.e. not at class body or at the top level.
2342-
"""
2343-
return self.return_types != []
2344-
23452338
def is_unusable_type(self, typ: Type):
23462339
"""Is this type an unusable type?
23472340
@@ -2773,3 +2766,34 @@ def is_valid_inferred_type_component(typ: Type) -> bool:
27732766
if not is_valid_inferred_type_component(item):
27742767
return False
27752768
return True
2769+
2770+
2771+
class Scope:
2772+
# We keep two stacks combined, to maintain the relative order
2773+
stack = None # type: List[Union[Type, FuncItem, MypyFile]]
2774+
2775+
def __init__(self, module: MypyFile) -> None:
2776+
self.stack = [module]
2777+
2778+
def top_function(self) -> Optional[FuncItem]:
2779+
for e in reversed(self.stack):
2780+
if isinstance(e, FuncItem):
2781+
return e
2782+
return None
2783+
2784+
def active_class(self) -> Optional[Type]:
2785+
if isinstance(self.stack[-1], Type):
2786+
return self.stack[-1]
2787+
return None
2788+
2789+
@contextmanager
2790+
def push_function(self, item: FuncItem) -> Iterator[None]:
2791+
self.stack.append(item)
2792+
yield
2793+
self.stack.pop()
2794+
2795+
@contextmanager
2796+
def push_class(self, t: Type) -> Iterator[None]:
2797+
self.stack.append(t)
2798+
yield
2799+
self.stack.pop()

mypy/checkexpr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
16041604
return AnyType()
16051605
if not self.chk.in_checked_function():
16061606
return AnyType()
1607-
args = self.chk.function_stack[-1].arguments
1607+
args = self.chk.scope.top_function().arguments
16081608
# An empty args with super() is an error; we need something in declared_self
16091609
if not args:
16101610
self.chk.fail('super() requires at least one positional argument', e)

mypy/checkmember.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -614,4 +614,4 @@ def erase_to_bound(t: Type):
614614
if isinstance(t, TypeType):
615615
if isinstance(t.item, TypeVarType):
616616
return TypeType(t.item.upper_bound)
617-
assert not t
617+
return t

mypy/semanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def prepare_method_signature(self, func: FuncDef) -> None:
329329
elif isinstance(functype, CallableType):
330330
self_type = functype.arg_types[0]
331331
if isinstance(self_type, AnyType):
332-
if func.is_class:
332+
if func.is_class or func.name() == '__new__':
333333
leading_type = self.class_type(self.type)
334334
else:
335335
leading_type = fill_typevars(self.type)

test-data/unit/check-selftype.test

+85
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,88 @@ class A:
248248
class B(A):
249249
def __init__(self, arg: T) -> None:
250250
super(B, self).__init__()
251+
252+
[case testSelfTypeNonsensical]
253+
# flags: --hide-error-context
254+
from typing import TypeVar, Type
255+
256+
T = TypeVar('T', bound=str)
257+
class A:
258+
def foo(self: T) -> T: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.A'
259+
return self
260+
261+
@classmethod
262+
def cfoo(cls: Type[T]) -> T: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.A]'
263+
return cls()
264+
265+
Q = TypeVar('Q', bound='B')
266+
class B:
267+
def foo(self: Q) -> Q:
268+
return self
269+
270+
@classmethod
271+
def cfoo(cls: Type[Q]) -> Q:
272+
return cls()
273+
274+
class C:
275+
def foo(self: C) -> C: return self
276+
277+
@classmethod
278+
def cfoo(cls: Type[C]) -> C:
279+
return cls()
280+
281+
class D:
282+
def foo(self: str) -> str: # E: The erased type of self 'builtins.str' is not a supertype of its class '__main__.D'
283+
return self
284+
285+
@staticmethod
286+
def bar(self: str) -> str:
287+
return self
288+
289+
@classmethod
290+
def cfoo(cls: Type[str]) -> str: # E: The erased type of self 'Type[builtins.str]' is not a supertype of its class 'Type[__main__.D]'
291+
return cls()
292+
293+
[builtins fixtures/classmethod.pyi]
294+
295+
[case testSelfTypeLambdaDefault]
296+
# flags: --hide-error-context
297+
from typing import Callable
298+
class C:
299+
@classmethod
300+
def foo(cls,
301+
arg: Callable[[int], str] = lambda a: ''
302+
) -> None:
303+
pass
304+
305+
def bar(self,
306+
arg: Callable[[int], str] = lambda a: ''
307+
) -> None:
308+
pass
309+
[builtins fixtures/classmethod.pyi]
310+
311+
[case testSelfTypeNew]
312+
# flags: --hide-error-context
313+
from typing import TypeVar, Type
314+
315+
T = TypeVar('T', bound=A)
316+
class A:
317+
def __new__(cls: Type[T]) -> T:
318+
return cls()
319+
320+
class B:
321+
def __new__(cls: Type[T]) -> T: # E: The erased type of self 'Type[__main__.A]' is not a supertype of its class 'Type[__main__.B]'
322+
return cls()
323+
324+
class C:
325+
def __new__(cls: Type[C]) -> C:
326+
return cls()
327+
328+
class D:
329+
def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
330+
return cls
331+
332+
class E:
333+
def __new__(cls) -> E:
334+
reveal_type(cls) # E: Revealed type is 'def () -> __main__.E'
335+
return cls()

0 commit comments

Comments
 (0)