Skip to content

Commit 0361971

Browse files
xadupretitaiwangmsjustinchuby
authored
Fix misleading annotation in the documentation (#2046)
The function does not seem to work inplace in all cases. --------- Co-authored-by: Ti-Tai Wang <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent ab2dabe commit 0361971

File tree

4 files changed

+25
-20
lines changed

4 files changed

+25
-20
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@
8888
"python": (f"https://docs.python.org/{sys.version_info.major}", None),
8989
"matplotlib": ("https://matplotlib.org/stable/", None),
9090
"numpy": ("https://numpy.org/doc/stable/", None),
91+
"onnx": ("https://onnx.ai/onnx/", None),
9192
"onnxruntime": ("https://onnxruntime.ai/docs/api/python/", None),
93+
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
94+
"torch": ("https://pytorch.org/docs/main/", None),
9295
}
9396

9497
# -- Options for Sphinx Gallery ----------------------------------------------

onnxscript/optimizer/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5+
__all__ = [
6+
"fold_constants",
7+
"fold_constants_ir",
8+
"remove_unused_nodes",
9+
"optimize",
10+
"optimize_ir",
11+
"basic_constant_propagation",
12+
]
13+
514
import onnx
615

716
import onnxscript.optimizer._constant_folding as constant_folding
@@ -15,25 +24,17 @@
1524
fold_constants_ir = constant_folding.fold_constants
1625

1726

18-
def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
27+
def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
1928
if isinstance(model, ir.Model):
20-
return optimize_ir(model, *args, **kwargs)
29+
# In that case, this is done inplace.
30+
optimize_ir(model, *args, **kwargs)
31+
return model
2132
else:
2233
return legacy_optimizer.optimize(model, *args, **kwargs)
2334

2435

25-
def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs):
36+
def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
2637
if isinstance(model, ir.Model):
2738
return constant_folding.fold_constants(model, *args, **kwargs)
2839
else:
2940
return legacy_constant_folding.fold_constants(model, *args, **kwargs)
30-
31-
32-
__all__ = [
33-
"fold_constants",
34-
"fold_constants_ir",
35-
"remove_unused_nodes",
36-
"optimize",
37-
"optimize_ir",
38-
"basic_constant_propagation",
39-
]

onnxscript/tools/benchmark/benchmark_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,12 @@ def optimize_model_proto(
450450
begin = time.perf_counter()
451451

452452
if value == "optimize":
453-
model_proto = onnxscript.optimizer.optimize(
454-
model_proto,
453+
model_ir = onnxscript.optimizer.optimize(
454+
ir.from_proto(model_proto),
455455
num_iterations=2,
456456
onnx_shape_inference=False,
457457
)
458+
model_proto = ir.to_proto(model_ir)
458459

459460
elif value == "rewrite":
460461
model_proto = onnxscript.rewriter.rewrite(model_proto)

onnxscript/tools/transformers_models/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def export_to_onnx(
4242
else:
4343
prog = torch.onnx.dynamo_export(model, *args)
4444
assert prog is not None
45-
model_proto = prog.model_proto
45+
model = prog.model
4646
if optimize:
47-
model_proto = onnxscript.optimizer.optimize(
48-
model_proto,
47+
model = onnxscript.optimizer.optimize(
48+
model,
4949
num_iterations=2,
50-
onnx_shape_inference=True,
5150
)
52-
model_proto = onnxscript.rewriter.rewrite(model_proto)
51+
model = onnxscript.rewriter.rewrite(model)
52+
model_proto = onnxscript.ir.to_proto(model)
5353
model_proto = onnx.inliner.inline_local_functions(model_proto)
5454
return model_proto
5555

0 commit comments

Comments
 (0)