Skip to content

Commit a78bf43

Browse files
authored
Export version_converter and support model proto (#2251)
* Added `version_converter` to the list of public modules in `onnxscript/__init__.py`, allowing it to be used as onnxscript.version_converter. * Updated the `convert_version` function in `onnxscript/version_converter/__init__.py` to support both `ir.Model` and `onnx.ModelProto` as input types.
1 parent 9910215 commit a78bf43

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

onnxscript/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"ir",
88
"optimizer",
99
"rewriter",
10+
"version_converter",
1011
"export_onnx_lib",
1112
"OnnxFunction",
1213
"TracedOnnxFunction",
@@ -123,7 +124,7 @@
123124

124125
# isort: on
125126

126-
from . import ir, optimizer, rewriter
127+
from . import ir, optimizer, rewriter, version_converter
127128
from ._internal.utils import external_tensor
128129
from .values import OnnxFunction, TracedOnnxFunction
129130

onnxscript/version_converter/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
147147
return ir.passes.PassResult(model, True)
148148

149149

150-
def convert_version(model: ir.Model, target_version: int, fallback=False) -> None:
150+
def convert_version(
151+
model: ir.Model | onnx.ModelProto, target_version: int, fallback=None
152+
) -> None:
151153
"""Convert the model to the specified ONNX opset version.
152154
153155
Args:
@@ -156,4 +158,17 @@ def convert_version(model: ir.Model, target_version: int, fallback=False) -> Non
156158
fallback: Whether to fallback to the onnx version converter if the
157159
target version is not supported. Default is False.
158160
"""
161+
if isinstance(model, onnx.ModelProto):
162+
model_proto = model
163+
model = ir.from_proto(model)
164+
else:
165+
model_proto = None
166+
167+
assert isinstance(model, ir.Model)
159168
ConvertVersionPass(target_version=target_version, fallback=fallback)(model)
169+
170+
if model_proto is not None:
171+
# Update the model proto in-place
172+
model_proto.graph.Clear()
173+
del model_proto.functions
174+
model_proto.graph.CopyFrom(ir.to_proto(model.graph))

0 commit comments

Comments
 (0)