Skip to content

Commit ceace79

Browse files
committed
Create TracedOnnxFunction
1 parent a3454d6 commit ceace79

File tree

4 files changed

+229
-99
lines changed

4 files changed

+229
-99
lines changed

onnxscript/_internal/ast_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Utilities for working with Python ASTs."""
2+
3+
import ast
4+
import inspect
5+
import textwrap
6+
import types
7+
8+
9+
def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
10+
try:
11+
src = inspect.getsource(f)
12+
except OSError as e:
13+
raise RuntimeError(
14+
f"Decorator script does not work on dynamically "
15+
f"compiled function {f.__name__}."
16+
) from e
17+
src = textwrap.dedent(src)
18+
top_level_ast = ast.parse(src)
19+
assert isinstance(top_level_ast, ast.Module)
20+
assert len(top_level_ast.body) == 1
21+
f_ast = top_level_ast.body[0]
22+
assert isinstance(f_ast, ast.FunctionDef)
23+
return src, f_ast

onnxscript/main.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,14 @@
88
import ast
99
import inspect
1010
import sys
11-
import textwrap
1211
import types
1312
from typing import Any, Callable, Optional, Sequence, cast
1413

1514
import onnx.helper
1615

1716
import onnxscript
1817
from onnxscript import converter, irbuilder, values
19-
20-
21-
def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
22-
try:
23-
src = inspect.getsource(f)
24-
except OSError as e:
25-
raise RuntimeError(
26-
f"Decorator script does not work on dynamically "
27-
f"compiled function {f.__name__}."
28-
) from e
29-
src = textwrap.dedent(src)
30-
top_level_ast = ast.parse(src)
31-
assert isinstance(top_level_ast, ast.Module)
32-
assert len(top_level_ast.body) == 1
33-
f_ast = top_level_ast.body[0]
34-
assert isinstance(f_ast, ast.FunctionDef)
35-
return src, f_ast
36-
37-
38-
def get_ast(f: types.FunctionType) -> ast.FunctionDef:
39-
_, f_ast = get_src_and_ast(f)
40-
return f_ast
18+
from onnxscript._internal import ast_utils
4119

4220

4321
def script_check(
@@ -104,7 +82,7 @@ def transform(f: types.FunctionType) -> onnxscript.OnnxFunction:
10482
if not inspect.isfunction(f):
10583
raise TypeError("The ONNXScript decorator should be applied to functions only.")
10684

107-
src, f_ast = get_src_and_ast(f) # pylint: disable=redefined-outer-name
85+
src, f_ast = ast_utils.get_src_and_ast(f) # pylint: disable=redefined-outer-name
10886
# The script should be compiled using the globals/locals at the definition site.
10987
# This allows the script to reference names defined outside the script,
11088
# which is used for a few different purposes.

0 commit comments

Comments
 (0)