From 0e19686c85b55a7834f792263a66092d8e639c25 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 20:07:25 +0000 Subject: [PATCH 1/7] [dataclass_transform] support field_specifiers --- mypy/message_registry.py | 4 ++ mypy/plugin.py | 4 ++ mypy/plugins/dataclasses.py | 38 ++++++++++- mypy/semanal.py | 27 +++++++- test-data/unit/check-dataclass-transform.test | 67 +++++++++++++++++++ 5 files changed, 135 insertions(+), 5 deletions(-) diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 7827a2818be9..e00aca2869bd 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -270,3 +270,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage: CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"' MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"' + +DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = ( + '"alias" argument to dataclass field must be a string literal' +) diff --git a/mypy/plugin.py b/mypy/plugin.py index 00a2af82969f..cf124b45d04f 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -297,6 +297,10 @@ def parse_bool(self, expr: Expression) -> bool | None: """Parse True/False literals.""" raise NotImplementedError + @abstractmethod + def parse_str_literal(self, expr: Expression) -> str | None: + """Parse string literals.""" + @abstractmethod def fail( self, diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 4683b8c1ffaf..0c8fa7027dcc 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -5,6 +5,7 @@ from typing import Optional from typing_extensions import Final +from mypy import errorcodes, message_registry from mypy.expandtype import expand_type from mypy.nodes import ( ARG_NAMED, @@ -75,6 +76,7 @@ class DataclassAttribute: def __init__( self, name: str, + alias: str | None, is_in_init: bool, is_init_var: bool, has_default: bool, @@ -85,6 +87,7 @@ def __init__( kw_only: bool, ) -> None: self.name = name + self.alias = alias self.is_in_init = is_in_init self.is_init_var = is_init_var self.has_default = has_default @@ -119,12 +122,13 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]: return self.type def to_var(self, current_info: TypeInfo) -> Var: - return Var(self.name, self.expand_type(current_info)) + return Var(self.alias or self.name, self.expand_type(current_info)) def serialize(self) -> JsonDict: assert self.type return { "name": self.name, + "alias": self.alias, "is_in_init": self.is_in_init, "is_init_var": self.is_init_var, "has_default": self.has_default, @@ -482,7 +486,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # Ensure that something like x: int = field() is rejected # after an attribute with a default. if has_field_call: - has_default = "default" in field_args or "default_factory" in field_args + has_default = ( + "default" in field_args + or "default_factory" in field_args + # alias for default_factory defined in PEP 681 + or "factory" in field_args + ) # All other assignments are already type checked. elif not isinstance(stmt.rvalue, TempNode): @@ -498,7 +507,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # kw_only value from the decorator parameter. field_kw_only_param = field_args.get("kw_only") if field_kw_only_param is not None: - is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param)) + value = ctx.api.parse_bool(field_kw_only_param) + if value is not None: + is_kw_only = value + else: + ctx.api.fail('"kw_only" argument must be True or False.', stmt.rvalue) if sym.type is None and node.is_final and node.is_inferred: # This is a special case, assignment like x: Final = 42 is classified @@ -516,9 +529,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: ) node.type = AnyType(TypeOfAny.from_error) + alias = None + if "alias" in field_args: + alias = self._ctx.api.parse_str_literal(field_args["alias"]) + if alias is None: + self._ctx.api.fail( + message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL, + stmt.rvalue, + code=errorcodes.LITERAL_REQ, + ) + current_attr_names.add(lhs.name) found_attrs[lhs.name] = DataclassAttribute( name=lhs.name, + alias=alias, is_in_init=is_in_init, is_init_var=is_init_var, has_default=has_default, @@ -609,6 +633,14 @@ def _is_kw_only_type(self, node: Type | None) -> bool: return node_type.type.fullname == "dataclasses.KW_ONLY" def _add_dataclass_fields_magic_attribute(self) -> None: + # Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform + # classes. + # It would be nice if this condition were reified rather than using an `is` check. + # Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform + # classes. + if self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES: + return + attr_name = "__dataclass_fields__" any_type = AnyType(TypeOfAny.explicit) field_type = self._ctx.api.named_type_or_none("dataclasses.Field", [any_type]) or any_type diff --git a/mypy/semanal.py b/mypy/semanal.py index cd5b82f80b1d..bc3cc77252f8 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -235,7 +235,7 @@ remove_dups, type_constructors, ) -from mypy.typeops import function_type, get_type_vars +from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type from mypy.types import ( ASSERT_TYPE_NAMES, DATACLASS_TRANSFORM_NAMES, @@ -6451,6 +6451,15 @@ def parse_bool(self, expr: Expression) -> bool | None: return False return None + def parse_str_literal(self, expr: Expression) -> str | None: + if isinstance(expr, StrExpr): + return expr.value + if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None: + values = try_getting_str_literals_from_type(expr.node.type) + if values is not None and len(values) == 1: + return values[0] + return None + def set_future_import_flags(self, module_name: str) -> None: if module_name in FUTURE_IMPORTS: self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name]) @@ -6466,7 +6475,9 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp # field_specifiers is currently the only non-boolean argument; check for it first so # so the rest of the block can fail through to handling booleans if name == "field_specifiers": - self.fail('"field_specifiers" support is currently unimplemented', call) + parameters.field_specifiers = self.parse_dataclass_transform_field_specifiers( + value + ) continue boolean = self.parse_bool(value) @@ -6487,6 +6498,18 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp return parameters + def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]: + if not isinstance(arg, TupleExpr): + return tuple() + + names = [] + for specifier in arg.items: + if not isinstance(specifier, RefExpr): + self.fail('"field_specifiers" must only contain identifiers', specifier) + return tuple() + names.append(specifier.fullname) + return tuple(names) + def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 01e8935b0745..216634aaab9f 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -202,3 +202,70 @@ Foo(5) [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformFieldSpecifierRejectMalformed] +# flags: --python-version 3.11 +from typing import dataclass_transform, Callable, Type + +def some_type() -> Type: ... +def some_function() -> Callable[[], None]: ... + +@dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers +def bad_dataclass1() -> None: ... +@dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers +def bad_dataclass2() -> None: ... + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformFieldSpecifierParams] +# flags: --python-version 3.11 +from typing import dataclass_transform, Any, Callable, Type, Final + +def field( + *, + init: bool = True, + kw_only: bool = False, + alias: str | None = None, + default: Any | None = None, + default_factory: Callable[[], Any] | None = None, + factory: Callable[[], Any] | None = None, +): ... +@dataclass_transform(field_specifiers=(field,)) +def my_dataclass(cls: Type) -> Type: + return cls + +B: Final = 'b' +@my_dataclass +class Foo: + a: int = field(alias='a_') + b: int = field(alias=B) + # cannot be passed as a positional + kwonly: int = field(kw_only=True, default=0) + # Safe to omit from constructor, error to pass + noinit: int = field(init=False, default=1) + # It should be safe to call the constructor without passing any of these + unused1: int = field(default=0) + unused2: int = field(factory=lambda: 0) + unused3: int = field(default_factory=lambda: 0) + +Foo(a=5, b=1) # E: Unexpected keyword argument "a" for "Foo" +Foo(a_=1, b=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo" +Foo(1, 2, 3) # E: Too many positional arguments for "Foo" +foo = Foo(a_=5, b=1) +reveal_type(foo.noinit) # N: Revealed type is "builtins.int" +reveal_type(foo.unused1) # N: Revealed type is "builtins.int" +Foo(a_=5, b=1, unused1=2, unused2=3, unused3=4) + +def some_str() -> str: ... +def some_bool() -> bool: ... +@my_dataclass +class Bad: + bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal + bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be True or False. + +# this metadata should only exist for dataclasses.dataclass classes +Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] From cfb861ad7799c5faee03d7f4bef72362c9f9cfee Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Mon, 13 Feb 2023 16:27:54 +0000 Subject: [PATCH 2/7] add a docstring --- mypy/semanal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/semanal.py b/mypy/semanal.py index d87ca024f2b6..dbd9e9064ea2 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -6462,6 +6462,8 @@ def parse_bool(self, expr: Expression) -> bool | None: return None def parse_str_literal(self, expr: Expression) -> str | None: + """Attempt to find the string literal value of the given expression. Returns `None` if no + literal value can be found.""" if isinstance(expr, StrExpr): return expr.value if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None: From e190e18a9be2898728814bb335dc4451b5dcd90c Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Mon, 13 Feb 2023 16:36:15 +0000 Subject: [PATCH 3/7] update error message formatting --- mypy/plugins/dataclasses.py | 2 +- test-data/unit/check-dataclass-transform.test | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 2f171e30b7b7..7aeb19ee93bf 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -524,7 +524,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if value is not None: is_kw_only = value else: - self._api.fail('"kw_only" argument must be True or False.', stmt.rvalue) + self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue) if sym.type is None and node.is_final and node.is_inferred: # This is a special case, assignment like x: Final = 42 is classified diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 8c59c4eee77b..ab0be02eed72 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -269,7 +269,7 @@ def some_bool() -> bool: ... @my_dataclass class Bad: bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal - bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be True or False. + bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be a boolean literal # this metadata should only exist for dataclasses.dataclass classes Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__" From e4c692934a80ad777ab09435300209939594f4b4 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Mon, 13 Feb 2023 17:07:23 +0000 Subject: [PATCH 4/7] allow non-standard positional args for dataclass_transform field specifiers --- mypy/plugins/dataclasses.py | 6 +++++ test-data/unit/check-dataclass-transform.test | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 7aeb19ee93bf..db7aeec84852 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -689,6 +689,12 @@ def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Express # the best we can do for now is not to fail. # TODO: we can infer what's inside `**` and try to collect it. message = 'Unpacking **kwargs in "field()" is not supported' + elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES: + # dataclasses.field can only be used with keyword args, but this + # restriction is only enforced for the *standardized* arguments to + # dataclass_transform field specifiers. If this is not a + # dataclasses.dataclass class, we can just skip positional args safely. + continue else: message = '"field()" does not accept positional arguments' self._api.fail(message, expr) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index ab0be02eed72..1dbb396810f4 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -277,6 +277,28 @@ Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields_ [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformFieldSpecifierExtraArgs] +# flags: --python-version 3.11 +from typing import dataclass_transform + +def field(extra1, *, kw_only=False, extra2=0): ... +@dataclass_transform(field_specifiers=(field,)) +def my_dataclass(cls): + return cls + +@my_dataclass +class Good: + a: int = field(5) + b: int = field(5, extra2=1) + c: int = field(5, kw_only=True) + +@my_dataclass +class Bad: + a: int = field(kw_only=True) # E: Missing positional argument "extra1" in call to "field" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformOverloadsDecoratorOnOverload] # flags: --python-version 3.11 from typing import dataclass_transform, overload, Any, Callable, Type, Literal From 23e930023817c531ff1353d7139452588d8aa598 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 14 Feb 2023 18:45:34 +0000 Subject: [PATCH 5/7] explicit error when field_specifiers arg is not a tuple literal --- mypy/semanal.py | 1 + test-data/unit/check-dataclass-transform.test | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index be9894b5b333..d2fd92499679 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -6517,6 +6517,7 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]: if not isinstance(arg, TupleExpr): + self.fail('"field_specifiers" argument must be a tuple literal', arg) return tuple() names = [] diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 25d240c95ba2..01efdae5eb1d 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -212,15 +212,23 @@ Foo(5) [case testDataclassTransformFieldSpecifierRejectMalformed] # flags: --python-version 3.11 -from typing import dataclass_transform, Callable, Type +from typing import dataclass_transform, Any, Callable, Final, Type def some_type() -> Type: ... def some_function() -> Callable[[], None]: ... +def field(*args, **kwargs): ... +def fields_tuple() -> tuple[type | Callable[..., Any], ...]: return (field,) +CONSTANT: Final = (field,) + @dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers def bad_dataclass1() -> None: ... @dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers def bad_dataclass2() -> None: ... +@dataclass_transform(field_specifiers=CONSTANT) # E: "field_specifiers" argument must be a tuple literal +def bad_dataclass3() -> None: ... +@dataclass_transform(field_specifiers=fields_tuple()) # E: "field_specifiers" argument must be a tuple literal +def bad_dataclass4() -> None: ... [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] From 3340b3c7fe682acbf3e19f441358170381a6ce20 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 14 Feb 2023 18:47:55 +0000 Subject: [PATCH 6/7] more coverage on field attribute tests --- test-data/unit/check-dataclass-transform.test | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 01efdae5eb1d..17c7772ab9a8 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -250,7 +250,7 @@ def field( def my_dataclass(cls: Type) -> Type: return cls -B: Final = 'b' +B: Final = 'b_' @my_dataclass class Foo: a: int = field(alias='a_') @@ -264,13 +264,13 @@ class Foo: unused2: int = field(factory=lambda: 0) unused3: int = field(default_factory=lambda: 0) -Foo(a=5, b=1) # E: Unexpected keyword argument "a" for "Foo" -Foo(a_=1, b=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo" +Foo(a=5, b_=1) # E: Unexpected keyword argument "a" for "Foo" +Foo(a_=1, b_=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo" Foo(1, 2, 3) # E: Too many positional arguments for "Foo" -foo = Foo(a_=5, b=1) +foo = Foo(1, 2, kwonly=3) reveal_type(foo.noinit) # N: Revealed type is "builtins.int" reveal_type(foo.unused1) # N: Revealed type is "builtins.int" -Foo(a_=5, b=1, unused1=2, unused2=3, unused3=4) +Foo(a_=5, b_=1, unused1=2, unused2=3, unused3=4) def some_str() -> str: ... def some_bool() -> bool: ... From d911373c656f25f1e1557f5ac6fb3f6395251456 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 14 Feb 2023 18:55:58 +0000 Subject: [PATCH 7/7] add test with multiple field specifiers --- test-data/unit/check-dataclass-transform.test | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 17c7772ab9a8..2a7fad1da992 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -307,6 +307,28 @@ class Bad: [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformMultipleFieldSpecifiers] +# flags: --python-version 3.11 +from typing import dataclass_transform + +def field1(*, default: int) -> int: ... +def field2(*, default: str) -> str: ... + +@dataclass_transform(field_specifiers=(field1, field2)) +def my_dataclass(cls): return cls + +@my_dataclass +class Foo: + a: int = field1(default=0) + b: str = field2(default='hello') + +reveal_type(Foo) # N: Revealed type is "def (a: builtins.int =, b: builtins.str =) -> __main__.Foo" +Foo() +Foo(a=1, b='bye') + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformOverloadsDecoratorOnOverload] # flags: --python-version 3.11 from typing import dataclass_transform, overload, Any, Callable, Type, Literal