Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
75551b5
Define the EXPERIMENTAL_PREFER_TRACING flag
justinchuby Nov 22, 2023
b513a23
Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACING on | tes…
justinchuby Nov 22, 2023
581e0b4
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
ce12b39
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
e47392f
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 22, 2023
3bec3f8
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
1d83826
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 22, 2023
0e1b0e6
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 22, 2023
ae2388e
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
ffae748
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
599b841
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
c8692c3
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
3c2afff
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
8c1525e
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
ad0245f
Update base for Update on "Create env to test with TORCHLIB_EXPERIMEN…
justinchuby Nov 23, 2023
55bdb27
Update on "Create env to test with TORCHLIB_EXPERIMENTAL_PREFER_TRACI…
justinchuby Nov 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
- py310-torch-nightly
- py310-onnx-weekly
- py310-ort-nightly
- py311-ort-nightly
- py310-experimental-torchlib-tracing
include:
- name: py310
python-version: "3.10"
Expand All @@ -50,6 +52,12 @@ jobs:
- name: py310-ort-nightly
python-version: "3.10"
nox-tag: test-ort-nightly
- name: py311-ort-nightly
python-version: "3.11"
nox-tag: test-ort-nightly
- name: py310-experimental-torchlib-tracing
python-version: "3.10"
nox-tag: test-experimental-torchlib-tracing
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
Expand Down
19 changes: 18 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

COMMON_TEST_DEPENDENCIES = (
"jinja2",
"numpy==1.23.5",
"numpy==1.24.4",
"typing_extensions",
"beartype!=0.16.0",
"types-PyYAML",
Expand Down Expand Up @@ -95,3 +95,20 @@ def test_ort_nightly(session):
session.install(".", "--no-deps")
session.run("pip", "list")
session.run("pytest", "onnxscript", *session.posargs)


@nox.session(tags=["test-experimental-torchlib-tracing"])
def test_experimental_torchlib_tracing(session):
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
session.run(
"pytest",
"onnxscript/tests/function_libs/torch_lib/ops_test.py",
*session.posargs,
env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"},
)
4 changes: 4 additions & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
this_will="make initializers as inputs to the model graph",
)
EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_PREFER_TRACING",
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
)
3 changes: 3 additions & 0 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def torch_op(
trace_only: bool = False,
private: bool = False,
complex: bool = False,
traceable: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.

Expand All @@ -112,6 +113,7 @@ def torch_op(
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function expects complex-valued inputs.
traceable: Whether the function can be traced.
"""
if registry is None:
registry = default_registry
Expand All @@ -128,6 +130,7 @@ def wrapper(
else:
assert isinstance(func, FunctionType)
processed_func = onnxscript.script(opset=custom_opset)(func)
processed_func.experimental_traceable = traceable

assert registry is not None
for name_ in _check_and_normalize_names(name):
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __init__(
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
self._op_schema: Optional[onnx.defs.OpSchema] = None

# Experimental fields
self.experimental_traceable = False

@property
@deprecation.deprecated(
since="0.1",
Expand Down