From 067b36171028a41a65b09674eb1a679044d03aff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 18:56:56 -0700 Subject: [PATCH 01/18] [pass] Create version converter pass --- .../ir/passes/common/version_converter.py | 111 ++++++++++++++++++ onnxscript/version_converter/__init__.py | 16 ++- .../version_converter/_version_converter.py | 15 ++- 3 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 onnxscript/ir/passes/common/version_converter.py diff --git a/onnxscript/ir/passes/common/version_converter.py b/onnxscript/ir/passes/common/version_converter.py new file mode 100644 index 0000000000..3949017914 --- /dev/null +++ b/onnxscript/ir/passes/common/version_converter.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Version conversion passes.""" + +from __future__ import annotations + +__all__ = [ + "ConvertVersionPass", +] + +import logging + +import onnx + +from onnxscript import ir +from onnxscript.ir.passes.common import _c_api_utils +from onnxscript.ir.passes.common import inliner as _inliner +from onnxscript.version_converter import _version_converter + +logger = logging.getLogger(__name__) + + +class ConvertVersionPass(ir.passes.InPlacePass): + """Convert the model to the specified ONNX opset version. + + This pass leverages the onnxscript version converter to convert the model. If + the conversion is not supported, it falls back to the onnx C API to convert + the model. This pass is in-place. + + The pass is an no-op if the c-api fails. + + Attributes: + target_version: The target ONNX opset version to convert the model to. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is True. + """ + + def __init__(self, target_version: int, fallback: bool = False) -> None: + super().__init__() + self.target_version = target_version + self.fallback = fallback + self.inliner = _inliner.InlinePass() + + def call(self, model: ir.Model) -> ir.passes.PassResult: + # Normalize the opset import + if "ai.onnx" in model.graph.opset_imports: + model.graph.opset_imports[""] = model.graph.opset_imports["ai.onnx"] + del model.graph.opset_imports["ai.onnx"] + + model_opset_version = model.graph.opset_imports[""] + if model_opset_version == self.target_version: + # No need to convert the version + return ir.passes.PassResult(model, False) + + # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. + # Hence, we inline all the functions. + self.inliner(model) + + if _version_converter.version_supported(model_opset_version, self.target_version): + _version_converter.convert_version( + model, + target_version=self.target_version, + ) + return ir.passes.PassResult(model, True) + + if not self.fallback: + logger.info( + "The model version conversion is not supported by the onnxscript version converter " + "and fallback is disabled. The model was not modified" + " (current version: %d, target version: %d). " + "Set fallback=True to enable fallback to the onnx c-api version converter.", + model_opset_version, + self.target_version, + ) + return ir.passes.PassResult(model, False) + + # If the onnxscript version converter does not support the conversion, + # we can use the onnx C API to convert the model + def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: + """Partial function to check the model.""" + return onnx.version_converter.convert_version( + proto, target_version=self.target_version + ) + + try: + converted_model = _c_api_utils.call_onnx_api( + func=_partial_convert_version, model=model + ) + except Exception as e: + logger.warning( + "Failed to convert the model to the target version %d using the ONNX C API. " + "The model was not modified", + self.target_version, + exc_info=e, + ) + return ir.passes.PassResult(model, False) + + converted_model = ir.from_proto(converted_model) + + # Recover the initializers in the converted model + for input in converted_model.graph.inputs: + if input.name in model.graph.initializers: + input.const_value = model.graph.initializers[input.name].const_value + converted_model.graph.register_initializer(input) + user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)] + converted_model.graph.inputs.clear() + converted_model.graph.inputs.extend(user_inputs) + + # Return the converted graph to the original model to keep the pass in-place + model.graph = converted_model.graph + return ir.passes.PassResult(model, True) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 20b7d9c24b..a28926079b 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -7,15 +7,19 @@ "convert_version", ] -import onnxscript.optimizer from onnxscript import ir -from onnxscript.version_converter import _version_converter +from onnxscript.ir.passes.common import version_converter as _version_converter_pass -def convert_version(model: ir.Model, target_version: int) -> None: - """Convert the model to the specified ONNX opset version.""" +def convert_version(model: ir.Model, target_version: int, fallback=False) -> None: + """Convert the model to the specified ONNX opset version. + Args: + model: The model to convert. + target_version: The target ONNX opset version. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is True. + """ # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. # Hence, we inline all the functions. - onnxscript.optimizer.inline(model) - _version_converter.convert_version(model, target_version) + _version_converter_pass.ConvertVersionPass(target_version=target_version, fallback=fallback)(model) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 28a590bb27..293b119ad6 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -16,7 +16,8 @@ logger = logging.getLogger(__name__) -CURRENT_MAX_ONNX_OPSET = 23 +SUPPORTED_MAX_ONNX_OPSET = 23 +SUPPORTED_MIN_ONNX_OPSET = 18 class VersionConverterError(RuntimeError): @@ -38,6 +39,16 @@ class Replacement: AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue] +def version_supported(current_version: int, target_version: int) -> bool: + """Check if the target version is supported by the current version.""" + return ( + SUPPORTED_MIN_ONNX_OPSET + <= current_version + <= target_version + <= SUPPORTED_MIN_ONNX_OPSET + ) + + class AdapterRegistry: """A class that maintains a registry of adapters for ops.""" @@ -262,7 +273,7 @@ def visit_node( return None def visit_graph(self, graph: ir.Graph) -> None: - if self.target_version > CURRENT_MAX_ONNX_OPSET: + if self.target_version > SUPPORTED_MAX_ONNX_OPSET: logger.warning( "Conversion to target opset: %s not currently supported.", self.target_version, From e62f78787b0c4db1247bde689a5ca261a62e6a36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 18:57:08 -0700 Subject: [PATCH 02/18] format --- onnxscript/version_converter/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index a28926079b..8a190fb0d8 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -22,4 +22,6 @@ def convert_version(model: ir.Model, target_version: int, fallback=False) -> Non """ # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. # Hence, we inline all the functions. - _version_converter_pass.ConvertVersionPass(target_version=target_version, fallback=fallback)(model) + _version_converter_pass.ConvertVersionPass( + target_version=target_version, fallback=fallback + )(model) From 0a297753b17dee4bd2b7258140c75b816abfb77c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 18:59:30 -0700 Subject: [PATCH 03/18] Update onnxscript/version_converter/_version_converter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/version_converter/_version_converter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 293b119ad6..95c3c0b533 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -45,9 +45,7 @@ def version_supported(current_version: int, target_version: int) -> bool: SUPPORTED_MIN_ONNX_OPSET <= current_version <= target_version - <= SUPPORTED_MIN_ONNX_OPSET - ) - + <= SUPPORTED_MAX_ONNX_OPSET class AdapterRegistry: """A class that maintains a registry of adapters for ops.""" From f3c163840d8d647c597525225f475cc9cc86eaae Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 18:59:56 -0700 Subject: [PATCH 04/18] doc --- onnxscript/version_converter/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 8a190fb0d8..544fb91ace 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -18,7 +18,7 @@ def convert_version(model: ir.Model, target_version: int, fallback=False) -> Non model: The model to convert. target_version: The target ONNX opset version. fallback: Whether to fallback to the onnx version converter if the - target version is not supported. Default is True. + target version is not supported. """ # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. # Hence, we inline all the functions. From 9d6b1c8d49afdd40c013553f057920e44d00c463 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 19:01:16 -0700 Subject: [PATCH 05/18] Update onnxscript/version_converter/_version_converter.py --- onnxscript/version_converter/_version_converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 95c3c0b533..96276a0661 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -46,6 +46,7 @@ def version_supported(current_version: int, target_version: int) -> bool: <= current_version <= target_version <= SUPPORTED_MAX_ONNX_OPSET + ) class AdapterRegistry: """A class that maintains a registry of adapters for ops.""" From 48bfc6502facabcbbc5988cdab838cecb73eecde Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 19:11:11 -0700 Subject: [PATCH 06/18] Fix case --- .../ir/passes/common/version_converter.py | 19 +++++++------------ .../version_converter/_version_converter.py | 7 ++++++- .../_version_converter_test.py | 4 +--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/onnxscript/ir/passes/common/version_converter.py b/onnxscript/ir/passes/common/version_converter.py index 3949017914..46c030c14a 100644 --- a/onnxscript/ir/passes/common/version_converter.py +++ b/onnxscript/ir/passes/common/version_converter.py @@ -42,21 +42,17 @@ def __init__(self, target_version: int, fallback: bool = False) -> None: self.inliner = _inliner.InlinePass() def call(self, model: ir.Model) -> ir.passes.PassResult: - # Normalize the opset import - if "ai.onnx" in model.graph.opset_imports: - model.graph.opset_imports[""] = model.graph.opset_imports["ai.onnx"] - del model.graph.opset_imports["ai.onnx"] - - model_opset_version = model.graph.opset_imports[""] - if model_opset_version == self.target_version: - # No need to convert the version - return ir.passes.PassResult(model, False) + if "" in model.graph.opset_imports: + onnx_opset_version = model.graph.opset_imports[""] + if onnx_opset_version == self.target_version: + # No need to convert the version + return ir.passes.PassResult(model, False) # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. # Hence, we inline all the functions. self.inliner(model) - if _version_converter.version_supported(model_opset_version, self.target_version): + if _version_converter.version_supported(model, self.target_version): _version_converter.convert_version( model, target_version=self.target_version, @@ -67,9 +63,8 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: logger.info( "The model version conversion is not supported by the onnxscript version converter " "and fallback is disabled. The model was not modified" - " (current version: %d, target version: %d). " + " (target version: %d). " "Set fallback=True to enable fallback to the onnx c-api version converter.", - model_opset_version, self.target_version, ) return ir.passes.PassResult(model, False) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 96276a0661..46b4596fb5 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -39,8 +39,12 @@ class Replacement: AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue] -def version_supported(current_version: int, target_version: int) -> bool: +def version_supported(model: ir.Model, target_version: int) -> bool: """Check if the target version is supported by the current version.""" + if "" in model.graph.opset_imports: + current_version = model.graph.opset_imports[""] + else: + return True return ( SUPPORTED_MIN_ONNX_OPSET <= current_version @@ -48,6 +52,7 @@ def version_supported(current_version: int, target_version: int) -> bool: <= SUPPORTED_MAX_ONNX_OPSET ) + class AdapterRegistry: """A class that maintains a registry of adapters for ops.""" diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 472ffe2e50..3c73498230 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -4,15 +4,13 @@ import unittest -import onnx.checker import onnx.defs import onnx.parser -import onnx.shape_inference from onnxscript import ir, version_converter -class ApapterCoverageTest(unittest.TestCase): +class AdapterCoverageTest(unittest.TestCase): def get_all_unique_schema_versions(self) -> dict[str, list]: """Collect all unique versions of ONNX standard domain ops""" op_version_dict = {} From ec661aaad71c5d13818440f4f5e37733d355c5d8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 19:11:42 -0700 Subject: [PATCH 07/18] converted_proto --- onnxscript/ir/passes/common/version_converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/version_converter.py b/onnxscript/ir/passes/common/version_converter.py index 46c030c14a..6da5801ff5 100644 --- a/onnxscript/ir/passes/common/version_converter.py +++ b/onnxscript/ir/passes/common/version_converter.py @@ -78,7 +78,7 @@ def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: ) try: - converted_model = _c_api_utils.call_onnx_api( + converted_proto = _c_api_utils.call_onnx_api( func=_partial_convert_version, model=model ) except Exception as e: @@ -90,7 +90,7 @@ def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: ) return ir.passes.PassResult(model, False) - converted_model = ir.from_proto(converted_model) + converted_model = ir.from_proto(converted_proto) # Recover the initializers in the converted model for input in converted_model.graph.inputs: From a57fc8681d35f9a78f52eb794cbcdc8f37afd50d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Apr 2025 19:12:09 -0700 Subject: [PATCH 08/18] lint --- onnxscript/ir/passes/common/version_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/version_converter.py b/onnxscript/ir/passes/common/version_converter.py index 6da5801ff5..01ecae7eba 100644 --- a/onnxscript/ir/passes/common/version_converter.py +++ b/onnxscript/ir/passes/common/version_converter.py @@ -81,7 +81,7 @@ def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: converted_proto = _c_api_utils.call_onnx_api( func=_partial_convert_version, model=model ) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.warning( "Failed to convert the model to the target version %d using the ONNX C API. " "The model was not modified", From 3d1f4286215d2aba6f6ae062537ff14735539024 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 19 Apr 2025 09:10:58 -0700 Subject: [PATCH 09/18] Update ConvertVersionPass --- onnxscript/version_converter/__init__.py | 140 +++++++++++++++++++++-- 1 file changed, 133 insertions(+), 7 deletions(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 544fb91ace..76aaabce6c 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -3,12 +3,142 @@ from __future__ import annotations __all__ = [ - # Functions + "ConvertVersionPass", "convert_version", ] +import logging + +import onnx + from onnxscript import ir -from onnxscript.ir.passes.common import version_converter as _version_converter_pass +from onnxscript.ir.passes.common import _c_api_utils +from onnxscript.ir.passes.common import inliner as _inliner +from onnxscript.ir.passes.common import unused_removal as _unused_removal +from onnxscript.version_converter import _version_converter + +logger = logging.getLogger(__name__) + + +class ConvertVersionPass(ir.passes.InPlacePass): + """Convert the model to the specified ONNX opset version. + + This pass leverages the onnxscript version converter to convert the model. If + the conversion is not supported, it falls back to the onnx C API to convert + the model. This pass is in-place. + + The pass is an no-op if the c-api fails. + + Attributes: + target_version: The target ONNX opset version to convert the model to. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is True. + """ + + def __init__(self, target_version: int, fallback: bool = False) -> None: + super().__init__() + self.target_version = target_version + self.fallback = fallback + self.convert_pass = ir.passes.Sequential( + _inliner.InlinePass(), + _ConvertVersionPassRequiresInline( + target_version=target_version, + fallback=fallback, + ), + _unused_removal.RemoveUnusedNodesPass(), + _unused_removal.RemoveUnusedFunctionsPass(), + _unused_removal.RemoveUnusedOpsetsPass(), + ) + + def call(self, model: ir.Model) -> ir.passes.PassResult: + return self.convert_pass(model) + + +class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): + """Convert the model to the specified ONNX opset version. + + This pass leverages the onnxscript version converter to convert the model. If + the conversion is not supported, it falls back to the onnx C API to convert + the model. This pass is in-place. + + The pass is an no-op if the c-api fails. + + Attributes: + target_version: The target ONNX opset version to convert the model to. + fallback: Whether to fallback to the onnx version converter if the + target version is not supported. Default is True. + """ + + def __init__(self, target_version: int, fallback: bool = False) -> None: + super().__init__() + self.target_version = target_version + self.fallback = fallback + + def call(self, model: ir.Model) -> ir.passes.PassResult: + if model.functions: + raise ValueError( + "The model contains functions. The version conversion pass does not support " + "functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the " + f"functions before applying this pass ({self.__class__.__name__})." + ) + if "" in model.graph.opset_imports: + onnx_opset_version = model.graph.opset_imports[""] + if onnx_opset_version == self.target_version: + # No need to convert the version + return ir.passes.PassResult(model, False) + + if _version_converter.version_supported(model, self.target_version): + _version_converter.convert_version( + model, + target_version=self.target_version, + ) + return ir.passes.PassResult(model, True) + + if not self.fallback: + logger.info( + "The model version conversion is not supported by the onnxscript version converter " + "and fallback is disabled. The model was not modified" + " (target version: %d). " + "Set fallback=True to enable fallback to the onnx c-api version converter.", + self.target_version, + ) + return ir.passes.PassResult(model, False) + + # If the onnxscript version converter does not support the conversion, + # we can use the onnx C API to convert the model + def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: + """Partial function to check the model.""" + return onnx.version_converter.convert_version( + proto, target_version=self.target_version + ) + + try: + converted_proto = _c_api_utils.call_onnx_api( + func=_partial_convert_version, model=model + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to convert the model to the target version %d using the ONNX C API. " + "The model was not modified", + self.target_version, + exc_info=e, + ) + return ir.passes.PassResult(model, False) + + converted_model = ir.from_proto(converted_proto) + + # Recover the initializers in the converted model + for input in converted_model.graph.inputs: + if input.name in model.graph.initializers: + input.const_value = model.graph.initializers[input.name].const_value + converted_model.graph.register_initializer(input) + user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)] + converted_model.graph.inputs.clear() + converted_model.graph.inputs.extend(user_inputs) + + # Return the converted graph to the original model to keep the pass in-place + model.graph = converted_model.graph + return ir.passes.PassResult(model, True) def convert_version(model: ir.Model, target_version: int, fallback=False) -> None: @@ -20,8 +150,4 @@ def convert_version(model: ir.Model, target_version: int, fallback=False) -> Non fallback: Whether to fallback to the onnx version converter if the target version is not supported. """ - # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. - # Hence, we inline all the functions. - _version_converter_pass.ConvertVersionPass( - target_version=target_version, fallback=fallback - )(model) + ConvertVersionPass(target_version=target_version, fallback=fallback)(model) From 01ef856ea9da8182f9f585afd9af7aa6d838b924 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 19 Apr 2025 09:17:09 -0700 Subject: [PATCH 10/18] comments --- onnxscript/version_converter/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 76aaabce6c..b0acd09eee 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -87,7 +87,13 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: # No need to convert the version return ir.passes.PassResult(model, False) - if _version_converter.version_supported(model, self.target_version): + # When fallback is disabled, always use the onnxscript version converter; + # When fallback is enabled, use the onnxscript version converter + # if the target version is supported. Otherwise, use the onnx C API + # to convert the model. + if not self.fallback or _version_converter.version_supported( + model, self.target_version + ): _version_converter.convert_version( model, target_version=self.target_version, @@ -95,7 +101,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return ir.passes.PassResult(model, True) if not self.fallback: - logger.info( + logger.warning( "The model version conversion is not supported by the onnxscript version converter " "and fallback is disabled. The model was not modified" " (target version: %d). " From 09d44ade67c286f1264cb7f8120ccd17bbdefe16 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 19 Apr 2025 12:05:42 -0700 Subject: [PATCH 11/18] remove --- .../ir/passes/common/version_converter.py | 106 ------------------ 1 file changed, 106 deletions(-) delete mode 100644 onnxscript/ir/passes/common/version_converter.py diff --git a/onnxscript/ir/passes/common/version_converter.py b/onnxscript/ir/passes/common/version_converter.py deleted file mode 100644 index 01ecae7eba..0000000000 --- a/onnxscript/ir/passes/common/version_converter.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Version conversion passes.""" - -from __future__ import annotations - -__all__ = [ - "ConvertVersionPass", -] - -import logging - -import onnx - -from onnxscript import ir -from onnxscript.ir.passes.common import _c_api_utils -from onnxscript.ir.passes.common import inliner as _inliner -from onnxscript.version_converter import _version_converter - -logger = logging.getLogger(__name__) - - -class ConvertVersionPass(ir.passes.InPlacePass): - """Convert the model to the specified ONNX opset version. - - This pass leverages the onnxscript version converter to convert the model. If - the conversion is not supported, it falls back to the onnx C API to convert - the model. This pass is in-place. - - The pass is an no-op if the c-api fails. - - Attributes: - target_version: The target ONNX opset version to convert the model to. - fallback: Whether to fallback to the onnx version converter if the - target version is not supported. Default is True. - """ - - def __init__(self, target_version: int, fallback: bool = False) -> None: - super().__init__() - self.target_version = target_version - self.fallback = fallback - self.inliner = _inliner.InlinePass() - - def call(self, model: ir.Model) -> ir.passes.PassResult: - if "" in model.graph.opset_imports: - onnx_opset_version = model.graph.opset_imports[""] - if onnx_opset_version == self.target_version: - # No need to convert the version - return ir.passes.PassResult(model, False) - - # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. - # Hence, we inline all the functions. - self.inliner(model) - - if _version_converter.version_supported(model, self.target_version): - _version_converter.convert_version( - model, - target_version=self.target_version, - ) - return ir.passes.PassResult(model, True) - - if not self.fallback: - logger.info( - "The model version conversion is not supported by the onnxscript version converter " - "and fallback is disabled. The model was not modified" - " (target version: %d). " - "Set fallback=True to enable fallback to the onnx c-api version converter.", - self.target_version, - ) - return ir.passes.PassResult(model, False) - - # If the onnxscript version converter does not support the conversion, - # we can use the onnx C API to convert the model - def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: - """Partial function to check the model.""" - return onnx.version_converter.convert_version( - proto, target_version=self.target_version - ) - - try: - converted_proto = _c_api_utils.call_onnx_api( - func=_partial_convert_version, model=model - ) - except Exception as e: # pylint: disable=broad-exception-caught - logger.warning( - "Failed to convert the model to the target version %d using the ONNX C API. " - "The model was not modified", - self.target_version, - exc_info=e, - ) - return ir.passes.PassResult(model, False) - - converted_model = ir.from_proto(converted_proto) - - # Recover the initializers in the converted model - for input in converted_model.graph.inputs: - if input.name in model.graph.initializers: - input.const_value = model.graph.initializers[input.name].const_value - converted_model.graph.register_initializer(input) - user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)] - converted_model.graph.inputs.clear() - converted_model.graph.inputs.extend(user_inputs) - - # Return the converted graph to the original model to keep the pass in-place - model.graph = converted_model.graph - return ir.passes.PassResult(model, True) From 46143339539dc41d2aebf0b06bc8e22d2d7c8d3a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 21 Apr 2025 17:25:19 -0700 Subject: [PATCH 12/18] test fallback --- tests/version_converter/test_models.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/version_converter/test_models.py diff --git a/tests/version_converter/test_models.py b/tests/version_converter/test_models.py new file mode 100644 index 0000000000..c012007d12 --- /dev/null +++ b/tests/version_converter/test_models.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import pathlib +import unittest + +from onnxscript import ir, version_converter + +model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" + + +class ModelTest(unittest.TestCase): + def test_model_runs_and_matches_accuracy_after_conversion_fallback_true(self): + model_path = model_folder_path / "e2e_models/torchscript_model/torchscript_model.onnx" + model = ir.load(model_path) + + # Down convert the model with the onnx version converter + version_converter.convert_version(model, target_version=16, fallback=True) + self.assertEqual(model.opset_imports[""], 16) + + +if __name__ == "__main__": + unittest.main() From 188b6cb2cd82b529f4b03053c70d59a2139619ec Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 21 Apr 2025 20:36:11 -0700 Subject: [PATCH 13/18] rename --- .../{test_models.py => version_conversion_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/version_converter/{test_models.py => version_conversion_test.py} (100%) diff --git a/tests/version_converter/test_models.py b/tests/version_converter/version_conversion_test.py similarity index 100% rename from tests/version_converter/test_models.py rename to tests/version_converter/version_conversion_test.py From 529ae52ef946a6d30608b794a14e5f9ceeb7d24b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Apr 2025 10:14:37 -0700 Subject: [PATCH 14/18] Update onnxscript/version_converter/__init__.py Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> --- onnxscript/version_converter/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index b0acd09eee..27e1ba3034 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -66,7 +66,7 @@ class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): Attributes: target_version: The target ONNX opset version to convert the model to. fallback: Whether to fallback to the onnx version converter if the - target version is not supported. Default is True. + target version is not supported. Default is False. """ def __init__(self, target_version: int, fallback: bool = False) -> None: From b73398eccaad9c56cae510499a2c53f03ebd42a6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Apr 2025 10:16:04 -0700 Subject: [PATCH 15/18] Apply suggestions from code review --- onnxscript/version_converter/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 27e1ba3034..e7432a1021 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -32,7 +32,7 @@ class ConvertVersionPass(ir.passes.InPlacePass): Attributes: target_version: The target ONNX opset version to convert the model to. fallback: Whether to fallback to the onnx version converter if the - target version is not supported. Default is True. + target version is not supported. Default is False. """ def __init__(self, target_version: int, fallback: bool = False) -> None: @@ -154,6 +154,6 @@ def convert_version(model: ir.Model, target_version: int, fallback=False) -> Non model: The model to convert. target_version: The target ONNX opset version. fallback: Whether to fallback to the onnx version converter if the - target version is not supported. + target version is not supported. Default is False. """ ConvertVersionPass(target_version=target_version, fallback=fallback)(model) From 7e9ca80b5af406be5ec31811d44019fa20fce456 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Apr 2025 10:28:36 -0700 Subject: [PATCH 16/18] Allow fallback --- onnxscript/_framework_apis/torch_2_6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py index 2cfe51cea0..58084e8406 100644 --- a/onnxscript/_framework_apis/torch_2_6.py +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -35,7 +35,7 @@ def convert_version(model: ir.Model, target_version: int) -> ir.Model: """Convert the model to the specified ONNX opset version.""" if target_version < 18: return model - version_converter.convert_version(model, target_version) + version_converter.convert_version(model, target_version, fallback=True) return model From 107fb6aa421e4a19f9097c66897f4d60f2237a36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Apr 2025 10:29:35 -0700 Subject: [PATCH 17/18] log --- onnxscript/_framework_apis/torch_2_6.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py index 58084e8406..2d166cb967 100644 --- a/onnxscript/_framework_apis/torch_2_6.py +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -12,6 +12,7 @@ "save_model_with_external_data", "torchlib_opset", ] +import logging from typing import TYPE_CHECKING from onnxscript import ir, optimizer, version_converter @@ -25,6 +26,9 @@ from onnxscript.onnx_opset._impl.opset18 import Opset18 +logger = logging.getLogger(__name__) + + def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" optimizer.optimize_ir(model) @@ -34,6 +38,7 @@ def optimize(model: ir.Model) -> ir.Model: def convert_version(model: ir.Model, target_version: int) -> ir.Model: """Convert the model to the specified ONNX opset version.""" if target_version < 18: + logger.warning("Conversion to opset < 18 is not supported.") return model version_converter.convert_version(model, target_version, fallback=True) return model From 9ce36985ff29e64e720a7a30c134dcee5a3cab29 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Apr 2025 10:36:39 -0700 Subject: [PATCH 18/18] arg --- onnxscript/version_converter/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index e7432a1021..23d7bf23b0 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -66,10 +66,10 @@ class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): Attributes: target_version: The target ONNX opset version to convert the model to. fallback: Whether to fallback to the onnx version converter if the - target version is not supported. Default is False. + target version is not supported. """ - def __init__(self, target_version: int, fallback: bool = False) -> None: + def __init__(self, target_version: int, fallback: bool) -> None: super().__init__() self.target_version = target_version self.fallback = fallback