Skip to content

Define the EXPERIMENTAL_PREFER_TRACING flag and the traceable option | feat(torchlib) #1176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
75551b5
Define the EXPERIMENTAL_PREFER_TRACING flag
justinchuby Nov 22, 2023
9af95b9
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 22, 2023
f2b66c8
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 22, 2023
6f5e1a0
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 22, 2023
3c318ad
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 22, 2023
9feb804
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 23, 2023
567184c
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 23, 2023
ea8796f
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 23, 2023
698587c
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 23, 2023
34277ef
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 23, 2023
78d642a
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 23, 2023
7f2210f
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 23, 2023
688f677
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 23, 2023
eeb1ff7
Update base for Update on "Define the EXPERIMENTAL_PREFER_TRACING fla…
justinchuby Nov 27, 2023
b48a098
Update on "Define the EXPERIMENTAL_PREFER_TRACING flag and the tracea…
justinchuby Nov 27, 2023
567560c
Merge branch 'main' into gh/justinchuby/44/head
titaiwangms Nov 28, 2023
a28f765
Merge branch 'main' into gh/justinchuby/44/head
justinchuby Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
this_will="make initializers as inputs to the model graph",
)
EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
)
14 changes: 14 additions & 0 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def torch_op(
trace_only: bool = False,
private: bool = False,
complex: bool = False,
traceable: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.

Expand All @@ -112,6 +113,18 @@ def torch_op(
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function expects complex-valued inputs.
traceable: Whether the function can also be traced. This is an **experimental** flag.
A function is traceable if it can both be scripted and traced to produce
the same result for a given input. Specifically:

- A function _can_ be tagged with traceable if its if branches (if any)
can be statically evaluated.
- A function _should_ be tagged with traceable if it contains if branches
and/or CastLike nodes so that they can be evaluated away with the
EXPERIMENTAL_PREFER_TRACING on.
- A function without if branches or CastLike nodes _should not_ be tagged
with traceable because inlining will do the same thing.
- A function with `@graph` defined for a `Scan` op is not traceable yet.
"""
if registry is None:
registry = default_registry
Expand All @@ -128,6 +141,7 @@ def wrapper(
else:
assert isinstance(func, FunctionType)
processed_func = onnxscript.script(opset=custom_opset)(func)
processed_func.experimental_traceable = traceable

assert registry is not None
for name_ in _check_and_normalize_names(name):
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __init__(
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
self._op_schema: Optional[onnx.defs.OpSchema] = None

# Experimental fields
self.experimental_traceable = False

@property
@deprecation.deprecated(
since="0.1",
Expand Down