Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from collections.abc import Iterable
from typing import Any, Callable, TypeVar, cast

_F = TypeVar("_F", bound=Callable[..., Any])
from typing_extensions import ParamSpec

_InputT = ParamSpec("_InputT")
_RetT = TypeVar("_RetT")


class DecoratorBase:
Expand All @@ -31,17 +34,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.args = args
self.kwargs = kwargs

def __call__(self, func: _F) -> _F:
def __call__(
self, func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
"""As an entry point for decorative applications"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
# Pretreatment parameters
processed_args, processed_kwargs = self.process(args, kwargs)
return func(*processed_args, **processed_kwargs)

wrapper.__signature__ = inspect.signature(func)
return cast("_F", wrapper)
return cast("Callable[_InputT, _RetT]", wrapper)

def process(
self, args: tuple[Any, ...], kwargs: dict[str, Any]
Expand Down Expand Up @@ -151,9 +156,9 @@ def process(


def param_one_alias(alias_list):
def decorator(func):
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if not kwargs:
return func(*args, **kwargs)
if (alias_list[0] not in kwargs) and (alias_list[1] in kwargs):
Expand All @@ -167,9 +172,9 @@ def wrapper(*args, **kwargs):


def param_two_alias(alias_list1, alias_list2):
def decorator(func):
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if not kwargs:
return func(*args, **kwargs)
if (alias_list1[0] not in kwargs) and (alias_list1[1] in kwargs):
Expand All @@ -185,9 +190,9 @@ def wrapper(*args, **kwargs):


def param_two_alias_one_default(alias_list1, alias_list2, default_param):
def decorator(func):
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if not kwargs:
return func(*args, **kwargs)

Expand Down Expand Up @@ -253,9 +258,9 @@ def process(


def view_decorator():
def decorator(func):
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if ("dtype" in kwargs) and ("shape_or_dtype" not in kwargs):
kwargs["shape_or_dtype"] = kwargs.pop("dtype")
elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs):
Expand All @@ -282,9 +287,9 @@ def reshape_decorator():
tensor_x.reshape(-1, 1, 3) -> paddle.reshape(tensor_x, -1, 1, 3])
"""

def decorator(func):
def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if ("input" in kwargs) and ("x" not in kwargs):
kwargs["x"] = kwargs.pop("input")
elif len(args) >= 2 and type(args[1]) is int:
Expand Down
Loading