diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 7827a2818be9..e00aca2869bd 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -270,3 +270,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage: CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"' MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"' + +DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = ( + '"alias" argument to dataclass field must be a string literal' +) diff --git a/mypy/plugin.py b/mypy/plugin.py index 00a2af82969f..cf124b45d04f 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -297,6 +297,10 @@ def parse_bool(self, expr: Expression) -> bool | None: """Parse True/False literals.""" raise NotImplementedError + @abstractmethod + def parse_str_literal(self, expr: Expression) -> str | None: + """Parse string literals.""" + @abstractmethod def fail( self, diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 872765847073..6b1062d6457f 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -5,6 +5,7 @@ from typing import Optional from typing_extensions import Final +from mypy import errorcodes, message_registry from mypy.expandtype import expand_type from mypy.nodes import ( ARG_NAMED, @@ -77,6 +78,7 @@ class DataclassAttribute: def __init__( self, name: str, + alias: str | None, is_in_init: bool, is_init_var: bool, has_default: bool, @@ -87,6 +89,7 @@ def __init__( kw_only: bool, ) -> None: self.name = name + self.alias = alias self.is_in_init = is_in_init self.is_init_var = is_init_var self.has_default = has_default @@ -121,12 +124,13 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]: return self.type def to_var(self, current_info: TypeInfo) -> Var: - return Var(self.name, self.expand_type(current_info)) + return Var(self.alias or self.name, self.expand_type(current_info)) def serialize(self) -> JsonDict: assert self.type return { "name": self.name, + "alias": self.alias, "is_in_init": self.is_in_init, "is_init_var": self.is_init_var, "has_default": self.has_default, @@ -495,7 +499,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # Ensure that something like x: int = field() is rejected # after an attribute with a default. if has_field_call: - has_default = "default" in field_args or "default_factory" in field_args + has_default = ( + "default" in field_args + or "default_factory" in field_args + # alias for default_factory defined in PEP 681 + or "factory" in field_args + ) # All other assignments are already type checked. elif not isinstance(stmt.rvalue, TempNode): @@ -511,7 +520,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # kw_only value from the decorator parameter. field_kw_only_param = field_args.get("kw_only") if field_kw_only_param is not None: - is_kw_only = bool(self._api.parse_bool(field_kw_only_param)) + value = self._api.parse_bool(field_kw_only_param) + if value is not None: + is_kw_only = value + else: + self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue) if sym.type is None and node.is_final and node.is_inferred: # This is a special case, assignment like x: Final = 42 is classified @@ -529,9 +542,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: ) node.type = AnyType(TypeOfAny.from_error) + alias = None + if "alias" in field_args: + alias = self._api.parse_str_literal(field_args["alias"]) + if alias is None: + self._api.fail( + message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL, + stmt.rvalue, + code=errorcodes.LITERAL_REQ, + ) + current_attr_names.add(lhs.name) found_attrs[lhs.name] = DataclassAttribute( name=lhs.name, + alias=alias, is_in_init=is_in_init, is_init_var=is_init_var, has_default=has_default, @@ -624,6 +648,14 @@ def _is_kw_only_type(self, node: Type | None) -> bool: return node_type.type.fullname == "dataclasses.KW_ONLY" def _add_dataclass_fields_magic_attribute(self) -> None: + # Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform + # classes. + # It would be nice if this condition were reified rather than using an `is` check. + # Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform + # classes. + if self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES: + return + attr_name = "__dataclass_fields__" any_type = AnyType(TypeOfAny.explicit) field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type @@ -657,6 +689,12 @@ def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Express # the best we can do for now is not to fail. # TODO: we can infer what's inside `**` and try to collect it. message = 'Unpacking **kwargs in "field()" is not supported' + elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES: + # dataclasses.field can only be used with keyword args, but this + # restriction is only enforced for the *standardized* arguments to + # dataclass_transform field specifiers. If this is not a + # dataclasses.dataclass class, we can just skip positional args safely. + continue else: message = '"field()" does not accept positional arguments' self._api.fail(message, expr) diff --git a/mypy/semanal.py b/mypy/semanal.py index 8c16b0addd45..d2fd92499679 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -236,7 +236,7 @@ remove_dups, type_constructors, ) -from mypy.typeops import function_type, get_type_vars +from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type from mypy.types import ( ASSERT_TYPE_NAMES, DATACLASS_TRANSFORM_NAMES, @@ -6462,6 +6462,17 @@ def parse_bool(self, expr: Expression) -> bool | None: return False return None + def parse_str_literal(self, expr: Expression) -> str | None: + """Attempt to find the string literal value of the given expression. Returns `None` if no + literal value can be found.""" + if isinstance(expr, StrExpr): + return expr.value + if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None: + values = try_getting_str_literals_from_type(expr.node.type) + if values is not None and len(values) == 1: + return values[0] + return None + def set_future_import_flags(self, module_name: str) -> None: if module_name in FUTURE_IMPORTS: self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name]) @@ -6482,7 +6493,9 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp # field_specifiers is currently the only non-boolean argument; check for it first so # so the rest of the block can fail through to handling booleans if name == "field_specifiers": - self.fail('"field_specifiers" support is currently unimplemented', call) + parameters.field_specifiers = self.parse_dataclass_transform_field_specifiers( + value + ) continue boolean = require_bool_literal_argument(self, value, name) @@ -6502,6 +6515,19 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp return parameters + def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]: + if not isinstance(arg, TupleExpr): + self.fail('"field_specifiers" argument must be a tuple literal', arg) + return tuple() + + names = [] + for specifier in arg.items: + if not isinstance(specifier, RefExpr): + self.fail('"field_specifiers" must only contain identifiers', specifier) + return tuple() + names.append(specifier.fullname) + return tuple(names) + def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index bc8fe1ecf58c..2a7fad1da992 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -210,6 +210,125 @@ Foo(5) [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformFieldSpecifierRejectMalformed] +# flags: --python-version 3.11 +from typing import dataclass_transform, Any, Callable, Final, Type + +def some_type() -> Type: ... +def some_function() -> Callable[[], None]: ... + +def field(*args, **kwargs): ... +def fields_tuple() -> tuple[type | Callable[..., Any], ...]: return (field,) +CONSTANT: Final = (field,) + +@dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers +def bad_dataclass1() -> None: ... +@dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers +def bad_dataclass2() -> None: ... +@dataclass_transform(field_specifiers=CONSTANT) # E: "field_specifiers" argument must be a tuple literal +def bad_dataclass3() -> None: ... +@dataclass_transform(field_specifiers=fields_tuple()) # E: "field_specifiers" argument must be a tuple literal +def bad_dataclass4() -> None: ... + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformFieldSpecifierParams] +# flags: --python-version 3.11 +from typing import dataclass_transform, Any, Callable, Type, Final + +def field( + *, + init: bool = True, + kw_only: bool = False, + alias: str | None = None, + default: Any | None = None, + default_factory: Callable[[], Any] | None = None, + factory: Callable[[], Any] | None = None, +): ... +@dataclass_transform(field_specifiers=(field,)) +def my_dataclass(cls: Type) -> Type: + return cls + +B: Final = 'b_' +@my_dataclass +class Foo: + a: int = field(alias='a_') + b: int = field(alias=B) + # cannot be passed as a positional + kwonly: int = field(kw_only=True, default=0) + # Safe to omit from constructor, error to pass + noinit: int = field(init=False, default=1) + # It should be safe to call the constructor without passing any of these + unused1: int = field(default=0) + unused2: int = field(factory=lambda: 0) + unused3: int = field(default_factory=lambda: 0) + +Foo(a=5, b_=1) # E: Unexpected keyword argument "a" for "Foo" +Foo(a_=1, b_=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo" +Foo(1, 2, 3) # E: Too many positional arguments for "Foo" +foo = Foo(1, 2, kwonly=3) +reveal_type(foo.noinit) # N: Revealed type is "builtins.int" +reveal_type(foo.unused1) # N: Revealed type is "builtins.int" +Foo(a_=5, b_=1, unused1=2, unused2=3, unused3=4) + +def some_str() -> str: ... +def some_bool() -> bool: ... +@my_dataclass +class Bad: + bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal + bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be a boolean literal + +# this metadata should only exist for dataclasses.dataclass classes +Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformFieldSpecifierExtraArgs] +# flags: --python-version 3.11 +from typing import dataclass_transform + +def field(extra1, *, kw_only=False, extra2=0): ... +@dataclass_transform(field_specifiers=(field,)) +def my_dataclass(cls): + return cls + +@my_dataclass +class Good: + a: int = field(5) + b: int = field(5, extra2=1) + c: int = field(5, kw_only=True) + +@my_dataclass +class Bad: + a: int = field(kw_only=True) # E: Missing positional argument "extra1" in call to "field" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformMultipleFieldSpecifiers] +# flags: --python-version 3.11 +from typing import dataclass_transform + +def field1(*, default: int) -> int: ... +def field2(*, default: str) -> str: ... + +@dataclass_transform(field_specifiers=(field1, field2)) +def my_dataclass(cls): return cls + +@my_dataclass +class Foo: + a: int = field1(default=0) + b: str = field2(default='hello') + +reveal_type(Foo) # N: Revealed type is "def (a: builtins.int =, b: builtins.str =) -> __main__.Foo" +Foo() +Foo(a=1, b='bye') + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformOverloadsDecoratorOnOverload] # flags: --python-version 3.11 from typing import dataclass_transform, overload, Any, Callable, Type, Literal