Skip to content

Commit 02edd76

Browse files
committed
Create the OpLike protocol and refactor Op | feat(values)
- Removes `is_single_op` because it is unused. Signed-off-by: Justin Chu <justinchumicrosoft.com> ghstack-source-id: 853c076 Pull Request resolved: #692 Signed-off-by: Justin Chu <[email protected]>
1 parent 80e62a4 commit 02edd76

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

onnxscript/function_libs/torch_lib/tracing.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
import inspect
66
import textwrap
77
import types
8-
from typing import Optional
8+
import typing
9+
from typing import Optional, Tuple
910

1011
import onnx
1112

1213
import onnxscript
1314
from onnxscript import converter as ons_converter
1415
from onnxscript._internal import version_utils
1516

17+
if typing.TYPE_CHECKING:
18+
from onnxscript import irbuilder
19+
1620
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")
1721

1822

@@ -33,7 +37,7 @@ def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
3337
return src, f_ast
3438

3539

36-
class TraceOnlyFunction:
40+
class TraceOnlyFunction(onnxscript.values.OpLike):
3741
"""TraceOnlyFunction.
3842
3943
Attributes:
@@ -44,9 +48,11 @@ class TraceOnlyFunction:
4448
def __init__(self, opset: onnxscript.values.Opset, func: types.FunctionType):
4549
self._opset = opset
4650
self._func = func
47-
self._opschema: Optional[onnx.defs.OpSchema] = None
4851
# Set the signature of the class to function's
4952
self.__signature__ = inspect.signature(func)
53+
# Cached computed fields
54+
self._opschema: Optional[onnx.defs.OpSchema] = None
55+
self._param_schemas: Optional[Tuple[onnxscript.values.ParamSchema, ...]] = None
5056

5157
def __call__(self, *args, **kwargs):
5258
return self._func(*args, **kwargs)
@@ -72,13 +78,32 @@ def opset(self) -> onnxscript.values.Opset:
7278
@property
7379
def opschema(self) -> Optional[onnx.defs.OpSchema]:
7480
"""Return the opschema."""
75-
7681
if self._opschema is not None:
7782
return self._opschema
78-
7983
if not _ONNX_OP_SCHEMA_WRITABLE:
8084
return None
8185

86+
# FIXME(justinchuby): outputs are empty. Need to fix.
87+
self._opschema = onnxscript.values.op_schema_from_function_ir(
88+
self._function_ir(), self._opset
89+
)
90+
91+
return self._opschema
92+
93+
def param_schemas(self) -> tuple[onnxscript.values.ParamSchema, ...]:
94+
"""Generate param_schemas for the TraceOnlyFunction."""
95+
if self._param_schemas is None:
96+
self._param_schemas = onnxscript.values.param_schemas_from_function_ir(
97+
self._function_ir()
98+
)
99+
100+
return self._param_schemas
101+
102+
def _function_ir(self) -> irbuilder.IRFunction:
103+
"""Return the IRFunction of the function.
104+
105+
This IRFunction contains only the function signature.
106+
"""
82107
src, func_ast = _get_src_and_ast(self._func)
83108
module = inspect.getmodule(self._func)
84109
closure = inspect.getclosurevars(self._func)
@@ -90,9 +115,4 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]:
90115
source=src,
91116
)
92117

93-
function_ir = converter.translate_function_signature(func_ast)
94-
95-
# FIXME(justinchuby): outputs are empty. Need to fix.
96-
self._opschema = onnxscript.values.op_schema_from_function_ir(function_ir, self._opset)
97-
98-
return self._opschema
118+
return converter.translate_function_signature(func_ast)

onnxscript/irbuilder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __str__(self):
202202

203203
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
204204
domain = self.callee.opset.domain
205-
opname = self.callee.opname
205+
opname = self.callee.name
206206
callee = f"{domain}.{opname}" if (domain != "") else opname
207207
return f"{lhs} = {callee} {attrs}{args}"
208208

@@ -212,7 +212,7 @@ def debug_print(self):
212212

213213
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
214214
n = helper.make_node(
215-
self.callee.opname,
215+
self.callee.name,
216216
[_opt_var_to_str(x) for x in self.args],
217217
[str(x) for x in self.result],
218218
domain=self.callee.opset.domain,

0 commit comments

Comments
 (0)