Skip to content

Commit 8f71f1a

Browse files
authored
Annotate script() with ParamSpec for more accurate typing (#2178)
This pull request introduces type parameterization using `TypeVar` and `ParamSpec` to enhance type safety and flexibility in the `onnxscript` module. ### Type Parameterization Enhancements: * [`onnxscript/main.py`](diffhunk://#diff-1f9f494aa46ce42a47c4191deb91aaded1b216b24accc43685561281507e7ca8L9-R20): Introduced `_R` and `_P` type variables, and updated the `script` decorator and `transform` function signatures to use `Callable[_P, _R]` for better type inference. [[1]](diffhunk://#diff-1f9f494aa46ce42a47c4191deb91aaded1b216b24accc43685561281507e7ca8L9-R20) [[2]](diffhunk://#diff-1f9f494aa46ce42a47c4191deb91aaded1b216b24accc43685561281507e7ca8L42-R46) [[3]](diffhunk://#diff-1f9f494aa46ce42a47c4191deb91aaded1b216b24accc43685561281507e7ca8L78-R82) * [`onnxscript/values.py`](diffhunk://#diff-9625fb4ad20b7aa13388f751c0fde3a809f2b6c91023413a02ea2249f9071248R16-R36): Added `Generic`, `TypeVar`, and `ParamSpec` imports, and updated the `OnnxFunction` class to inherit from `Generic[_P, _R]`. Modified the `__call__` method to use `_P.args` and `_P.kwargs` for improved type checking. [[1]](diffhunk://#diff-9625fb4ad20b7aa13388f751c0fde3a809f2b6c91023413a02ea2249f9071248R16-R36) [[2]](diffhunk://#diff-9625fb4ad20b7aa13388f751c0fde3a809f2b6c91023413a02ea2249f9071248L467-R474) [[3]](diffhunk://#diff-9625fb4ad20b7aa13388f751c0fde3a809f2b6c91023413a02ea2249f9071248L569-R581)
1 parent 634148e commit 8f71f1a

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

onnxscript/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
import ast
77
import inspect
88
import sys
9-
from typing import Any, Callable, Optional, Sequence
9+
from typing import Any, Callable, Optional, Sequence, TypeVar
1010

1111
import onnx.helper
12+
from typing_extensions import ParamSpec
1213

1314
import onnxscript
1415
from onnxscript import converter, irbuilder, values
1516
from onnxscript._internal import ast_utils
1617

18+
_R = TypeVar("_R")
19+
_P = ParamSpec("_P")
20+
1721

1822
def script_check(
1923
f: ast.FunctionDef,
@@ -39,7 +43,7 @@ def script(
3943
opset: Optional[values.Opset] = None,
4044
default_opset: Optional[values.Opset] = None,
4145
**kwargs: Any,
42-
) -> Callable[[Callable], onnxscript.OnnxFunction]:
46+
) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]:
4347
"""Main decorator. Declares a function as an onnx function.
4448
4549
Args:
@@ -75,7 +79,7 @@ def log2(x):
7579
"Script parameter must be an opset. Did you use @script instead of @script()?"
7680
)
7781

78-
def transform(f: Callable) -> onnxscript.OnnxFunction:
82+
def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]:
7983
if not inspect.isfunction(f):
8084
raise TypeError("The ONNXScript decorator should be applied to functions only.")
8185

onnxscript/values.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,27 @@
1313
Any,
1414
Callable,
1515
ClassVar,
16+
Generic,
1617
Optional,
1718
Protocol,
1819
Sequence,
20+
TypeVar,
1921
_GenericAlias,
2022
)
2123

2224
import onnx
2325
import onnx.defs
26+
from typing_extensions import ParamSpec
2427

2528
from onnxscript import converter as converter_module
2629
from onnxscript import irbuilder, sourceinfo, type_annotation
2730
from onnxscript._internal import ast_utils, deprecation
2831
from onnxscript.ir import _schemas
2932

33+
_R = TypeVar("_R")
34+
_P = ParamSpec("_P")
35+
36+
3037
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
3138
onnx.defs.OpSchema.AttrType.FLOAT: float,
3239
onnx.defs.OpSchema.AttrType.INT: int,
@@ -464,7 +471,7 @@ def _op_schema_from_function_ir(
464471
)
465472

466473

467-
class OnnxFunction(Op):
474+
class OnnxFunction(Op, Generic[_P, _R]):
468475
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
469476
470477
Attributes:
@@ -566,12 +573,12 @@ def fun(*args, **kwargs):
566573

567574
return fun
568575

569-
def __call__(self, *args, **kwargs):
576+
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
570577
"""Implements an eager-mode execution of an onnxscript function."""
571578
# FIXME(after #225): Move import to the top of the file.
572579
from onnxscript import evaluator # pylint: disable=import-outside-toplevel
573580

574-
return evaluator.default().eval_function(self, args, kwargs)
581+
return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value]
575582

576583
def __repr__(self) -> str:
577584
return f"{self.__class__.__name__}({self.function!r})"

0 commit comments

Comments
 (0)