Skip to content

Commit 6388d95

Browse files
authored
Merge branch 'main' into justinchu/ser-initializer
2 parents ae831e6 + 9551e98 commit 6388d95

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2236
-445
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ jobs:
3131
- py311-onnx-weekly
3232
- py311-ort-nightly
3333
- py311-experimental-torchlib-tracing
34-
- py311-experimental-torchlib-onnx-ir
3534
- py310
3635
- py39
37-
- py38
3836
include:
3937
- name: py311
4038
python-version: "3.11"
@@ -45,9 +43,6 @@ jobs:
4543
- name: py39
4644
python-version: "3.9"
4745
nox-tag: test
48-
- name: py38
49-
python-version: "3.8"
50-
nox-tag: test
5146
- name: py312-torch-nightly
5247
python-version: "3.12"
5348
nox-tag: test-torch-nightly
@@ -63,9 +58,6 @@ jobs:
6358
- name: py311-experimental-torchlib-tracing
6459
python-version: "3.11"
6560
nox-tag: test-experimental-torchlib-tracing
66-
- name: py311-experimental-torchlib-onnx-ir
67-
python-version: "3.11"
68-
nox-tag: test-experimental-torchlib-onnx-ir
6961
runs-on: ${{ matrix.os }}
7062
steps:
7163
- uses: actions/checkout@v4
@@ -105,7 +97,7 @@ jobs:
10597
fail-fast: false
10698
matrix:
10799
os: [ubuntu-latest]
108-
transformers: ["4.37.2", "4.41.2"]
100+
transformers: ["4.37.2", "4.41.2", "4.42.3"]
109101
torch: ["release", "nightly"]
110102
python_version: ["3.11"]
111103
nox-tag: ["test-dort"]

README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ models using a subset of Python. ONNX Script is:
1515
* **Debuggable:** allows for eager-mode evaluation that provides for a
1616
more delightful ONNX model debugging experience.
1717

18+
This repo also covers:
19+
20+
* **ONNX IR:** an in-memory IR that supports the full ONNX spec, designed
21+
for graph construction, analysis and transformation.
22+
* **ONNX Script Optimizer:** provides functionality to optimize an ONNX
23+
model by performing optimizations and clean-ups such as constant folding,
24+
dead code elimination, etc.
25+
* **ONNX Rewriter:** provides functionality to replace certain patterns in
26+
an ONNX graph with replacement patterns based on user-defined rewrite rules.
27+
1828
Note however that ONNX Script does **not** intend to support the entirety
1929
of the Python language.
2030

@@ -142,6 +152,85 @@ result = Hardmax(v)
142152

143153
More examples can be found in the [docs/examples](docs/examples) directory.
144154

155+
## ONNX IR
156+
157+
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
158+
159+
### Features
160+
161+
* **Full ONNX spec support:** all valid models representable by ONNX protobuf,
162+
and a subset of invalid models (so you can load and fix them).
163+
* **Low memory footprint:** mmap'ed external tensors; unified interface for
164+
ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size
165+
limitation. Zero copies.
166+
* **Straightforward access patterns:** Access value information and traverse the
167+
graph topology at ease.
168+
* **Robust mutation:** Create as many iterators as you like on the graph while mutating it.
169+
* **Speed:** Performant graph manipulation, serialization/deserialization to Protobuf.
170+
* **Pythonic and familiar APIs:** Classes define Pythonic apis and still map to
171+
ONNX protobuf concepts in an intuitive way.
172+
173+
## ONNX Script Tools
174+
175+
### ONNX Optimizer
176+
177+
The ONNX Script Optimizer tool provides the user with the functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc. In order to utilize the optimizer tool:
178+
179+
```python
180+
import onnxscript
181+
182+
onnxscript.optimizer.optimize(onnx_model)
183+
```
184+
185+
For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://onnxscript.ai/tutorial/optimizer/optimize.html)
186+
187+
### ONNX Rewriter
188+
189+
The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on user-defined rewrite rules. The rewriter tools allows two different methods in which patterns in the graph can be rewritten.
190+
191+
### Pattern-based rewriting
192+
193+
For this style of rewriting, the user provides a `target_pattern` that is to be replaced, a `replacement_pattern` and a `match_condition` (pattern rewrite will occur only if the match condition is satisfied). A simple example on how to use the pattern-based rewriting tool is as follows:
194+
195+
```python
196+
from onnxscript.rewriter import pattern
197+
198+
# The target pattern
199+
def erf_gelu_pattern(op, x):
200+
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
201+
202+
def erf_gelu_pattern_2(op, x):
203+
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
204+
205+
# The replacement pattern
206+
def gelu(op, x: ir.Value):
207+
return op.Gelu(x, domain="com.microsoft")
208+
209+
# Create multiple rules
210+
rule1 = pattern.RewriteRule(
211+
erf_gelu_pattern, # Target Pattern
212+
gelu, # Replacement
213+
)
214+
rule2 = pattern.RewriteRule(
215+
erf_gelu_pattern_2, # Target Pattern
216+
gelu, # Replacement
217+
)
218+
# Create a Rewrite Rule Set with multiple rules.
219+
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
220+
# Apply rewrites
221+
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
222+
model, # Original ONNX Model
223+
pattern_rewrite_rules=rewrite_rule_set,
224+
)
225+
return model_with_rewrite_applied
226+
```
227+
228+
For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html)
229+
230+
### Function-based rewriting
231+
232+
This style of rewriting matches a `FUNCTION_KEYWORD` and `PACKAGE_NAME` provided by the user to an existing function within the graph and replaces it with a new function provided by the user.
233+
145234
## Development Guidelines
146235

147236
Every change impacting the converter or the eager evaluation must be

docs/test/test_documentation_examples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def do_test_folder(self, folder):
3434
if tested == 0:
3535
raise RuntimeError(f"No example was tested in folder {folder}.")
3636

37+
@unittest.skipIf(
38+
sys.platform != "linux", reason="No need to run the documentation on every OS."
39+
)
3740
def test_documentation_examples(self):
3841
this = os.path.abspath(os.path.dirname(__file__))
3942
onxc = os.path.normpath(os.path.join(this, "..", ".."))

noxfile.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
'numpy==1.26.4; python_version>="3.9"',
2020
"packaging",
2121
"parameterized",
22-
"psutil",
22+
'psutil; sys_platform != "win32"',
2323
"pytest-cov",
2424
"pytest-randomly",
2525
"pytest-subtests",
@@ -28,13 +28,13 @@
2828
"pyyaml",
2929
"types-PyYAML",
3030
"typing_extensions",
31-
"ml_dtypes",
31+
"ml-dtypes",
3232
)
3333
ONNX = "onnx==1.16"
3434
ONNX_RUNTIME = "onnxruntime==1.17.1"
3535
PYTORCH = "torch==2.2.2"
3636
TORCHVISON = "torchvision==0.17.2"
37-
TRANSFORMERS = "transformers>=4.37.2"
37+
TRANSFORMERS = "transformers==4.37.2"
3838
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
3939
"flatbuffers",
4040
"coloredlogs",
@@ -134,27 +134,6 @@ def test_experimental_torchlib_tracing(session):
134134
)
135135

136136

137-
@nox.session(tags=["test-experimental-torchlib-onnx-ir"])
138-
def test_experimental_torchlib_onnx_ir(session):
139-
"""Test TorchLib using the ONNX IR to build graphs."""
140-
session.install(
141-
*COMMON_TEST_DEPENDENCIES,
142-
PYTORCH,
143-
TORCHVISON,
144-
ONNX,
145-
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
146-
)
147-
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
148-
session.install(".", "--no-deps")
149-
session.run("pip", "list")
150-
session.run(
151-
"pytest",
152-
"tests/function_libs/torch_lib/ops_test.py",
153-
*session.posargs,
154-
env={"TORCHLIB_EXPERIMENTAL_USE_IR": "1"},
155-
)
156-
157-
158137
@nox.session(tags=["test-dort"])
159138
def test_dort(session):
160139
"""Test the conversion of a couple of models from transformers."""
@@ -163,7 +142,7 @@ def test_dort(session):
163142
)
164143
torch_version, transformers_version = session.posargs
165144

166-
if torch_version == "nighly":
145+
if torch_version == "nightly":
167146
session.install(
168147
"--pre",
169148
"torch",

onnxscript/_internal/ast_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
import inspect
99
import sys
1010
import textwrap
11-
import types
11+
from typing import Callable
1212

1313
PY_VERSION_GE_39 = sys.version_info >= (3, 9)
1414

1515

16-
def get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
16+
def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]:
1717
try:
18-
src = inspect.getsource(f)
18+
src = inspect.getsource(func)
1919
except OSError as e:
2020
raise RuntimeError(
2121
f"Decorator script does not work on dynamically "
22-
f"compiled function {f.__name__}."
22+
f"compiled function {func.__name__}."
2323
) from e
2424
src = textwrap.dedent(src)
2525
top_level_ast = ast.parse(src)

onnxscript/_internal/runtime_typing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any])
1818

1919
try:
20-
from beartype import beartype as checked
20+
from beartype import beartype as _beartype_decorator
2121
from beartype import roar as _roar
2222

23+
checked = typing.cast(typing.Callable[[T], T], _beartype_decorator)
24+
2325
# Beartype warns when we import from typing because the types are deprecated
2426
# in Python 3.9. But there will be a long time until we can move to using
2527
# the native container types for type annotations (when 3.9 is the lowest

onnxscript/_internal/version_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
# Licensed under the MIT License.
33
"""Version utils for testing."""
44

5+
from __future__ import annotations
6+
7+
import warnings
8+
from typing import Callable, Sequence
9+
510
import packaging.version
611

712

@@ -25,6 +30,19 @@ def torch_older_than(version: str) -> bool:
2530
)
2631

2732

33+
def transformers_older_than(version: str) -> bool | None:
34+
"""Returns True if the transformers version is older than the given version."""
35+
try:
36+
import transformers # pylint: disable=import-outside-toplevel
37+
except ImportError:
38+
return None
39+
40+
return (
41+
packaging.version.parse(transformers.__version__).release
42+
< packaging.version.parse(version).release
43+
)
44+
45+
2846
def is_onnxruntime_training() -> bool:
2947
"""Returns True if the onnxruntime is onnxruntime-training."""
3048
try:
@@ -74,3 +92,27 @@ def has_transformers():
7492
return True # noqa
7593
except ImportError:
7694
return False
95+
96+
97+
def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
98+
"""Catches warnings.
99+
100+
Args:
101+
warns: warnings to ignore
102+
103+
Returns:
104+
decorated function
105+
"""
106+
107+
def wrapper(fct):
108+
if warns is None:
109+
raise AssertionError(f"warns cannot be None for '{fct}'.")
110+
111+
def call_f(self):
112+
with warnings.catch_warnings():
113+
warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
114+
return fct(self)
115+
116+
return call_f
117+
118+
return wrapper

onnxscript/backend/onnx_export_test.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import dataclasses
66
import importlib
7+
import os
78
import pathlib
89
import re
10+
import sys
911
import unittest
1012
from typing import Pattern
1113

@@ -89,6 +91,17 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
8991
skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"),
9092
)
9193

94+
if sys.platform == "win32":
95+
SKIP_TESTS = (
96+
*SKIP_TESTS,
97+
skip(r"^test_gemm_beta", "cannot import module, import_module does not work"),
98+
skip(
99+
r"^test_averagepool_2d_default",
100+
"cannot import module, import_module does not work",
101+
),
102+
skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"),
103+
)
104+
92105

93106
def load_function(obj):
94107
return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",))
@@ -106,16 +119,24 @@ def run_function(obj, *inputs):
106119
def extract_functions(name: str, content: str, test_folder: pathlib.Path):
107120
if not test_folder.exists():
108121
test_folder.mkdir(exist_ok=True, parents=True)
109-
init = test_folder / "__init__.py"
110-
init.touch(exist_ok=True)
111-
file = test_folder / f"{name}.py"
112-
file.write_text(content, encoding="utf-8")
122+
init = str(test_folder / "__init__.py")
123+
with open(init, "w", encoding="utf-8") as f:
124+
f.write("\n")
125+
filename = str(test_folder / f"{name}.py")
126+
with open(filename, "w", encoding="utf-8") as f:
127+
f.write(content + "\n")
128+
assert os.path.exists(
129+
filename
130+
), f"{filename!r} ({os.path.abspath(filename)!r} does not exist."
113131
import_name = f"tests.{test_folder.parts[-1]}.{name}"
114132
try:
115133
mod = importlib.import_module(import_name)
116134
except (SyntaxError, ImportError) as e:
117135
raise AssertionError(
118-
f"Unable to import {import_name!r} (file: {file!r})\n----\n{content}"
136+
f"Unable to import {import_name!r} (e={e}) (file: {filename!r}, "
137+
f"absolute path: {os.path.abspath(filename)!r}, "
138+
f"current folder: {os.getcwd()}"
139+
f"\n---- CONTENT --\n{content}"
119140
) from e
120141
functions = {
121142
k: v for k, v in mod.__dict__.items() if isinstance(v, onnxscript.OnnxFunction)
@@ -265,16 +286,6 @@ def _load_function(_):
265286
return session
266287

267288
def _run_function(obj, *inputs):
268-
print(" run ONNX")
269-
for i, inp in enumerate(inputs):
270-
if inp is None:
271-
print(f" input {i}: None")
272-
else:
273-
print(
274-
f" input {i}: "
275-
f"dtype={inp.dtype!r} shape={inp.shape!r}"
276-
f"{inp.ravel().tolist()!r}"
277-
)
278289
try:
279290
return run_function(obj, *inputs)
280291
except Exception as e:

onnxscript/function_libs/torch_lib/_flags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,5 @@ def _load_boolean_flag(
5454
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
5555
"TORCHLIB_EXPERIMENTAL_USE_IR",
5656
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
57+
deprecated=True,
5758
)

0 commit comments

Comments
 (0)