5
5
import inspect
6
6
import textwrap
7
7
import types
8
- from typing import Optional
8
+ import typing
9
+ from typing import Optional , Tuple
9
10
10
11
import onnx
11
12
12
13
import onnxscript
13
14
from onnxscript import converter as ons_converter
14
15
from onnxscript ._internal import version_utils
15
16
17
+ if typing .TYPE_CHECKING :
18
+ from onnxscript import irbuilder
19
+
16
20
_ONNX_OP_SCHEMA_WRITABLE = not version_utils .onnx_older_than ("1.14" )
17
21
18
22
@@ -33,7 +37,7 @@ def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
33
37
return src , f_ast
34
38
35
39
36
- class TraceOnlyFunction :
40
+ class TraceOnlyFunction ( onnxscript . values . OpLike ) :
37
41
"""TraceOnlyFunction.
38
42
39
43
Attributes:
@@ -44,9 +48,11 @@ class TraceOnlyFunction:
44
48
def __init__ (self , opset : onnxscript .values .Opset , func : types .FunctionType ):
45
49
self ._opset = opset
46
50
self ._func = func
47
- self ._opschema : Optional [onnx .defs .OpSchema ] = None
48
51
# Set the signature of the class to function's
49
52
self .__signature__ = inspect .signature (func )
53
+ # Cached computed fields
54
+ self ._opschema : Optional [onnx .defs .OpSchema ] = None
55
+ self ._param_schemas : Optional [Tuple [onnxscript .values .ParamSchema , ...]] = None
50
56
51
57
def __call__ (self , * args , ** kwargs ):
52
58
return self ._func (* args , ** kwargs )
@@ -72,13 +78,32 @@ def opset(self) -> onnxscript.values.Opset:
72
78
@property
73
79
def opschema (self ) -> Optional [onnx .defs .OpSchema ]:
74
80
"""Return the opschema."""
75
-
76
81
if self ._opschema is not None :
77
82
return self ._opschema
78
-
79
83
if not _ONNX_OP_SCHEMA_WRITABLE :
80
84
return None
81
85
86
+ # FIXME(justinchuby): outputs are empty. Need to fix.
87
+ self ._opschema = onnxscript .values .op_schema_from_function_ir (
88
+ self ._function_ir (), self ._opset
89
+ )
90
+
91
+ return self ._opschema
92
+
93
+ def param_schemas (self ) -> tuple [onnxscript .values .ParamSchema , ...]:
94
+ """Generate param_schemas for the TraceOnlyFunction."""
95
+ if self ._param_schemas is None :
96
+ self ._param_schemas = onnxscript .values .param_schemas_from_function_ir (
97
+ self ._function_ir ()
98
+ )
99
+
100
+ return self ._param_schemas
101
+
102
+ def _function_ir (self ) -> irbuilder .IRFunction :
103
+ """Return the IRFunction of the function.
104
+
105
+ This IRFunction contains only the function signature.
106
+ """
82
107
src , func_ast = _get_src_and_ast (self ._func )
83
108
module = inspect .getmodule (self ._func )
84
109
closure = inspect .getclosurevars (self ._func )
@@ -90,9 +115,4 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]:
90
115
source = src ,
91
116
)
92
117
93
- function_ir = converter .translate_function_signature (func_ast )
94
-
95
- # FIXME(justinchuby): outputs are empty. Need to fix.
96
- self ._opschema = onnxscript .values .op_schema_from_function_ir (function_ir , self ._opset )
97
-
98
- return self ._opschema
118
+ return converter .translate_function_signature (func_ast )
0 commit comments