Skip to content

Commit bc7671c

Browse files
justinchubyCopilotshubhambhokare1
authored
[pass] Create version converter pass (#2214)
Use both the onnxscript version converter and optionally fall back to the onnx version converter if the target version is unsupported. Created `version_supported` helper function for users to check if a target version is supported by the onnxscript version converter. Use the converter in pytorch apis. --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Shubham Bhokare <[email protected]>
1 parent feb20f1 commit bc7671c

File tree

5 files changed

+194
-14
lines changed

5 files changed

+194
-14
lines changed

onnxscript/_framework_apis/torch_2_6.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"save_model_with_external_data",
1313
"torchlib_opset",
1414
]
15+
import logging
1516
from typing import TYPE_CHECKING
1617

1718
from onnxscript import ir, optimizer, version_converter
@@ -25,6 +26,9 @@
2526
from onnxscript.onnx_opset._impl.opset18 import Opset18
2627

2728

29+
logger = logging.getLogger(__name__)
30+
31+
2832
def optimize(model: ir.Model) -> ir.Model:
2933
"""Optimize the model."""
3034
optimizer.optimize_ir(model)
@@ -34,8 +38,9 @@ def optimize(model: ir.Model) -> ir.Model:
3438
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
3539
"""Convert the model to the specified ONNX opset version."""
3640
if target_version < 18:
41+
logger.warning("Conversion to opset < 18 is not supported.")
3742
return model
38-
version_converter.convert_version(model, target_version)
43+
version_converter.convert_version(model, target_version, fallback=True)
3944
return model
4045

4146

onnxscript/version_converter/__init__.py

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,157 @@
33
from __future__ import annotations
44

55
__all__ = [
6-
# Functions
6+
"ConvertVersionPass",
77
"convert_version",
88
]
99

10-
import onnxscript.optimizer
10+
import logging
11+
12+
import onnx
13+
1114
from onnxscript import ir
15+
from onnxscript.ir.passes.common import _c_api_utils
16+
from onnxscript.ir.passes.common import inliner as _inliner
17+
from onnxscript.ir.passes.common import unused_removal as _unused_removal
1218
from onnxscript.version_converter import _version_converter
1319

20+
logger = logging.getLogger(__name__)
21+
22+
23+
class ConvertVersionPass(ir.passes.InPlacePass):
24+
"""Convert the model to the specified ONNX opset version.
25+
26+
This pass leverages the onnxscript version converter to convert the model. If
27+
the conversion is not supported, it falls back to the onnx C API to convert
28+
the model. This pass is in-place.
29+
30+
The pass is an no-op if the c-api fails.
31+
32+
Attributes:
33+
target_version: The target ONNX opset version to convert the model to.
34+
fallback: Whether to fallback to the onnx version converter if the
35+
target version is not supported. Default is False.
36+
"""
37+
38+
def __init__(self, target_version: int, fallback: bool = False) -> None:
39+
super().__init__()
40+
self.target_version = target_version
41+
self.fallback = fallback
42+
self.convert_pass = ir.passes.Sequential(
43+
_inliner.InlinePass(),
44+
_ConvertVersionPassRequiresInline(
45+
target_version=target_version,
46+
fallback=fallback,
47+
),
48+
_unused_removal.RemoveUnusedNodesPass(),
49+
_unused_removal.RemoveUnusedFunctionsPass(),
50+
_unused_removal.RemoveUnusedOpsetsPass(),
51+
)
52+
53+
def call(self, model: ir.Model) -> ir.passes.PassResult:
54+
return self.convert_pass(model)
55+
56+
57+
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
58+
"""Convert the model to the specified ONNX opset version.
59+
60+
This pass leverages the onnxscript version converter to convert the model. If
61+
the conversion is not supported, it falls back to the onnx C API to convert
62+
the model. This pass is in-place.
63+
64+
The pass is an no-op if the c-api fails.
65+
66+
Attributes:
67+
target_version: The target ONNX opset version to convert the model to.
68+
fallback: Whether to fallback to the onnx version converter if the
69+
target version is not supported.
70+
"""
71+
72+
def __init__(self, target_version: int, fallback: bool) -> None:
73+
super().__init__()
74+
self.target_version = target_version
75+
self.fallback = fallback
76+
77+
def call(self, model: ir.Model) -> ir.passes.PassResult:
78+
if model.functions:
79+
raise ValueError(
80+
"The model contains functions. The version conversion pass does not support "
81+
"functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the "
82+
f"functions before applying this pass ({self.__class__.__name__})."
83+
)
84+
if "" in model.graph.opset_imports:
85+
onnx_opset_version = model.graph.opset_imports[""]
86+
if onnx_opset_version == self.target_version:
87+
# No need to convert the version
88+
return ir.passes.PassResult(model, False)
89+
90+
# When fallback is disabled, always use the onnxscript version converter;
91+
# When fallback is enabled, use the onnxscript version converter
92+
# if the target version is supported. Otherwise, use the onnx C API
93+
# to convert the model.
94+
if not self.fallback or _version_converter.version_supported(
95+
model, self.target_version
96+
):
97+
_version_converter.convert_version(
98+
model,
99+
target_version=self.target_version,
100+
)
101+
return ir.passes.PassResult(model, True)
102+
103+
if not self.fallback:
104+
logger.warning(
105+
"The model version conversion is not supported by the onnxscript version converter "
106+
"and fallback is disabled. The model was not modified"
107+
" (target version: %d). "
108+
"Set fallback=True to enable fallback to the onnx c-api version converter.",
109+
self.target_version,
110+
)
111+
return ir.passes.PassResult(model, False)
112+
113+
# If the onnxscript version converter does not support the conversion,
114+
# we can use the onnx C API to convert the model
115+
def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
116+
"""Partial function to check the model."""
117+
return onnx.version_converter.convert_version(
118+
proto, target_version=self.target_version
119+
)
120+
121+
try:
122+
converted_proto = _c_api_utils.call_onnx_api(
123+
func=_partial_convert_version, model=model
124+
)
125+
except Exception as e: # pylint: disable=broad-exception-caught
126+
logger.warning(
127+
"Failed to convert the model to the target version %d using the ONNX C API. "
128+
"The model was not modified",
129+
self.target_version,
130+
exc_info=e,
131+
)
132+
return ir.passes.PassResult(model, False)
133+
134+
converted_model = ir.from_proto(converted_proto)
135+
136+
# Recover the initializers in the converted model
137+
for input in converted_model.graph.inputs:
138+
if input.name in model.graph.initializers:
139+
input.const_value = model.graph.initializers[input.name].const_value
140+
converted_model.graph.register_initializer(input)
141+
user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
142+
converted_model.graph.inputs.clear()
143+
converted_model.graph.inputs.extend(user_inputs)
144+
145+
# Return the converted graph to the original model to keep the pass in-place
146+
model.graph = converted_model.graph
147+
return ir.passes.PassResult(model, True)
148+
14149

15-
def convert_version(model: ir.Model, target_version: int) -> None:
16-
"""Convert the model to the specified ONNX opset version."""
150+
def convert_version(model: ir.Model, target_version: int, fallback=False) -> None:
151+
"""Convert the model to the specified ONNX opset version.
17152
18-
# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
19-
# Hence, we inline all the functions.
20-
onnxscript.optimizer.inline(model)
21-
_version_converter.convert_version(model, target_version)
153+
Args:
154+
model: The model to convert.
155+
target_version: The target ONNX opset version.
156+
fallback: Whether to fallback to the onnx version converter if the
157+
target version is not supported. Default is False.
158+
"""
159+
ConvertVersionPass(target_version=target_version, fallback=fallback)(model)

onnxscript/version_converter/_version_converter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19-
CURRENT_MAX_ONNX_OPSET = 23
19+
SUPPORTED_MAX_ONNX_OPSET = 23
20+
SUPPORTED_MIN_ONNX_OPSET = 18
2021

2122

2223
class VersionConverterError(RuntimeError):
@@ -38,6 +39,20 @@ class Replacement:
3839
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]
3940

4041

42+
def version_supported(model: ir.Model, target_version: int) -> bool:
43+
"""Check if the target version is supported by the current version."""
44+
if "" in model.graph.opset_imports:
45+
current_version = model.graph.opset_imports[""]
46+
else:
47+
return True
48+
return (
49+
SUPPORTED_MIN_ONNX_OPSET
50+
<= current_version
51+
<= target_version
52+
<= SUPPORTED_MAX_ONNX_OPSET
53+
)
54+
55+
4156
class AdapterRegistry:
4257
"""A class that maintains a registry of adapters for ops."""
4358

@@ -262,7 +277,7 @@ def visit_node(
262277
return None
263278

264279
def visit_graph(self, graph: ir.Graph) -> None:
265-
if self.target_version > CURRENT_MAX_ONNX_OPSET:
280+
if self.target_version > SUPPORTED_MAX_ONNX_OPSET:
266281
logger.warning(
267282
"Conversion to target opset: %s not currently supported.",
268283
self.target_version,

onnxscript/version_converter/_version_converter_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44

55
import unittest
66

7-
import onnx.checker
87
import onnx.defs
98
import onnx.parser
10-
import onnx.shape_inference
119

1210
from onnxscript import ir, version_converter
1311

1412

15-
class ApapterCoverageTest(unittest.TestCase):
13+
class AdapterCoverageTest(unittest.TestCase):
1614
def get_all_unique_schema_versions(self) -> dict[str, list]:
1715
"""Collect all unique versions of ONNX standard domain ops"""
1816
op_version_dict = {}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import pathlib
6+
import unittest
7+
8+
from onnxscript import ir, version_converter
9+
10+
model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata"
11+
12+
13+
class ModelTest(unittest.TestCase):
14+
def test_model_runs_and_matches_accuracy_after_conversion_fallback_true(self):
15+
model_path = model_folder_path / "e2e_models/torchscript_model/torchscript_model.onnx"
16+
model = ir.load(model_path)
17+
18+
# Down convert the model with the onnx version converter
19+
version_converter.convert_version(model, target_version=16, fallback=True)
20+
self.assertEqual(model.opset_imports[""], 16)
21+
22+
23+
if __name__ == "__main__":
24+
unittest.main()

0 commit comments

Comments
 (0)