Skip to content

Commit 80e62a4

Browse files
committed
Auto OpSchema for trace_only functions | feat(torchlib)
ghstack-source-id: b25bb51 Pull Request resolved: #674 Signed-off-by: Justin Chu <[email protected]>
1 parent ceace79 commit 80e62a4

File tree

4 files changed

+122
-6
lines changed

4 files changed

+122
-6
lines changed

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Optional
77

88
import onnxscript
9+
from onnxscript.function_libs.torch_lib import tracing
910

1011

1112
class OverloadedFunction:
@@ -67,7 +68,7 @@ def torch_op(
6768
registry: Optional[Registry] = None,
6869
trace_only: bool = False,
6970
private: bool = False,
70-
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction | Callable[..., Any]]:
71+
) -> Callable[[FunctionType], onnxscript.OnnxFunction | tracing.TraceOnlyFunction]:
7172
"""Register a torch op.
7273
7374
Args:
@@ -81,12 +82,14 @@ def torch_op(
8182
if registry is None:
8283
registry = default_registry
8384

84-
def wrapper(func: Callable[..., Any]) -> onnxscript.OnnxFunction | Callable[..., Any]:
85+
def wrapper(func: FunctionType) -> onnxscript.OnnxFunction | tracing.TraceOnlyFunction:
86+
# Compile the function
87+
custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1)
88+
89+
processed_func: onnxscript.OnnxFunction | tracing.TraceOnlyFunction
8590
if trace_only:
86-
processed_func = func
91+
processed_func = tracing.TraceOnlyFunction(custom_opset, func)
8792
else:
88-
# Compile the function
89-
custom_opset = onnxscript.values.Opset(domain="onnxscript.atenlib", version=1)
9093
assert isinstance(func, FunctionType)
9194
processed_func = onnxscript.script(opset=custom_opset)(func)
9295

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Module for trace_only functions."""
2+
from __future__ import annotations
3+
4+
import ast
5+
import inspect
6+
import textwrap
7+
import types
8+
from typing import Optional
9+
10+
import onnx
11+
12+
import onnxscript
13+
from onnxscript import converter as ons_converter
14+
from onnxscript._internal import version_utils
15+
16+
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")
17+
18+
19+
def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
20+
try:
21+
src = inspect.getsource(f)
22+
except OSError as e:
23+
raise RuntimeError(
24+
f"Decorator script does not work on dynamically "
25+
f"compiled function {f.__name__}."
26+
) from e
27+
src = textwrap.dedent(src)
28+
top_level_ast = ast.parse(src)
29+
assert isinstance(top_level_ast, ast.Module)
30+
assert len(top_level_ast.body) == 1
31+
f_ast = top_level_ast.body[0]
32+
assert isinstance(f_ast, ast.FunctionDef)
33+
return src, f_ast
34+
35+
36+
class TraceOnlyFunction:
37+
"""TraceOnlyFunction.
38+
39+
Attributes:
40+
name: Name of the op. E.g. "aten::add".
41+
func: Function.
42+
"""
43+
44+
def __init__(self, opset: onnxscript.values.Opset, func: types.FunctionType):
45+
self._opset = opset
46+
self._func = func
47+
self._opschema: Optional[onnx.defs.OpSchema] = None
48+
# Set the signature of the class to function's
49+
self.__signature__ = inspect.signature(func)
50+
51+
def __call__(self, *args, **kwargs):
52+
return self._func(*args, **kwargs)
53+
54+
def __repr__(self):
55+
return f"TraceOnlyFunction({self!r})"
56+
57+
@property
58+
def name(self) -> str:
59+
"""Return the name of the op."""
60+
return self._func.__name__
61+
62+
@property
63+
def source(self) -> str:
64+
"""Return the source of the op."""
65+
return inspect.getsource(self._func)
66+
67+
@property
68+
def opset(self) -> onnxscript.values.Opset:
69+
"""Return the opset."""
70+
return self._opset
71+
72+
@property
73+
def opschema(self) -> Optional[onnx.defs.OpSchema]:
74+
"""Return the opschema."""
75+
76+
if self._opschema is not None:
77+
return self._opschema
78+
79+
if not _ONNX_OP_SCHEMA_WRITABLE:
80+
return None
81+
82+
src, func_ast = _get_src_and_ast(self._func)
83+
module = inspect.getmodule(self._func)
84+
closure = inspect.getclosurevars(self._func)
85+
global_names = module.__dict__.copy()
86+
global_names.update(closure.nonlocals)
87+
converter = ons_converter.Converter(
88+
opset=self._opset,
89+
global_names=global_names,
90+
source=src,
91+
)
92+
93+
function_ir = converter.translate_function_signature(func_ast)
94+
95+
# FIXME(justinchuby): outputs are empty. Need to fix.
96+
self._opschema = onnxscript.values.op_schema_from_function_ir(function_ir, self._opset)
97+
98+
return self._opschema

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,19 @@ def test_script_function_has_op_schema(self, _, func_with_wrangler):
119119
self.assertIsNotNone(schema)
120120
self.assertEqual(schema.name, func.name)
121121

122+
@parameterized.parameterized.expand(
123+
list(ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.items())
124+
)
125+
@unittest.skipIf(
126+
version_utils.onnx_older_than("1.15"),
127+
"OpSchema is not writable before ONNX 1.15",
128+
)
129+
def test_trace_only_function_has_op_schema(self, _, func_with_wrangler):
130+
func, _ = _split_function_and_wrangler(func_with_wrangler)
131+
schema = func.opschema
132+
self.assertIsNotNone(schema)
133+
self.assertEqual(schema.name, func.name)
134+
122135

123136
def run_test_output_match(
124137
test_suite: unittest.TestCase,

onnxscript/values.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ class OnnxFunction(Op):
431431

432432
def __init__(
433433
self,
434-
opset: Opset,
434+
opset: Optional[Opset],
435435
pyfun: types.FunctionType,
436436
irfun: irbuilder.IRFunction,
437437
source: str,
@@ -455,6 +455,8 @@ def __init__(
455455
self.kwargs = kwargs
456456
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
457457
self._opschema: Optional[onnx.defs.OpSchema] = None
458+
# Set the signature of the class to function's
459+
self.__signature__ = inspect.signature(pyfun)
458460

459461
@property
460462
def opschema(self) -> Optional[onnx.defs.OpSchema]:

0 commit comments

Comments
 (0)