Skip to content

Implementation for validation checks on functions that can only be called on dataclasses #14286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Optional
from typing_extensions import Final
import mypy

from mypy.expandtype import expand_type
from mypy.nodes import (
Expand Down Expand Up @@ -630,6 +631,23 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
transformer = DataclassTransformer(ctx)
return transformer.transform()

def dataclass_check_type_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Will raise an error if the function is not called on a dataclass
instance. Examples of these functions are replace() and field() functions.

This callback is called in default.py
"""
if ctx != None:
if len(ctx.arg_types) != 0:
if len(ctx.arg_types[0]) != 0:
#this means there's no dataclass_tag in the metadata, so the argument is not a dataclass
if 'dataclass' not in (ctx.arg_types[0][0].type.metadata):
ctx.api.msg.fail(
"Calling replace on a non-dataclass.",
ctx.context
)
return ctx.default_return_type


def _collect_field_args(
expr: Expression, ctx: ClassDefContext
Expand Down
5 changes: 4 additions & 1 deletion mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""

def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
from mypy.plugins import ctypes, singledispatch
from mypy.plugins import ctypes, singledispatch, dataclasses

if fullname == "ctypes.Array":
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
#these functions should only be called on dataclass instances
elif fullname == "dataclasses.replace" or fullname == "dataclasses.field":
return dataclasses.dataclass_check_type_callback
return None

def get_method_signature_hook(
Expand Down
45 changes: 45 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
@@ -1,3 +1,48 @@
[case testNoCrashOnReplace]
def replace(__obj: _T, **changes: Any) -> _T: ...
# flags: --python-version 3.7
from dataclasses import dataclass
from dataclasses import replace
@dataclass
class D:
pass
class C:
pass

replace(C()) # E: replace() should be called on dataclass instances
[builtins fixtures/dataclasses.pyi]


[case testNoCrashOnAstuple]
def replace(__obj: _T, **changes: Any) -> _T: ...
# flags: --python-version 3.7
from dataclasses import dataclass
from dataclasses import astuple

@dataclass
class D:
pass
class C:
pass

astuple(C()) # E: astuple() should be called on dataclass instances
[builtins fixtures/dataclasses.pyi]

[case testNoCrashOnFields]
def replace(__obj: _T, **changes: Any) -> _T: ...
# flags: --python-version 3.7
from dataclasses import dataclass
from dataclasses import fields

@dataclass
class D:
pass
class C:
pass

fields(C()) # E: astuple() should be called on dataclass instances
[builtins fixtures/dataclasses.pyi]

[case testDataclassesBasic]
# flags: --python-version 3.7
from dataclasses import dataclass
Expand Down
2 changes: 2 additions & 0 deletions test-data/unit/lib-stub/dataclasses.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ def field(*,
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...,) -> Any: ...

@overload
def replace(obj: _T, **changes: Any) -> _T: ...

class Field(Generic[_T]): pass