Skip to content

Commit 9a87303

Browse files
Define the EXPERIMENTAL_PREFER_TRACING flag and the traceable option | feat(torchlib) (#1176)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1178 * #1177 * __->__ #1176 As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. --------- Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent 9e74858 commit 9a87303

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

onnxscript/function_libs/torch_lib/_flags.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ def _load_boolean_flag(
3939
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
4040
this_will="make initializers as inputs to the model graph",
4141
)
42+
EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
43+
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
44+
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
45+
)

onnxscript/function_libs/torch_lib/registration.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def torch_op(
9999
trace_only: bool = False,
100100
private: bool = False,
101101
complex: bool = False,
102+
traceable: bool = False,
102103
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
103104
"""Register a torch op.
104105
@@ -112,6 +113,18 @@ def torch_op(
112113
private: Whether the function is private (not directly exposed). It should
113114
be true for all functions with names starting with "_".
114115
complex: Whether the function expects complex-valued inputs.
116+
traceable: Whether the function can also be traced. This is an **experimental** flag.
117+
A function is traceable if it can both be scripted and traced to produce
118+
the same result for a given input. Specifically:
119+
120+
- A function _can_ be tagged with traceable if its if branches (if any)
121+
can be statically evaluated.
122+
- A function _should_ be tagged with traceable if it contains if branches
123+
and/or CastLike nodes so that they can be evaluated away with the
124+
EXPERIMENTAL_PREFER_TRACING on.
125+
- A function without if branches or CastLike nodes _should not_ be tagged
126+
with traceable because inlining will do the same thing.
127+
- A function with `@graph` defined for a `Scan` op is not traceable yet.
115128
"""
116129
if registry is None:
117130
registry = default_registry
@@ -128,6 +141,7 @@ def wrapper(
128141
else:
129142
assert isinstance(func, FunctionType)
130143
processed_func = onnxscript.script(opset=custom_opset)(func)
144+
processed_func.experimental_traceable = traceable
131145

132146
assert registry is not None
133147
for name_ in _check_and_normalize_names(name):

onnxscript/values.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ def __init__(
479479
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
480480
self._op_schema: Optional[onnx.defs.OpSchema] = None
481481

482+
# Experimental fields
483+
self.experimental_traceable = False
484+
482485
@property
483486
@deprecation.deprecated(
484487
since="0.1",

0 commit comments

Comments
 (0)