|
13 | 13 | Any,
|
14 | 14 | Callable,
|
15 | 15 | ClassVar,
|
| 16 | + Generic, |
16 | 17 | Optional,
|
17 | 18 | Protocol,
|
18 | 19 | Sequence,
|
| 20 | + TypeVar, |
19 | 21 | _GenericAlias,
|
20 | 22 | )
|
21 | 23 |
|
22 | 24 | import onnx
|
23 | 25 | import onnx.defs
|
| 26 | +from typing_extensions import ParamSpec |
24 | 27 |
|
25 | 28 | from onnxscript import converter as converter_module
|
26 | 29 | from onnxscript import irbuilder, sourceinfo, type_annotation
|
27 | 30 | from onnxscript._internal import ast_utils, deprecation
|
28 | 31 | from onnxscript.ir import _schemas
|
29 | 32 |
|
| 33 | +_R = TypeVar("_R") |
| 34 | +_P = ParamSpec("_P") |
| 35 | + |
| 36 | + |
30 | 37 | _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
|
31 | 38 | onnx.defs.OpSchema.AttrType.FLOAT: float,
|
32 | 39 | onnx.defs.OpSchema.AttrType.INT: int,
|
@@ -464,7 +471,7 @@ def _op_schema_from_function_ir(
|
464 | 471 | )
|
465 | 472 |
|
466 | 473 |
|
467 |
| -class OnnxFunction(Op): |
| 474 | +class OnnxFunction(Op, Generic[_P, _R]): |
468 | 475 | """Represents an ONNX op for which a function-body has been defined in onnxscript.
|
469 | 476 |
|
470 | 477 | Attributes:
|
@@ -566,12 +573,12 @@ def fun(*args, **kwargs):
|
566 | 573 |
|
567 | 574 | return fun
|
568 | 575 |
|
569 |
| - def __call__(self, *args, **kwargs): |
| 576 | + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: |
570 | 577 | """Implements an eager-mode execution of an onnxscript function."""
|
571 | 578 | # FIXME(after #225): Move import to the top of the file.
|
572 | 579 | from onnxscript import evaluator # pylint: disable=import-outside-toplevel
|
573 | 580 |
|
574 |
| - return evaluator.default().eval_function(self, args, kwargs) |
| 581 | + return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] |
575 | 582 |
|
576 | 583 | def __repr__(self) -> str:
|
577 | 584 | return f"{self.__class__.__name__}({self.function!r})"
|
|
0 commit comments