Skip to content

Commit ae48580

Browse files
committed
Make classmethod's first argument be Type[...]
Currently the first argument to `__new__` and classmethods is a callable type that is constructed during semantic analysis by typechecker code (!) that looks for the `__init__`/`__new__` methods. This causes a number of problems, including not being able to call `object.__new__` in a subclass's `__new__` if it took arguments (#4190) and giving the wrong type if `__init__` appeared after the class method (#1727). Taking a `Type` instead lets us solve those problems, and postpone computing the callable version of the type until typechecking if it is needed. This also lets us drop a bunch of plugin code that tries to fix up the types of its cls arguments post-hoc, sometimes incorrectly (#5263). Fixes #1727. Fixes #4190. Fixes #5263.
1 parent a88afb2 commit ae48580

15 files changed

+99
-82
lines changed

mypy/checker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,8 +867,6 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])
867867
self.fail(msg, defn)
868868
if note:
869869
self.note(note, defn)
870-
if defn.is_class and isinstance(arg_type, CallableType):
871-
arg_type.is_classmethod_class = True
872870
elif isinstance(arg_type, TypeVarType):
873871
# Refuse covariant parameter type variables
874872
# TODO: check recursively for inner type variables

mypy/checkexpr.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,16 @@ def check_call(self, callee: Type, args: List[Expression],
602602
return callee.ret_type, callee
603603

604604
if (callee.is_type_obj() and callee.type_object().is_abstract
605-
# Exceptions for Type[...] and classmethod first argument
606-
and not callee.from_type_type and not callee.is_classmethod_class
605+
# Exception for Type[...]
606+
and not callee.from_type_type
607607
and not callee.type_object().fallback_to_any):
608608
type = callee.type_object()
609609
self.msg.cannot_instantiate_abstract_class(
610610
callee.type_object().name(), type.abstract_attributes,
611611
context)
612612
elif (callee.is_type_obj() and callee.type_object().is_protocol
613-
# Exceptions for Type[...] and classmethod first argument
614-
and not callee.from_type_type and not callee.is_classmethod_class):
613+
# Exception for Type[...]
614+
and not callee.from_type_type):
615615
self.chk.fail('Cannot instantiate protocol class "{}"'
616616
.format(callee.type_object().name()), context)
617617

@@ -737,6 +737,9 @@ def analyze_type_type_callee(self, item: Type, context: Context) -> Type:
737737
for c in callee.items()])
738738
if callee:
739739
return callee
740+
# We support Type of namedtuples but not of tuples in general
741+
if isinstance(item, TupleType) and item.fallback.type.fullname() != 'builtins.tuple':
742+
return self.analyze_type_type_callee(item.fallback, context)
740743

741744
self.msg.unsupported_type_type(item, context)
742745
return AnyType(TypeOfAny.from_error)
@@ -1133,9 +1136,7 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
11331136
caller_type.is_type_obj() and
11341137
(caller_type.type_object().is_abstract or caller_type.type_object().is_protocol) and
11351138
isinstance(callee_type.item, Instance) and
1136-
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
1137-
# ...except for classmethod first argument
1138-
not caller_type.is_classmethod_class):
1139+
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol)):
11391140
self.msg.concrete_only_call(callee_type, context)
11401141
elif not is_subtype(caller_type, callee_type):
11411142
if self.chk.should_suppress_optional_error([caller_type, callee_type]):

mypy/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mypy.tvar_scope import TypeVarScope
1515
from mypy.types import (
1616
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, TypeVarType,
17-
AnyType, TypeList, UnboundType, TypeOfAny
17+
AnyType, TypeList, UnboundType, TypeOfAny, TypeType,
1818
)
1919
from mypy.messages import MessageBuilder
2020
from mypy.options import Options
@@ -93,7 +93,7 @@ def anal_type(self, t: Type, *,
9393
raise NotImplementedError
9494

9595
@abstractmethod
96-
def class_type(self, info: TypeInfo) -> Type:
96+
def class_type(self, self_type: Type) -> Type:
9797
raise NotImplementedError
9898

9999
@abstractmethod

mypy/plugins/attrs.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -495,23 +495,6 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute],
495495
[attribute.argument(ctx) for attribute in attributes if attribute.init],
496496
NoneTyp()
497497
)
498-
for stmt in ctx.cls.defs.body:
499-
# The type of classmethods will be wrong because it's based on the parent's __init__.
500-
# Set it correctly.
501-
if isinstance(stmt, Decorator) and stmt.func.is_class:
502-
func_type = stmt.func.type
503-
if isinstance(func_type, CallableType):
504-
func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info)
505-
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
506-
func_type = stmt.type
507-
if isinstance(func_type, Overloaded):
508-
class_type = ctx.api.class_type(ctx.cls.info)
509-
for item in func_type.items():
510-
item.arg_types[0] = class_type
511-
if stmt.impl is not None:
512-
assert isinstance(stmt.impl, Decorator)
513-
if isinstance(stmt.impl.func.type, CallableType):
514-
stmt.impl.func.type.arg_types[0] = class_type
515498

516499

517500
class MethodAdder:

mypy/plugins/dataclasses.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,6 @@ def transform(self) -> None:
9292
args=[attr.to_argument(info) for attr in attributes if attr.is_in_init],
9393
return_type=NoneTyp(),
9494
)
95-
for stmt in self._ctx.cls.defs.body:
96-
# Fix up the types of classmethods since, by default,
97-
# they will be based on the parent class' init.
98-
if isinstance(stmt, Decorator) and stmt.func.is_class:
99-
func_type = stmt.func.type
100-
if isinstance(func_type, CallableType):
101-
func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info)
102-
if isinstance(stmt, OverloadedFuncDef) and stmt.is_class:
103-
func_type = stmt.type
104-
if isinstance(func_type, Overloaded):
105-
class_type = ctx.api.class_type(ctx.cls.info)
106-
for item in func_type.items():
107-
item.arg_types[0] = class_type
108-
if stmt.impl is not None:
109-
assert isinstance(stmt.impl, Decorator)
110-
if isinstance(stmt.impl.func.type, CallableType):
111-
stmt.impl.func.type.arg_types[0] = class_type
11295

11396
# Add an eq method, but only if the class doesn't already have one.
11497
if decorator_arguments['eq'] and info.get('__eq__') is None:

mypy/semanal.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from mypy.types import (
6969
FunctionLike, UnboundType, TypeVarDef, TupleType, UnionType, StarType, function_type,
7070
CallableType, Overloaded, Instance, Type, AnyType,
71-
TypeTranslator, TypeOfAny
71+
TypeTranslator, TypeOfAny, TypeType,
7272
)
7373
from mypy.nodes import implicit_module_attrs
7474
from mypy.typeanal import (
@@ -479,10 +479,9 @@ def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
479479
elif isinstance(functype, CallableType):
480480
self_type = functype.arg_types[0]
481481
if isinstance(self_type, AnyType):
482+
leading_type = fill_typevars(info) # type: Type
482483
if func.is_class or func.name() in ('__new__', '__init_subclass__'):
483-
leading_type = self.class_type(info)
484-
else:
485-
leading_type = fill_typevars(info)
484+
leading_type = self.class_type(leading_type)
486485
func.type = replace_implicit_first_type(functype, leading_type)
487486

488487
def set_original_def(self, previous: Optional[Node], new: FuncDef) -> bool:
@@ -775,8 +774,6 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]:
775774
# that were already set in build_namedtuple_typeinfo.
776775
nt_names = named_tuple_info.names
777776
named_tuple_info.names = SymbolTable()
778-
# This is needed for the cls argument to classmethods to get bound correctly.
779-
named_tuple_info.names['__new__'] = nt_names['__new__']
780777

781778
self.enter_class(named_tuple_info)
782779

@@ -1343,15 +1340,8 @@ def object_type(self) -> Instance:
13431340
def str_type(self) -> Instance:
13441341
return self.named_type('__builtins__.str')
13451342

1346-
def class_type(self, info: TypeInfo) -> Type:
1347-
# Construct a function type whose fallback is cls.
1348-
from mypy import checkmember # To avoid import cycle.
1349-
leading_type = checkmember.type_object_type(info, self.builtin_type)
1350-
if isinstance(leading_type, Overloaded):
1351-
# Overloaded __init__ is too complex to handle. Plus it's stubs only.
1352-
return AnyType(TypeOfAny.special_form)
1353-
else:
1354-
return leading_type
1343+
def class_type(self, self_type: Type) -> Type:
1344+
return TypeType.make_normalized(self_type)
13551345

13561346
def named_type(self, qualified_name: str, args: Optional[List[Type]] = None) -> Instance:
13571347
sym = self.lookup_qualified(qualified_name, Context())

mypy/types.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,6 @@ def __init__(self,
706706
column: int = -1,
707707
is_ellipsis_args: bool = False,
708708
implicit: bool = False,
709-
is_classmethod_class: bool = False,
710709
special_sig: Optional[str] = None,
711710
from_type_type: bool = False,
712711
bound_args: Sequence[Optional[Type]] = (),
@@ -730,7 +729,6 @@ def __init__(self,
730729
self.variables = variables
731730
self.is_ellipsis_args = is_ellipsis_args
732731
self.implicit = implicit
733-
self.is_classmethod_class = is_classmethod_class
734732
self.special_sig = special_sig
735733
self.from_type_type = from_type_type
736734
if not bound_args:
@@ -780,7 +778,6 @@ def copy_modified(self,
780778
is_ellipsis_args=(
781779
is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args),
782780
implicit=implicit if implicit is not _dummy else self.implicit,
783-
is_classmethod_class=self.is_classmethod_class,
784781
special_sig=special_sig if special_sig is not _dummy else self.special_sig,
785782
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
786783
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
@@ -816,9 +813,6 @@ def is_kw_arg(self) -> bool:
816813
def is_type_obj(self) -> bool:
817814
return self.fallback.type.is_metaclass()
818815

819-
def is_concrete_type_obj(self) -> bool:
820-
return self.is_type_obj() and self.is_classmethod_class
821-
822816
def type_object(self) -> mypy.nodes.TypeInfo:
823817
assert self.is_type_obj()
824818
ret = self.ret_type
@@ -990,7 +984,6 @@ def serialize(self) -> JsonDict:
990984
'variables': [v.serialize() for v in self.variables],
991985
'is_ellipsis_args': self.is_ellipsis_args,
992986
'implicit': self.implicit,
993-
'is_classmethod_class': self.is_classmethod_class,
994987
'bound_args': [(None if t is None else t.serialize())
995988
for t in self.bound_args],
996989
'def_extras': dict(self.def_extras),
@@ -1009,7 +1002,6 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
10091002
variables=[TypeVarDef.deserialize(v) for v in data['variables']],
10101003
is_ellipsis_args=data['is_ellipsis_args'],
10111004
implicit=data['implicit'],
1012-
is_classmethod_class=data['is_classmethod_class'],
10131005
bound_args=[(None if t is None else deserialize_type(t))
10141006
for t in data['bound_args']],
10151007
def_extras=data['def_extras']

test-data/unit/check-attr.test

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,19 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"
367367

368368
[builtins fixtures/list.pyi]
369369

370+
[case testAttrsGenericClassmethod]
371+
from typing import TypeVar, Generic, Optional
372+
import attr
373+
T = TypeVar('T')
374+
@attr.s(auto_attribs=True)
375+
class A(Generic[T]):
376+
x: Optional[T]
377+
@classmethod
378+
def clsmeth(cls) -> None:
379+
reveal_type(cls) # E: Revealed type is 'Type[__main__.A[T`1]]'
380+
381+
[builtins fixtures/classmethod.pyi]
382+
370383
[case testAttrsForwardReference]
371384
import attr
372385
@attr.s(auto_attribs=True)
@@ -416,7 +429,7 @@ class A:
416429
b: str = attr.ib()
417430
@classmethod
418431
def new(cls) -> A:
419-
reveal_type(cls) # E: Revealed type is 'def (a: builtins.int, b: builtins.str) -> __main__.A'
432+
reveal_type(cls) # E: Revealed type is 'Type[__main__.A]'
420433
return cls(6, 'hello')
421434
@classmethod
422435
def bad(cls) -> A:
@@ -451,7 +464,7 @@ class A:
451464

452465
@classmethod
453466
def foo(cls, x: Union[int, str]) -> Union[int, str]:
454-
reveal_type(cls) # E: Revealed type is 'def (a: Any, b: Any =) -> __main__.A'
467+
reveal_type(cls) # E: Revealed type is 'Type[__main__.A]'
455468
reveal_type(cls.other()) # E: Revealed type is 'builtins.str'
456469
return x
457470

test-data/unit/check-class-namedtuple.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def f(a: Type[N]):
388388
a()
389389
[builtins fixtures/list.pyi]
390390
[out]
391-
main:8: error: Unsupported type Type["N"]
391+
main:9: error: Too few arguments for "N"
392392

393393
[case testNewNamedTupleWithDefaults]
394394
# flags: --fast-parser --python-version 3.6
@@ -587,7 +587,7 @@ class XMethBad(NamedTuple):
587587
class MagicalFields(NamedTuple):
588588
x: int
589589
def __slots__(self) -> None: pass # E: Cannot overwrite NamedTuple attribute "__slots__"
590-
def __new__(cls) -> None: pass # E: Name '__new__' already defined (possibly by an import)
590+
def __new__(cls) -> None: pass # E: Cannot overwrite NamedTuple attribute "__new__"
591591
def _source(self) -> int: pass # E: Cannot overwrite NamedTuple attribute "_source"
592592
__annotations__ = {'x': float} # E: NamedTuple field name cannot start with an underscore: __annotations__ \
593593
# E: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]" \
@@ -640,7 +640,7 @@ class HasClassMethod(NamedTuple):
640640

641641
@classmethod
642642
def new(cls, f: str) -> 'HasClassMethod':
643-
reveal_type(cls) # E: Revealed type is 'def (x: builtins.str) -> Tuple[builtins.str, fallback=__main__.HasClassMethod]'
643+
reveal_type(cls) # E: Revealed type is 'Type[Tuple[builtins.str, fallback=__main__.HasClassMethod]]'
644644
reveal_type(HasClassMethod) # E: Revealed type is 'def (x: builtins.str) -> Tuple[builtins.str, fallback=__main__.HasClassMethod]'
645645
return cls(x=f)
646646

test-data/unit/check-classes.test

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,19 @@ class B(A):
339339
def __new__(cls) -> int:
340340
return 1
341341

342+
[case testOverride__new__AndCallObject]
343+
from typing import TypeVar, Generic
344+
345+
class A:
346+
def __new__(cls, x: int) -> 'A':
347+
return object.__new__(cls)
348+
349+
T = TypeVar('T')
350+
class B(Generic[T]):
351+
def __new__(cls, foo: T) -> 'B[T]':
352+
return object.__new__(cls)
353+
[builtins fixtures/__new__.pyi]
354+
342355
[case testInnerFunctionNotOverriding]
343356
class A:
344357
def f(self) -> int: pass
@@ -1828,7 +1841,7 @@ class Num1:
18281841

18291842
class Num2(Num1):
18301843
@overload
1831-
def __add__(self, other: Num2) -> Num2: ...
1844+
def __add__(self, other: Num2) -> Num2: ...
18321845
@overload
18331846
def __add__(self, other: Num1) -> Num2: ...
18341847
def __add__(self, other): pass
@@ -3001,7 +3014,7 @@ def f(a: Type[N]):
30013014
a()
30023015
[builtins fixtures/list.pyi]
30033016
[out]
3004-
main:3: error: Unsupported type Type["N"]
3017+
main:4: error: Too few arguments for "N"
30053018

30063019
[case testTypeUsingTypeCJoin]
30073020
from typing import Type
@@ -5117,3 +5130,15 @@ class C:
51175130
def x(self) -> int: pass
51185131
[builtins fixtures/property.pyi]
51195132
[out]
5133+
5134+
[case testClassMethodBeforeInit]
5135+
class Foo(object):
5136+
@classmethod
5137+
def bar(cls):
5138+
# type: () -> Foo
5139+
return cls("bar")
5140+
5141+
def __init__(self, baz):
5142+
# type: (str) -> None
5143+
self.baz = baz
5144+
[builtins fixtures/classmethod.pyi]

test-data/unit/check-dataclasses.test

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class A:
249249

250250
@classmethod
251251
def foo(cls, x: Union[int, str]) -> Union[int, str]:
252-
reveal_type(cls) # E: Revealed type is 'def (a: builtins.int, b: builtins.str) -> __main__.A'
252+
reveal_type(cls) # E: Revealed type is 'Type[__main__.A]'
253253
reveal_type(cls.other()) # E: Revealed type is 'builtins.str'
254254
return x
255255

@@ -421,6 +421,23 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in
421421

422422
[builtins fixtures/list.pyi]
423423

424+
[case testDataclassGenericsClassmethod]
425+
# flags: --python-version 3.6
426+
from dataclasses import dataclass
427+
from typing import Generic, TypeVar
428+
429+
T = TypeVar('T')
430+
431+
@dataclass
432+
class A(Generic[T]):
433+
x: T
434+
435+
@classmethod
436+
def foo(cls) -> None:
437+
reveal_type(cls) # E: Revealed type is 'Type[__main__.A[T`1]]'
438+
439+
[builtins fixtures/classmethod.pyi]
440+
424441
[case testDataclassesForwardRefs]
425442
from dataclasses import dataclass
426443

@@ -466,3 +483,16 @@ app.rating
466483
app.database_name # E: "SpecializedApplication" has no attribute "database_name"
467484

468485
[builtins fixtures/list.pyi]
486+
487+
[case testDataclassFactory]
488+
from typing import Type, TypeVar
489+
from dataclasses import dataclass
490+
491+
T = TypeVar('T', bound='A')
492+
493+
@dataclass
494+
class A:
495+
@classmethod
496+
def make(cls: Type[T]) -> T:
497+
return cls()
498+
[builtins fixtures/classmethod.pyi]

test-data/unit/check-overloading.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4072,7 +4072,7 @@ class Wrapper:
40724072

40734073
@classmethod # E: Overloaded function implementation cannot produce return type of signature 1
40744074
def foo(cls, x: Union[int, str]) -> str:
4075-
reveal_type(cls) # E: Revealed type is 'def () -> __main__.Wrapper'
4075+
reveal_type(cls) # E: Revealed type is 'Type[__main__.Wrapper]'
40764076
reveal_type(cls.other()) # E: Revealed type is 'builtins.str'
40774077
return "..."
40784078

0 commit comments

Comments
 (0)