From d7c0bd0ea4507d2a020db110546ffe88172da39c Mon Sep 17 00:00:00 2001 From: Annie Wu Date: Mon, 12 Dec 2022 16:56:15 -0500 Subject: [PATCH 1/2] implemented error check for functions not called on dataclass instances --- mypy/plugins/dataclasses.py | 18 ++++++++++++++++++ mypy/plugins/default.py | 5 ++++- test-data/unit/lib-stub/dataclasses.pyi | 2 ++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 75496d5e56f9..ffdb76b9ce9e 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -4,6 +4,7 @@ from typing import Optional from typing_extensions import Final +import mypy from mypy.expandtype import expand_type from mypy.nodes import ( @@ -630,6 +631,23 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: transformer = DataclassTransformer(ctx) return transformer.transform() +def dataclass_check_type_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Will raise an error if the function is not called on a dataclass + instance. Examples of these functions are replace() and field() functions. + + This callback is called in default.py + """ + if ctx != None: + if len(ctx.arg_types) != 0: + if len(ctx.arg_types[0]) != 0: + #this means there's no dataclass_tag in the metadata, so the argument is not a dataclass + if 'dataclass' not in (ctx.arg_types[0][0].type.metadata): + ctx.api.msg.fail( + "Calling replace on a non-dataclass.", + ctx.context + ) + return ctx.default_return_type + def _collect_field_args( expr: Expression, ctx: ClassDefContext diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 04971868e8f4..a708fbb6c7aa 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -37,12 +37,15 @@ class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: - from mypy.plugins import ctypes, singledispatch + from mypy.plugins import ctypes, singledispatch, dataclasses if fullname == "ctypes.Array": return ctypes.array_constructor_callback elif fullname == "functools.singledispatch": return singledispatch.create_singledispatch_function_callback + #these functions should only be called on dataclass instances + elif fullname == "dataclasses.replace" or fullname == "dataclasses.field": + return dataclasses.dataclass_check_type_callback return None def get_method_signature_hook( diff --git a/test-data/unit/lib-stub/dataclasses.pyi b/test-data/unit/lib-stub/dataclasses.pyi index bd33b459266c..9c4bb80a46cd 100644 --- a/test-data/unit/lib-stub/dataclasses.pyi +++ b/test-data/unit/lib-stub/dataclasses.pyi @@ -30,5 +30,7 @@ def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...,) -> Any: ... +@overload +def replace(obj: _T, **changes: Any) -> _T: ... class Field(Generic[_T]): pass From ac7346b46dc1538aef99e5b551364099ff5cf5d0 Mon Sep 17 00:00:00 2001 From: Asyer Yonas Date: Mon, 12 Dec 2022 19:57:31 -0500 Subject: [PATCH 2/2] added test to make sure astuple, fields, and replace are called on dataclass instances --- test-data/unit/check-dataclasses.test | 45 +++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index c248f8db8585..8d4ce4732eea 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -1,3 +1,48 @@ +[case testNoCrashOnReplace] +def replace(__obj: _T, **changes: Any) -> _T: ... +# flags: --python-version 3.7 +from dataclasses import dataclass +from dataclasses import replace +@dataclass +class D: + pass +class C: + pass + +replace(C()) # E: replace() should be called on dataclass instances +[builtins fixtures/dataclasses.pyi] + + +[case testNoCrashOnAstuple] +def replace(__obj: _T, **changes: Any) -> _T: ... +# flags: --python-version 3.7 +from dataclasses import dataclass +from dataclasses import astuple + +@dataclass +class D: + pass +class C: + pass + +astuple(C()) # E: astuple() should be called on dataclass instances +[builtins fixtures/dataclasses.pyi] + +[case testNoCrashOnFields] +def replace(__obj: _T, **changes: Any) -> _T: ... +# flags: --python-version 3.7 +from dataclasses import dataclass +from dataclasses import fields + +@dataclass +class D: + pass +class C: + pass + +fields(C()) # E: astuple() should be called on dataclass instances +[builtins fixtures/dataclasses.pyi] + [case testDataclassesBasic] # flags: --python-version 3.7 from dataclasses import dataclass