2
2
3
3
import itertools
4
4
import fnmatch
5
+ from contextlib import contextmanager
5
6
6
7
from typing import (
7
- Dict , Set , List , cast , Tuple , TypeVar , Union , Optional , NamedTuple
8
+ Dict , Set , List , cast , Tuple , TypeVar , Union , Optional , NamedTuple , Iterator
8
9
)
9
10
10
11
from mypy .errors import Errors , report_internal_error
35
36
from mypy .sametypes import is_same_type
36
37
from mypy .messages import MessageBuilder
37
38
import mypy .checkexpr
38
- from mypy .checkmember import map_type_from_supertype , bind_self
39
+ from mypy .checkmember import map_type_from_supertype , bind_self , erase_to_bound
39
40
from mypy import messages
40
41
from mypy .subtypes import (
41
- is_subtype , is_equivalent , is_proper_subtype , is_more_precise , restrict_subtype_away
42
+ is_subtype , is_equivalent , is_proper_subtype , is_more_precise , restrict_subtype_away ,
43
+ is_subtype_ignoring_tvars
42
44
)
43
45
from mypy .maptype import map_instance_to_supertype
44
46
from mypy .semanal import fill_typevars , set_callable_name , refers_to_fullname
65
67
[
66
68
('node' , FuncItem ),
67
69
('context_type_name' , Optional [str ]), # Name of the surrounding class (for error messages)
68
- ('class_type ' , Optional [Type ]), # And its type (from class_context )
70
+ ('active_class ' , Optional [Type ]), # And its type (for selftype handline )
69
71
])
70
72
71
73
@@ -91,19 +93,13 @@ class TypeChecker(NodeVisitor[Type]):
91
93
# Helper for type checking expressions
92
94
expr_checker = None # type: mypy.checkexpr.ExpressionChecker
93
95
94
- # Class context for checking overriding of a method of the form
95
- # def foo(self: T) -> T
96
- # We need to pass the current class definition for instantiation of T
97
- class_context = None # type: List[Type]
98
-
96
+ scope = None # type: Scope
99
97
# Stack of function return types
100
98
return_types = None # type: List[Type]
101
99
# Type context for type inference
102
100
type_context = None # type: List[Type]
103
101
# Flags; true for dynamically typed functions
104
102
dynamic_funcs = None # type: List[bool]
105
- # Stack of functions being type checked
106
- function_stack = None # type: List[FuncItem]
107
103
# Stack of collections of variables with partial types
108
104
partial_types = None # type: List[Dict[Var, Context]]
109
105
globals = None # type: SymbolTable
@@ -139,13 +135,12 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
139
135
self .path = path
140
136
self .msg = MessageBuilder (errors , modules )
141
137
self .expr_checker = mypy .checkexpr .ExpressionChecker (self , self .msg )
142
- self .class_context = []
138
+ self .scope = Scope ( tree )
143
139
self .binder = ConditionalTypeBinder ()
144
140
self .globals = tree .names
145
141
self .return_types = []
146
142
self .type_context = []
147
143
self .dynamic_funcs = []
148
- self .function_stack = []
149
144
self .partial_types = []
150
145
self .deferred_nodes = []
151
146
self .type_map = {}
@@ -203,7 +198,7 @@ def check_second_pass(self) -> bool:
203
198
todo = self .deferred_nodes
204
199
self .deferred_nodes = []
205
200
done = set () # type: Set[FuncItem]
206
- for node , type_name , class_type in todo :
201
+ for node , type_name , active_class in todo :
207
202
if node in done :
208
203
continue
209
204
# This is useful for debugging:
@@ -212,28 +207,27 @@ def check_second_pass(self) -> bool:
212
207
done .add (node )
213
208
if type_name :
214
209
self .errors .push_type (type_name )
215
- if class_type :
216
- self .class_context .append (class_type )
217
- self .accept (node )
218
- if class_type :
219
- self .class_context .pop ()
210
+
211
+ if active_class :
212
+ with self .scope .push_class (active_class ):
213
+ self .accept (node )
214
+ else :
215
+ self .accept (node )
220
216
if type_name :
221
217
self .errors .pop_type ()
222
218
return True
223
219
224
220
def handle_cannot_determine_type (self , name : str , context : Context ) -> None :
225
- if self .pass_num < LAST_PASS and self .function_stack :
221
+ node = self .scope .top_function ()
222
+ if self .pass_num < LAST_PASS and node is not None :
226
223
# Don't report an error yet. Just defer.
227
- node = self .function_stack [- 1 ]
228
224
if self .errors .type_name :
229
225
type_name = self .errors .type_name [- 1 ]
230
226
else :
231
227
type_name = None
232
- if self .class_context :
233
- class_context_top = self .class_context [- 1 ]
234
- else :
235
- class_context_top = None
236
- self .deferred_nodes .append (DeferredNode (node , type_name , class_context_top ))
228
+ # Shouldn't we freeze the entire scope?
229
+ active_class = self .scope .active_class ()
230
+ self .deferred_nodes .append (DeferredNode (node , type_name , active_class ))
237
231
# Set a marker so that we won't infer additional types in this
238
232
# function. Any inferred types could be bogus, because there's at
239
233
# least one type that we don't know.
@@ -510,7 +504,6 @@ def check_func_item(self, defn: FuncItem,
510
504
if isinstance (defn , FuncDef ):
511
505
fdef = defn
512
506
513
- self .function_stack .append (defn )
514
507
self .dynamic_funcs .append (defn .is_dynamic () and not type_override )
515
508
516
509
if fdef :
@@ -532,7 +525,6 @@ def check_func_item(self, defn: FuncItem,
532
525
self .errors .pop_function ()
533
526
534
527
self .dynamic_funcs .pop ()
535
- self .function_stack .pop ()
536
528
self .current_node_deferred = False
537
529
538
530
def check_func_def (self , defn : FuncItem , typ : CallableType , name : str ) -> None :
@@ -618,14 +610,22 @@ def is_implicit_any(t: Type) -> bool:
618
610
for i in range (len (typ .arg_types )):
619
611
arg_type = typ .arg_types [i ]
620
612
621
- # Refuse covariant parameter type variables
622
- # TODO: check recuresively for inner type variables
623
- if isinstance (arg_type , TypeVarType ):
624
- if i > 0 :
625
- if arg_type .variance == COVARIANT :
626
- self .fail (messages .FUNCTION_PARAMETER_CANNOT_BE_COVARIANT ,
627
- arg_type )
628
- # FIX: if i == 0 and this is not a method then same as above
613
+ ref_type = self .scope .active_class ()
614
+ if (isinstance (defn , FuncDef ) and ref_type is not None and i == 0
615
+ and not defn .is_static
616
+ and typ .arg_kinds [0 ] not in [nodes .ARG_STAR , nodes .ARG_STAR2 ]):
617
+ if defn .is_class or defn .name () == '__new__' :
618
+ ref_type = mypy .types .TypeType (ref_type )
619
+ erased = erase_to_bound (arg_type )
620
+ if not is_subtype_ignoring_tvars (ref_type , erased ):
621
+ self .fail ("The erased type of self '{}' "
622
+ "is not a supertype of its class '{}'"
623
+ .format (erased , ref_type ), defn )
624
+ elif isinstance (arg_type , TypeVarType ):
625
+ # Refuse covariant parameter type variables
626
+ # TODO: check recuresively for inner type variables
627
+ if arg_type .variance == COVARIANT :
628
+ self .fail (messages .FUNCTION_PARAMETER_CANNOT_BE_COVARIANT , arg_type )
629
629
if typ .arg_kinds [i ] == nodes .ARG_STAR :
630
630
# builtins.tuple[T] is typing.Tuple[T, ...]
631
631
arg_type = self .named_generic_type ('builtins.tuple' ,
@@ -644,7 +644,8 @@ def is_implicit_any(t: Type) -> bool:
644
644
645
645
# Type check body in a new scope.
646
646
with self .binder .top_frame_context ():
647
- self .accept (item .body )
647
+ with self .scope .push_function (defn ):
648
+ self .accept (item .body )
648
649
unreachable = self .binder .is_unreachable ()
649
650
650
651
if (self .options .warn_no_return and not unreachable
@@ -890,7 +891,7 @@ def check_method_override_for_base_with_name(
890
891
# The name of the method is defined in the base class.
891
892
892
893
# Construct the type of the overriding method.
893
- typ = bind_self (self .function_type (defn ), self .class_context [ - 1 ] )
894
+ typ = bind_self (self .function_type (defn ), self .scope . active_class () )
894
895
# Map the overridden method type to subtype context so that
895
896
# it can be checked for compatibility.
896
897
original_type = base_attr .type
@@ -903,7 +904,7 @@ def check_method_override_for_base_with_name(
903
904
assert False , str (base_attr .node )
904
905
if isinstance (original_type , FunctionLike ):
905
906
original = map_type_from_supertype (
906
- bind_self (original_type , self .class_context [ - 1 ] ),
907
+ bind_self (original_type , self .scope . active_class () ),
907
908
defn .info , base )
908
909
# Check that the types are compatible.
909
910
# TODO overloaded signatures
@@ -987,9 +988,8 @@ def visit_class_def(self, defn: ClassDef) -> Type:
987
988
old_binder = self .binder
988
989
self .binder = ConditionalTypeBinder ()
989
990
with self .binder .top_frame_context ():
990
- self .class_context .append (fill_typevars (defn .info ))
991
- self .accept (defn .defs )
992
- self .class_context .pop ()
991
+ with self .scope .push_class (fill_typevars (defn .info )):
992
+ self .accept (defn .defs )
993
993
self .binder = old_binder
994
994
if not defn .has_incompatible_baseclass :
995
995
# Otherwise we've already found errors; more errors are not useful
@@ -1528,8 +1528,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type:
1528
1528
return None
1529
1529
1530
1530
def check_return_stmt (self , s : ReturnStmt ) -> None :
1531
- if self .is_within_function ():
1532
- defn = self . function_stack [ - 1 ]
1531
+ defn = self .scope . top_function ()
1532
+ if defn is not None :
1533
1533
if defn .is_generator :
1534
1534
return_type = self .get_generator_return_type (self .return_types [- 1 ],
1535
1535
defn .is_coroutine )
@@ -1546,7 +1546,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
1546
1546
if self .is_unusable_type (return_type ):
1547
1547
# Lambdas are allowed to have a unusable returns.
1548
1548
# Functions returning a value of type None are allowed to have a Void return.
1549
- if isinstance (self .function_stack [ - 1 ] , FuncExpr ) or isinstance (typ , NoneTyp ):
1549
+ if isinstance (self .scope . top_function () , FuncExpr ) or isinstance (typ , NoneTyp ):
1550
1550
return
1551
1551
self .fail (messages .NO_RETURN_VALUE_EXPECTED , s )
1552
1552
else :
@@ -1559,7 +1559,7 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
1559
1559
msg = messages .INCOMPATIBLE_RETURN_VALUE_TYPE )
1560
1560
else :
1561
1561
# Empty returns are valid in Generators with Any typed returns.
1562
- if (self . function_stack [ - 1 ] .is_generator and isinstance (return_type , AnyType )):
1562
+ if (defn .is_generator and isinstance (return_type , AnyType )):
1563
1563
return
1564
1564
1565
1565
if isinstance (return_type , (Void , NoneTyp , AnyType )):
@@ -2335,13 +2335,6 @@ def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]:
2335
2335
return partial_types
2336
2336
return None
2337
2337
2338
- def is_within_function (self ) -> bool :
2339
- """Are we currently type checking within a function?
2340
-
2341
- I.e. not at class body or at the top level.
2342
- """
2343
- return self .return_types != []
2344
-
2345
2338
def is_unusable_type (self , typ : Type ):
2346
2339
"""Is this type an unusable type?
2347
2340
@@ -2773,3 +2766,34 @@ def is_valid_inferred_type_component(typ: Type) -> bool:
2773
2766
if not is_valid_inferred_type_component (item ):
2774
2767
return False
2775
2768
return True
2769
+
2770
+
2771
+ class Scope :
2772
+ # We keep two stacks combined, to maintain the relative order
2773
+ stack = None # type: List[Union[Type, FuncItem, MypyFile]]
2774
+
2775
+ def __init__ (self , module : MypyFile ) -> None :
2776
+ self .stack = [module ]
2777
+
2778
+ def top_function (self ) -> Optional [FuncItem ]:
2779
+ for e in reversed (self .stack ):
2780
+ if isinstance (e , FuncItem ):
2781
+ return e
2782
+ return None
2783
+
2784
+ def active_class (self ) -> Optional [Type ]:
2785
+ if isinstance (self .stack [- 1 ], Type ):
2786
+ return self .stack [- 1 ]
2787
+ return None
2788
+
2789
+ @contextmanager
2790
+ def push_function (self , item : FuncItem ) -> Iterator [None ]:
2791
+ self .stack .append (item )
2792
+ yield
2793
+ self .stack .pop ()
2794
+
2795
+ @contextmanager
2796
+ def push_class (self , t : Type ) -> Iterator [None ]:
2797
+ self .stack .append (t )
2798
+ yield
2799
+ self .stack .pop ()
0 commit comments