Skip to content

Commit 9e5a16e

Browse files
committed
Auto OpSchema for trace_only functions | feat(torchlib)
ghstack-source-id: 3b97c3f Pull Request resolved: #674 Signed-off-by: Justin Chu <[email protected]>
1 parent a3454d6 commit 9e5a16e

File tree

7 files changed

+262
-108
lines changed

7 files changed

+262
-108
lines changed

onnxscript/_internal/ast_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Utilities for working with Python ASTs."""
2+
3+
import ast
4+
import inspect
5+
import textwrap
6+
import types
7+
8+
9+
def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
10+
try:
11+
src = inspect.getsource(f)
12+
except OSError as e:
13+
raise RuntimeError(
14+
f"Decorator script does not work on dynamically "
15+
f"compiled function {f.__name__}."
16+
) from e
17+
src = textwrap.dedent(src)
18+
top_level_ast = ast.parse(src)
19+
assert isinstance(top_level_ast, ast.Module)
20+
assert len(top_level_ast.body) == 1
21+
f_ast = top_level_ast.body[0]
22+
assert isinstance(f_ast, ast.FunctionDef)
23+
return src, f_ast

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def torch_op(
6767
registry: Optional[Registry] = None,
6868
trace_only: bool = False,
6969
private: bool = False,
70-
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]:
70+
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
7171
"""Register a torch op.
7272
7373
Args:
@@ -81,12 +81,16 @@ def torch_op(
8181
if registry is None:
8282
registry = default_registry
8383

84-
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]:
84+
def wrapper(
85+
func: FunctionType,
86+
) -> onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction:
87+
# Compile the function
88+
custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1)
89+
90+
processed_func: onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction
8591
if trace_only:
86-
processed_func = func
92+
processed_func = onnxscript.values.TracedOnnxFunction(custom_opset, func)
8793
else:
88-
# Compile the function
89-
custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1)
9094
assert isinstance(func, FunctionType)
9195
processed_func = onnxscript.script(opset=custom_opset)(func)
9296

onnxscript/irbuilder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __str__(self):
202202

203203
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
204204
domain = self.callee.opset.domain
205-
opname = self.callee.opname
205+
opname = self.callee.name
206206
callee = f"{domain}.{opname}" if (domain != "") else opname
207207
return f"{lhs} = {callee} {attrs}{args}"
208208

@@ -212,7 +212,7 @@ def debug_print(self):
212212

213213
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
214214
n = helper.make_node(
215-
self.callee.opname,
215+
self.callee.name,
216216
[_opt_var_to_str(x) for x in self.args],
217217
[str(x) for x in self.result],
218218
domain=self.callee.opset.domain,

onnxscript/main.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,14 @@
88
import ast
99
import inspect
1010
import sys
11-
import textwrap
1211
import types
1312
from typing import Any, Callable, Optional, Sequence, cast
1413

1514
import onnx.helper
1615

1716
import onnxscript
1817
from onnxscript import converter, irbuilder, values
19-
20-
21-
def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
22-
try:
23-
src = inspect.getsource(f)
24-
except OSError as e:
25-
raise RuntimeError(
26-
f"Decorator script does not work on dynamically "
27-
f"compiled function {f.__name__}."
28-
) from e
29-
src = textwrap.dedent(src)
30-
top_level_ast = ast.parse(src)
31-
assert isinstance(top_level_ast, ast.Module)
32-
assert len(top_level_ast.body) == 1
33-
f_ast = top_level_ast.body[0]
34-
assert isinstance(f_ast, ast.FunctionDef)
35-
return src, f_ast
36-
37-
38-
def get_ast(f: types.FunctionType) -> ast.FunctionDef:
39-
_, f_ast = get_src_and_ast(f)
40-
return f_ast
18+
from onnxscript._internal import ast_utils
4119

4220

4321
def script_check(
@@ -104,7 +82,7 @@ def transform(f: types.FunctionType) -> onnxscript.OnnxFunction:
10482
if not inspect.isfunction(f):
10583
raise TypeError("The ONNXScript decorator should be applied to functions only.")
10684

107-
src, f_ast = get_src_and_ast(f) # pylint: disable=redefined-outer-name
85+
src, f_ast = ast_utils.get_src_and_ast(f)
10886
# The script should be compiled using the globals/locals at the definition site.
10987
# This allows the script to reference names defined outside the script,
11088
# which is used for a few different purposes.

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,19 @@ def test_script_function_has_op_schema(self, _, func_with_wrangler):
119119
self.assertIsNotNone(schema)
120120
self.assertEqual(schema.name, func.name)
121121

122+
@parameterized.parameterized.expand(
123+
list(ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.items())
124+
)
125+
@unittest.skipIf(
126+
version_utils.onnx_older_than("1.15"),
127+
"OpSchema is not writable before ONNX 1.15",
128+
)
129+
def test_trace_only_function_has_op_schema(self, _, func_with_wrangler):
130+
func, _ = _split_function_and_wrangler(func_with_wrangler)
131+
schema = func.opschema
132+
self.assertIsNotNone(schema)
133+
self.assertEqual(schema.name, func.name)
134+
122135

123136
def run_test_output_match(
124137
test_suite: unittest.TestCase,

0 commit comments

Comments
 (0)