diff --git a/mypy/plugin.py b/mypy/plugin.py index 75b370857536..f6bd02f92342 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -4,7 +4,6 @@ from functools import partial from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar, Dict -import mypy.plugins.attrs from mypy.nodes import ( Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr, ClassDef, TypeInfo, SymbolTableNode, MypyFile @@ -302,13 +301,18 @@ def get_method_hook(self, fullname: str def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in mypy.plugins.attrs.attr_class_makers: - return mypy.plugins.attrs.attr_class_maker_callback - elif fullname in mypy.plugins.attrs.attr_dataclass_makers: + from mypy.plugins import attrs + from mypy.plugins import dataclasses + + if fullname in attrs.attr_class_makers: + return attrs.attr_class_maker_callback + elif fullname in attrs.attr_dataclass_makers: return partial( - mypy.plugins.attrs.attr_class_maker_callback, + attrs.attr_class_maker_callback, auto_attribs_default=True ) + elif fullname in dataclasses.dataclass_makers: + return dataclasses.dataclass_class_maker_callback return None diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 4a7d97f3d28b..02947385ca47 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -11,6 +11,9 @@ is_class_var, TempNode, Decorator, MemberExpr, Expression, FuncDef, Block, PassStmt, SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef ) +from mypy.plugins.common import ( + _get_argument, _get_bool_argument, _get_decorator_bool_argument +) from mypy.types import ( Type, AnyType, TypeOfAny, CallableType, NoneTyp, TypeVarDef, TypeVarType, Overloaded, Instance, UnionType, FunctionLike @@ -468,67 +471,6 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute], func_type.arg_types[0] = ctx.api.class_type(ctx.cls.info) -def _get_decorator_bool_argument( - ctx: 'mypy.plugin.ClassDefContext', - name: str, - default: bool) -> bool: - """Return the bool argument for the decorator. - - This handles both @attr.s(...) and @attr.s - """ - if isinstance(ctx.reason, CallExpr): - return _get_bool_argument(ctx, ctx.reason, name, default) - else: - return default - - -def _get_bool_argument(ctx: 'mypy.plugin.ClassDefContext', expr: CallExpr, - name: str, default: bool) -> bool: - """Return the boolean value for an argument to a call or the default if it's not found.""" - attr_value = _get_argument(expr, name) - if attr_value: - ret = ctx.api.parse_bool(attr_value) - if ret is None: - ctx.api.fail('"{}" argument must be True or False.'.format(name), expr) - return default - return ret - return default - - -def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: - """Return the expression for the specific argument.""" - # To do this we use the CallableType of the callee to find the FormalArgument, - # then walk the actual CallExpr looking for the appropriate argument. - # - # Note: I'm not hard-coding the index so that in the future we can support other - # attrib and class makers. - callee_type = None - if (isinstance(call.callee, RefExpr) - and isinstance(call.callee.node, (Var, FuncBase)) - and call.callee.node.type): - callee_node_type = call.callee.node.type - if isinstance(callee_node_type, Overloaded): - # We take the last overload. - callee_type = callee_node_type.items()[-1] - elif isinstance(callee_node_type, CallableType): - callee_type = callee_node_type - - if not callee_type: - return None - - argument = callee_type.argument_by_name(name) - if not argument: - return None - assert argument.name - - for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): - if argument.pos is not None and not attr_name and i == argument.pos: - return attr_value - if attr_name == argument.name: - return attr_value - return None - - class MethodAdder: """Helper to add methods to a TypeInfo. diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py new file mode 100644 index 000000000000..dc808852043a --- /dev/null +++ b/mypy/plugins/common.py @@ -0,0 +1,110 @@ +from typing import List, Optional + +from mypy.nodes import ( + ARG_OPT, ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase, + FuncDef, PassStmt, RefExpr, SymbolTableNode, Var +) +from mypy.plugin import ClassDefContext +from mypy.semanal import set_callable_name +from mypy.types import CallableType, Overloaded, Type, TypeVarDef +from mypy.typevars import fill_typevars + + +def _get_decorator_bool_argument( + ctx: ClassDefContext, + name: str, + default: bool, +) -> bool: + """Return the bool argument for the decorator. + + This handles both @decorator(...) and @decorator. + """ + if isinstance(ctx.reason, CallExpr): + return _get_bool_argument(ctx, ctx.reason, name, default) + else: + return default + + +def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, + name: str, default: bool) -> bool: + """Return the boolean value for an argument to a call or the + default if it's not found. + """ + attr_value = _get_argument(expr, name) + if attr_value: + ret = ctx.api.parse_bool(attr_value) + if ret is None: + ctx.api.fail('"{}" argument must be True or False.'.format(name), expr) + return default + return ret + return default + + +def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: + """Return the expression for the specific argument.""" + # To do this we use the CallableType of the callee to find the FormalArgument, + # then walk the actual CallExpr looking for the appropriate argument. + # + # Note: I'm not hard-coding the index so that in the future we can support other + # attrib and class makers. + callee_type = None + if (isinstance(call.callee, RefExpr) + and isinstance(call.callee.node, (Var, FuncBase)) + and call.callee.node.type): + callee_node_type = call.callee.node.type + if isinstance(callee_node_type, Overloaded): + # We take the last overload. + callee_type = callee_node_type.items()[-1] + elif isinstance(callee_node_type, CallableType): + callee_type = callee_node_type + + if not callee_type: + return None + + argument = callee_type.argument_by_name(name) + if not argument: + return None + assert argument.name + + for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): + if argument.pos is not None and not attr_name and i == argument.pos: + return attr_value + if attr_name == argument.name: + return attr_value + return None + + +def _add_method( + ctx: ClassDefContext, + name: str, + args: List[Argument], + return_type: Type, + self_type: Optional[Type] = None, + tvar_def: Optional[TypeVarDef] = None, +) -> None: + """Adds a new method to a class. + """ + info = ctx.cls.info + self_type = self_type or fill_typevars(info) + function_type = ctx.api.named_type('__builtins__.function') + + args = [Argument(Var('self'), self_type, None, ARG_POS)] + args + arg_types, arg_names, arg_kinds = [], [], [] + for arg in args: + assert arg.type_annotation, 'All arguments must be fully typed.' + arg_types.append(arg.type_annotation) + arg_names.append(arg.variable.name()) + arg_kinds.append(arg.kind) + + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) + if tvar_def: + signature.variables = [tvar_def] + + func = FuncDef(name, args, Block([PassStmt()])) + func.info = info + func.type = set_callable_name(signature, func) + func._fullname = info.fullname() + '.' + name + func.line = info.line + + info.names[name] = SymbolTableNode(MDEF, func) + info.defn.defs.body.append(func) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py new file mode 100644 index 000000000000..a2c19f6f5145 --- /dev/null +++ b/mypy/plugins/dataclasses.py @@ -0,0 +1,324 @@ +from collections import OrderedDict +from typing import Dict, List, Optional, Set, Tuple, cast + +from mypy.nodes import ( + ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, Block, CallExpr, + Context, Decorator, Expression, FuncDef, JsonDict, NameExpr, + SymbolTableNode, TempNode, TypeInfo, Var, +) +from mypy.plugin import ClassDefContext +from mypy.plugins.common import _add_method, _get_decorator_bool_argument +from mypy.types import ( + CallableType, Instance, NoneTyp, Type, TypeVarDef, TypeVarType, + deserialize_type +) +from mypy.typevars import fill_typevars + +# The set of decorators that generate dataclasses. +dataclass_makers = { + 'dataclass', + 'dataclasses.dataclass', +} + + +class DataclassAttribute: + def __init__( + self, + name: str, + is_in_init: bool, + is_init_var: bool, + has_default: bool, + line: int, + column: int, + ) -> None: + self.name = name + self.is_in_init = is_in_init + self.is_init_var = is_init_var + self.has_default = has_default + self.line = line + self.column = column + + def to_argument(self, info: TypeInfo) -> Argument: + return Argument( + variable=self.to_var(info), + type_annotation=info[self.name].type, + initializer=None, + kind=ARG_OPT if self.has_default else ARG_POS, + ) + + def to_var(self, info: TypeInfo) -> Var: + return Var(self.name, info[self.name].type) + + def serialize(self) -> JsonDict: + return { + 'name': self.name, + 'is_in_init': self.is_in_init, + 'is_init_var': self.is_init_var, + 'has_default': self.has_default, + 'line': self.line, + 'column': self.column, + } + + @classmethod + def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'DataclassAttribute': + return cls(**data) + + +class DataclassTransformer: + def __init__(self, ctx: ClassDefContext) -> None: + self._ctx = ctx + + def transform(self) -> None: + """Apply all the necessary transformations to the underlying + dataclass so as to ensure it is fully type checked according + to the rules in PEP 557. + """ + ctx = self._ctx + info = self._ctx.cls.info + attributes = self.collect_attributes() + decorator_arguments = { + 'init': _get_decorator_bool_argument(self._ctx, 'init', True), + 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), + 'order': _get_decorator_bool_argument(self._ctx, 'order', False), + 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), + } + + if decorator_arguments['init']: + _add_method( + ctx, + '__init__', + args=[attr.to_argument(info) for attr in attributes if attr.is_in_init], + return_type=NoneTyp(), + ) + for stmt in self._ctx.cls.defs.body: + # Fix up the types of classmethods since, by default, + # they will be based on the parent class' init. + if isinstance(stmt, Decorator) and stmt.func.is_class: + func_type = stmt.func.type + if isinstance(func_type, CallableType): + func_type.arg_types[0] = self._ctx.api.class_type(self._ctx.cls.info) + + # Add an eq method, but only if the class doesn't already have one. + if decorator_arguments['eq'] and info.get('__eq__') is None: + for method_name in ['__eq__', '__ne__']: + # The TVar is used to enforce that "other" must have + # the same type as self (covariant). Note the + # "self_type" parameter to _add_method. + obj_type = ctx.api.named_type('__builtins__.object') + cmp_tvar_def = TypeVarDef('T', 'T', 1, [], obj_type) + cmp_other_type = TypeVarType(cmp_tvar_def) + cmp_return_type = ctx.api.named_type('__builtins__.bool') + + _add_method( + ctx, + method_name, + args=[Argument(Var('other', cmp_other_type), cmp_other_type, None, ARG_POS)], + return_type=cmp_return_type, + self_type=cmp_other_type, + tvar_def=cmp_tvar_def, + ) + + # Add <, >, <=, >=, but only if the class has an eq method. + if decorator_arguments['order']: + if not decorator_arguments['eq']: + ctx.api.fail('eq must be True if order is True', ctx.cls) + + for method_name in ['__lt__', '__gt__', '__le__', '__ge__']: + # Like for __eq__ and __ne__, we want "other" to match + # the self type. + obj_type = ctx.api.named_type('__builtins__.object') + order_tvar_def = TypeVarDef('T', 'T', 1, [], obj_type) + order_other_type = TypeVarType(order_tvar_def) + order_return_type = ctx.api.named_type('__builtins__.bool') + order_args = [ + Argument(Var('other', order_other_type), order_other_type, None, ARG_POS) + ] + + existing_method = info.get(method_name) + if existing_method is not None: + assert existing_method.node + ctx.api.fail( + 'You may not have a custom %s method when order=True' % method_name, + existing_method.node, + ) + + _add_method( + ctx, + method_name, + args=order_args, + return_type=order_return_type, + self_type=order_other_type, + tvar_def=order_tvar_def, + ) + + if decorator_arguments['frozen']: + self._freeze(attributes) + + # Remove init-only vars from the class. + for attr in attributes: + if attr.is_init_var: + del info.names[attr.name] + + info.metadata['dataclass'] = { + 'attributes': OrderedDict((attr.name, attr.serialize()) for attr in attributes), + 'frozen': decorator_arguments['frozen'], + } + + def collect_attributes(self) -> List[DataclassAttribute]: + """Collect all attributes declared in the dataclass and its parents. + + All assignments of the form + + a: SomeType + b: SomeOtherType = ... + + are collected. + """ + # First, collect attributes belonging to the current class. + ctx = self._ctx + cls = self._ctx.cls + attrs = [] # type: List[DataclassAttribute] + known_attrs = set() # type: Set[str] + for stmt in cls.defs.body: + # Any assignment that doesn't use the new type declaration + # syntax can be ignored out of hand. + if not (isinstance(stmt, AssignmentStmt) and stmt.new_syntax): + continue + + # a: int, b: str = 1, 'foo' is not supported syntax so we + # don't have to worry about it. + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr): + continue + + node = cls.info.names[lhs.name].node + assert isinstance(node, Var) + + # x: ClassVar[int] is ignored by dataclasses. + if node.is_classvar: + continue + + # x: InitVar[int] is turned into x: int and is removed from the class. + is_init_var = False + if ( + isinstance(node.type, Instance) and + node.type.type.fullname() == 'dataclasses.InitVar' + ): + is_init_var = True + node.type = node.type.args[0] + + has_field_call, field_args = _collect_field_args(stmt.rvalue) + + is_in_init_param = field_args.get('init') + if is_in_init_param is None: + is_in_init = True + else: + is_in_init = bool(ctx.api.parse_bool(is_in_init_param)) + + has_default = False + # Ensure that something like x: int = field() is rejected + # after an attribute with a default. + if has_field_call: + has_default = 'default' in field_args or 'default_factory' in field_args + + # All other assignments are already type checked. + elif not isinstance(stmt.rvalue, TempNode): + has_default = True + + known_attrs.add(lhs.name) + attrs.append(DataclassAttribute( + name=lhs.name, + is_in_init=is_in_init, + is_init_var=is_init_var, + has_default=has_default, + line=stmt.line, + column=stmt.column, + )) + + # Next, collect attributes belonging to any class in the MRO + # as long as those attributes weren't already collected. This + # makes it possible to overwrite attributes in subclasses. + super_attrs = [] + init_method = cls.info.get_method('__init__') + for info in cls.info.mro[1:-1]: + if 'dataclass' not in info.metadata: + continue + + for name, data in info.metadata['dataclass']['attributes'].items(): + if name not in known_attrs: + attr = DataclassAttribute.deserialize(info, data) + if attr.is_init_var and isinstance(init_method, FuncDef): + # InitVars are removed from classes so, in order for them to be inherited + # properly, we need to re-inject them into subclasses' sym tables here. + # To do that, we look 'em up from the parents' __init__. These variables + # are subsequently removed from the sym table at the end of + # DataclassTransformer.transform. + for arg, arg_name in zip(init_method.arguments, init_method.arg_names): + if arg_name == attr.name: + cls.info.names[attr.name] = SymbolTableNode(MDEF, arg.variable) + + known_attrs.add(name) + super_attrs.append(attr) + + all_attrs = super_attrs + attrs + + # Ensure that arguments without a default don't follow + # arguments that have a default. + found_default = False + for attr in all_attrs: + # If we find any attribute that is_in_init but that + # doesn't have a default after one that does have one, + # then that's an error. + if found_default and attr.is_in_init and not attr.has_default: + ctx.api.fail( + 'Attributes without a default cannot follow attributes with one', + Context(line=attr.line, column=attr.column), + ) + + found_default = found_default or attr.has_default + + return all_attrs + + def _freeze(self, attributes: List[DataclassAttribute]) -> None: + """Converts all attributes to @property methods in order to + emulate frozen classes. + """ + info = self._ctx.cls.info + for attr in attributes: + sym_node = info.names.get(attr.name) + if sym_node is not None: + var = sym_node.node + assert isinstance(var, Var) + var.is_property = True + else: + var = attr.to_var(info) + var.info = info + var.is_property = True + var._fullname = info.fullname() + '.' + var.name() + info.names[var.name()] = SymbolTableNode(MDEF, var) + + +def dataclass_class_maker_callback(ctx: ClassDefContext) -> None: + """Hooks into the class typechecking process to add support for dataclasses. + """ + transformer = DataclassTransformer(ctx) + transformer.transform() + + +def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]: + """Returns a tuple where the first value represents whether or not + the expression is a call to dataclass.field and the second is a + dictionary of the keyword arguments that field() was called with. + """ + if ( + isinstance(expr, CallExpr) and + isinstance(expr.callee, NameExpr) and + expr.callee.fullname == 'dataclasses.field' + ): + # field() only takes keyword arguments. + args = {} + for name, arg in zip(expr.arg_names, expr.args): + assert name is not None + args[name] = arg + return True, args + return False, {} diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 9f9a3e60905a..a8526c7d634b 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -77,6 +77,7 @@ 'check-custom-plugin.test', 'check-default-plugin.test', 'check-attr.test', + 'check-dataclasses.test', ] diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test new file mode 100644 index 000000000000..aa8bad16f505 --- /dev/null +++ b/test-data/unit/check-dataclasses.test @@ -0,0 +1,432 @@ +[case testDataclassesBasic] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Person: + name: str + age: int + + def summary(self): + return "%s is %d years old." % (self.name, self.age) + +reveal_type(Person) # E: Revealed type is 'def (name: builtins.str, age: builtins.int) -> __main__.Person' +Person('John', 32) +Person('Jonh', 21, None) # E: Too many arguments for "Person" + +[builtins fixtures/list.pyi] + +[case testDataclassesBasicInheritance] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Mammal: + age: int + +@dataclass +class Person(Mammal): + name: str + + def summary(self): + return "%s is %d years old." % (self.name, self.age) + +reveal_type(Person) # E: Revealed type is 'def (age: builtins.int, name: builtins.str) -> __main__.Person' +Mammal(10) +Person(32, 'John') +Person(21, 'Jonh', None) # E: Too many arguments for "Person" + +[builtins fixtures/list.pyi] + +[case testDataclassesDeepInheritance] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class A: + a: int + +@dataclass +class B(A): + b: int + +@dataclass +class C(B): + c: int + +@dataclass +class D(C): + d: int + +reveal_type(A) # E: Revealed type is 'def (a: builtins.int) -> __main__.A' +reveal_type(B) # E: Revealed type is 'def (a: builtins.int, b: builtins.int) -> __main__.B' +reveal_type(C) # E: Revealed type is 'def (a: builtins.int, b: builtins.int, c: builtins.int) -> __main__.C' +reveal_type(D) # E: Revealed type is 'def (a: builtins.int, b: builtins.int, c: builtins.int, d: builtins.int) -> __main__.D' + +[builtins fixtures/list.pyi] + +[case testDataclassesOverriding] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Mammal: + age: int + +@dataclass +class Person(Mammal): + name: str + age: int + +reveal_type(Person) # E: Revealed type is 'def (name: builtins.str, age: builtins.int) -> __main__.Person' +Person('John', 32) +Person('John', 21, None) # E: Too many arguments for "Person" + +[builtins fixtures/list.pyi] + +[case testDataclassesFreezing] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass(frozen=True) +class Person: + name: str + +john = Person('John') +john.name = 'Ben' # E: Property "name" defined in "Person" is read-only + +[builtins fixtures/list.pyi] + +[case testDataclassesFields] +# flags: --python-version 3.6 +from dataclasses import dataclass, field + +@dataclass +class Person: + name: str + age: int = field(default=0, init=False) + +reveal_type(Person) # E: Revealed type is 'def (name: builtins.str) -> __main__.Person' +john = Person('John') +john.age = 'invalid' # E: Incompatible types in assignment (expression has type "str", variable has type "int") +john.age = 24 + +[builtins fixtures/list.pyi] + +[case testDataclassesBadInit] +# flags: --python-version 3.6 +from dataclasses import dataclass, field + +@dataclass +class Person: + name: str + age: int = field(init=None) # E: No overload variant of "field" matches argument type "None" + +[builtins fixtures/list.pyi] + +[case testDataclassesMultiInit] +# flags: --python-version 3.6 +from dataclasses import dataclass, field +from typing import List + +@dataclass +class Person: + name: str + age: int = field(init=False) + friend_names: List[str] = field(init=True) + enemy_names: List[str] + +reveal_type(Person) # E: Revealed type is 'def (name: builtins.str, friend_names: builtins.list[builtins.str], enemy_names: builtins.list[builtins.str]) -> __main__.Person' + +[builtins fixtures/list.pyi] + +[case testDataclassesMultiInitDefaults] +# flags: --python-version 3.6 +from dataclasses import dataclass, field +from typing import List, Optional + +@dataclass +class Person: + name: str + age: int = field(init=False) + friend_names: List[str] = field(init=True) + enemy_names: List[str] + nickname: Optional[str] = None + +reveal_type(Person) # E: Revealed type is 'def (name: builtins.str, friend_names: builtins.list[builtins.str], enemy_names: builtins.list[builtins.str], nickname: Union[builtins.str, None] =) -> __main__.Person' + +[builtins fixtures/list.pyi] + +[case testDataclassesDefaults] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Application: + name: str = 'Unnamed' + rating: int = 0 + +reveal_type(Application) # E: Revealed type is 'def (name: builtins.str =, rating: builtins.int =) -> __main__.Application' +app = Application() + +[builtins fixtures/list.pyi] + +[case testDataclassesDefaultFactories] +# flags: --python-version 3.6 +from dataclasses import dataclass, field + +@dataclass +class Application: + name: str = 'Unnamed' + rating: int = field(default_factory=int) + rating_count: int = field() # E: Attributes without a default cannot follow attributes with one + +[builtins fixtures/list.pyi] + +[case testDataclassesDefaultFactoryTypeChecking] +# flags: --python-version 3.6 +from dataclasses import dataclass, field + +@dataclass +class Application: + name: str = 'Unnamed' + rating: int = field(default_factory=str) # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +[builtins fixtures/list.pyi] + +[case testDataclassesDefaultOrdering] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Application: + name: str = 'Unnamed' + rating: int # E: Attributes without a default cannot follow attributes with one + +[builtins fixtures/list.pyi] + +[case testDataclassesClassmethods] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Application: + name: str + + @classmethod + def parse(cls, request: str) -> "Application": + return cls(name='...') + +app = Application.parse('') + +[builtins fixtures/list.pyi] +[builtins fixtures/classmethod.pyi] + +[case testDataclassesClassVars] +# flags: --python-version 3.6 +from dataclasses import dataclass +from typing import ClassVar + +@dataclass +class Application: + name: str + + COUNTER: ClassVar[int] = 0 + +reveal_type(Application) # E: Revealed type is 'def (name: builtins.str) -> __main__.Application' +application = Application("example") +application.COUNTER = 1 # E: Cannot assign to class variable "COUNTER" via instance +Application.COUNTER = 1 + +[builtins fixtures/list.pyi] + +[case testDataclassEquality] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Application: + name: str + rating: int + +app1 = Application("example-1", 5) +app2 = Application("example-2", 5) +app1 == app2 +app1 != app2 +app1 == None # E: Unsupported operand types for == ("Application" and "None") + +[builtins fixtures/list.pyi] + +[case testDataclassCustomEquality] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass +class Application: + name: str + rating: int + + def __eq__(self, other: 'Application') -> bool: + ... + +app1 = Application("example-1", 5) +app2 = Application("example-2", 5) +app1 == app2 +app1 != app2 # E: Unsupported left operand type for != ("Application") +app1 == None # E: Unsupported operand types for == ("Application" and "None") + +class SpecializedApplication(Application): + ... + +app1 == SpecializedApplication("example-3", 5) + +[builtins fixtures/list.pyi] + +[case testDataclassOrdering] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass(order=True) +class Application: + name: str + rating: int + +app1 = Application('example-1', 5) +app2 = Application('example-2', 5) +app1 < app2 +app1 > app2 +app1 <= app2 +app1 >= app2 +app1 < 5 # E: Unsupported operand types for < ("Application" and "int") +app1 > 5 # E: Unsupported operand types for > ("Application" and "int") +app1 <= 5 # E: Unsupported operand types for <= ("Application" and "int") +app1 >= 5 # E: Unsupported operand types for >= ("Application" and "int") + +class SpecializedApplication(Application): + ... + +app3 = SpecializedApplication('example-3', 5) +app1 < app3 +app1 > app3 +app1 <= app3 +app1 >= app3 + +[builtins fixtures/list.pyi] + +[case testDataclassOrderingWithoutEquality] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass(eq=False, order=True) # E: eq must be True if order is True +class Application: + ... + +[builtins fixtures/list.pyi] + +[case testDataclassOrderingWithCustomMethods] +# flags: --python-version 3.6 +from dataclasses import dataclass + +@dataclass(order=True) +class Application: + def __lt__(self, other: 'Application') -> bool: # E: You may not have a custom __lt__ method when order=True + ... + +[builtins fixtures/list.pyi] + +[case testDataclassDefaultsInheritance] +# flags: --python-version 3.6 +from dataclasses import dataclass +from typing import Optional + +@dataclass(order=True) +class Application: + id: Optional[int] + name: str + +@dataclass +class SpecializedApplication(Application): + rating: int = 0 + +reveal_type(SpecializedApplication) # E: Revealed type is 'def (id: Union[builtins.int, None], name: builtins.str, rating: builtins.int =) -> __main__.SpecializedApplication' + +[builtins fixtures/list.pyi] + +[case testDataclassGenerics] +# flags: --python-version 3.6 +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar + +T = TypeVar('T') + +@dataclass +class A(Generic[T]): + x: T + y: T + z: List[T] + + def foo(self) -> List[T]: + return [self.x, self.y] + + def bar(self) -> T: + return self.z[0] + + def problem(self) -> T: + return self.z # E: Incompatible return value type (got "List[T]", expected "T") + +reveal_type(A) # E: Revealed type is 'def [T] (x: T`1, y: T`1, z: builtins.list[T`1]) -> __main__.A[T`1]' +A(1, 2, ["a", "b"]) # E: Cannot infer type argument 1 of "A" +a = A(1, 2, [1, 2]) +reveal_type(a) # E: Revealed type is '__main__.A[builtins.int*]' +reveal_type(a.x) # E: Revealed type is 'builtins.int*' +reveal_type(a.y) # E: Revealed type is 'builtins.int*' +reveal_type(a.z) # E: Revealed type is 'builtins.list[builtins.int*]' +s: str = a.bar() # E: Incompatible types in assignment (expression has type "int", variable has type "str") + +[builtins fixtures/list.pyi] + +[case testDataclassesForwardRefs] +from dataclasses import dataclass + +@dataclass +class A: + b: 'B' + +@dataclass +class B: + x: int + +reveal_type(A) # E: Revealed type is 'def (b: __main__.B) -> __main__.A' +A(b=B(42)) +A(b=42) # E: Argument "b" to "A" has incompatible type "int"; expected "B" + +[builtins fixtures/list.pyi] + + +[case testDataclassesInitVars] +from dataclasses import InitVar, dataclass + +@dataclass +class Application: + name: str + database_name: InitVar[str] + +reveal_type(Application) # E: Revealed type is 'def (name: builtins.str, database_name: builtins.str) -> __main__.Application' +app = Application("example", 42) # E: Argument 2 to "Application" has incompatible type "int"; expected "str" +app = Application("example", "apps") +app.name +app.database_name # E: "Application" has no attribute "database_name" + + +@dataclass +class SpecializedApplication(Application): + rating: int + +reveal_type(SpecializedApplication) # E: Revealed type is 'def (name: builtins.str, database_name: builtins.str, rating: builtins.int) -> __main__.SpecializedApplication' +app = SpecializedApplication("example", "apps", "five") # E: Argument 3 to "SpecializedApplication" has incompatible type "str"; expected "int" +app = SpecializedApplication("example", "apps", 5) +app.name +app.rating +app.database_name # E: "SpecializedApplication" has no attribute "database_name" + +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 1c1871922b43..3f537922d9e3 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -4362,6 +4362,321 @@ import b [stale] [rechecked b] +[case testIncrementalDataclassesSubclassingCached] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + e: str = 'e' + +a = B(5, [5], 'foo') +a.a = 6 +a._b = [2] +a.c = 'yo' +a._d = 22 +a.e = 'hi' + +[file a.py] +from dataclasses import dataclass, field +from typing import ClassVar, List + +@dataclass +class A: + a: int + _b: List[int] + c: str = '18' + _d: int = field(default=False) + E = 7 + F: ClassVar[int] = 22 + +[builtins fixtures/list.pyi] +[out1] +[out2] + +[case testIncrementalDataclassesSubclassingCachedType] +import b + +[file b.py] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + pass + +[file b.py.2] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + pass + +reveal_type(B) + +[file a.py] +from dataclasses import dataclass + +@dataclass +class A: + x: int + +[builtins fixtures/list.pyi] +[out1] +[out2] +tmp/b.py:8: error: Revealed type is 'def (x: builtins.int) -> b.B' + +[case testIncrementalDataclassesArguments] +import b + +[file b.py] +from a import Frozen, NoInit, NoCmp + +[file b.py.2] +from a import Frozen, NoInit, NoCmp + +f = Frozen(5) +f.x = 6 + +g = NoInit() + +Frozen(1) < Frozen(2) +Frozen(1) <= Frozen(2) +Frozen(1) > Frozen(2) +Frozen(1) >= Frozen(2) + +NoCmp(1) < NoCmp(2) +NoCmp(1) <= NoCmp(2) +NoCmp(1) > NoCmp(2) +NoCmp(1) >= NoCmp(2) + +[file a.py] +from dataclasses import dataclass + +@dataclass(frozen=True, order=True) +class Frozen: + x: int + +@dataclass(init=False) +class NoInit: + x: int + +@dataclass(order=False) +class NoCmp: + x: int + +[builtins fixtures/list.pyi] +[out1] +[out2] +tmp/b.py:4: error: Property "x" defined in "Frozen" is read-only +tmp/b.py:13: error: Unsupported left operand type for < ("NoCmp") +tmp/b.py:14: error: Unsupported left operand type for <= ("NoCmp") +tmp/b.py:15: error: Unsupported left operand type for > ("NoCmp") +tmp/b.py:16: error: Unsupported left operand type for >= ("NoCmp") + +[case testIncrementalDataclassesDunder] +import b + +[file b.py] +from a import A + +[file b.py.2] +from a import A + +reveal_type(A) +reveal_type(A.__eq__) +reveal_type(A.__ne__) +reveal_type(A.__lt__) +reveal_type(A.__le__) +reveal_type(A.__gt__) +reveal_type(A.__ge__) + +A(1) < A(2) +A(1) <= A(2) +A(1) > A(2) +A(1) >= A(2) +A(1) == A(2) +A(1) != A(2) + +A(1) < 1 +A(1) <= 1 +A(1) > 1 +A(1) >= 1 +A(1) == 1 +A(1) != 1 + +1 < A(1) +1 <= A(1) +1 > A(1) +1 >= A(1) +1 == A(1) +1 != A(1) + +[file a.py] +from dataclasses import dataclass + +@dataclass(order=True) +class A: + a: int + +[builtins fixtures/attr.pyi] +[out1] +[out2] +tmp/b.py:3: error: Revealed type is 'def (a: builtins.int) -> a.A' +tmp/b.py:4: error: Revealed type is 'def (builtins.object, builtins.object) -> builtins.bool' +tmp/b.py:5: error: Revealed type is 'def (builtins.object, builtins.object) -> builtins.bool' +tmp/b.py:6: error: Revealed type is 'def [T] (self: T`1, other: T`1) -> builtins.bool' +tmp/b.py:7: error: Revealed type is 'def [T] (self: T`1, other: T`1) -> builtins.bool' +tmp/b.py:8: error: Revealed type is 'def [T] (self: T`1, other: T`1) -> builtins.bool' +tmp/b.py:9: error: Revealed type is 'def [T] (self: T`1, other: T`1) -> builtins.bool' +tmp/b.py:18: error: Unsupported operand types for < ("A" and "int") +tmp/b.py:19: error: Unsupported operand types for <= ("A" and "int") +tmp/b.py:20: error: Unsupported operand types for > ("A" and "int") +tmp/b.py:21: error: Unsupported operand types for >= ("A" and "int") +tmp/b.py:25: error: Unsupported operand types for > ("A" and "int") +tmp/b.py:26: error: Unsupported operand types for >= ("A" and "int") +tmp/b.py:27: error: Unsupported operand types for < ("A" and "int") +tmp/b.py:28: error: Unsupported operand types for <= ("A" and "int") + +[case testIncrementalDataclassesSubclassModified] +from b import B +B(5, 'foo') + +[file a.py] +from dataclasses import dataclass + +@dataclass +class A: + x: int + +[file b.py] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + y: str + +[file b.py.2] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + y: int + +[builtins fixtures/list.pyi] +[out1] +[out2] +main:2: error: Argument 2 to "B" has incompatible type "str"; expected "int" +[rechecked b] + +[case testIncrementalDataclassesSubclassModifiedErrorFirst] +from b import B +B(5, 'foo') + +[file a.py] +from dataclasses import dataclass + +@dataclass +class A: + x: int + +[file b.py] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + y: int + +[file b.py.2] +from a import A +from dataclasses import dataclass + +@dataclass +class B(A): + y: str + +[builtins fixtures/list.pyi] +[out1] +main:2: error: Argument 2 to "B" has incompatible type "str"; expected "int" + +[out2] +[rechecked b] + +[case testIncrementalDataclassesThreeFiles] +from c import C +C(5, 'foo', True) + +[file a.py] +from dataclasses import dataclass + +@dataclass +class A: + a: int + +[file b.py] +from dataclasses import dataclass + +@dataclass +class B: + b: str + +[file b.py.2] +from dataclasses import dataclass + +@dataclass +class B: + b: str + c: str + +[file c.py] +from a import A +from b import B +from dataclasses import dataclass + +@dataclass +class C(A, B): + c: bool + +[builtins fixtures/list.pyi] +[out1] +[out2] +tmp/c.py:7: error: Incompatible types in assignment (expression has type "bool", base class "B" defined the type as "str") + +[case testIncrementalDataclassesThreeRuns] +from a import A +A(5) + +[file a.py] +from dataclasses import dataclass + +@dataclass +class A: + a: int + +[file a.py.2] +from dataclasses import dataclass + +@dataclass +class A: + a: str + +[file a.py.3] +from dataclasses import dataclass + +@dataclass +class A: + a: int = 6 + +[builtins fixtures/list.pyi] +[out1] +[out2] +main:2: error: Argument 1 to "A" has incompatible type "int"; expected "str" +[out3] + [case testParentPatchingMess] # flags: --ignore-missing-imports --follow-imports=skip # cmd: mypy -m d d.k d.k.a d.k.v t diff --git a/test-data/unit/lib-stub/dataclasses.pyi b/test-data/unit/lib-stub/dataclasses.pyi new file mode 100644 index 000000000000..160cfcd066ba --- /dev/null +++ b/test-data/unit/lib-stub/dataclasses.pyi @@ -0,0 +1,30 @@ +from typing import Any, Callable, Generic, Mapping, Optional, TypeVar, overload, Type + +_T = TypeVar('_T') + +class InitVar(Generic[_T]): + ... + + +@overload +def dataclass(_cls: Type[_T]) -> Type[_T]: ... + +@overload +def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ..., + unsafe_hash: bool = ..., frozen: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ... + + +@overload +def field(*, default: _T, + init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., + metadata: Optional[Mapping[str, Any]] = ...) -> _T: ... + +@overload +def field(*, default_factory: Callable[[], _T], + init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., + metadata: Optional[Mapping[str, Any]] = ...) -> _T: ... + +@overload +def field(*, + init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., + metadata: Optional[Mapping[str, Any]] = ...) -> Any: ...