diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 89696d6986..579dd37220 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -11,6 +11,7 @@ import onnx +import onnxscript.ir.passes import onnxscript.ir.passes.common from onnxscript import ir from onnxscript.ir.passes.common import _c_api_utils diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 5ab06a1ca5..447b9412b0 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -20,6 +20,25 @@ SUPPORTED_MIN_ONNX_OPSET = 18 +def _get_onnx_opset_version(model: ir.Model) -> int | None: + """Get the ONNX opset version imported by the model.""" + model_version1 = model.opset_imports.get("") + model_version2 = model.opset_imports.get("ai.onnx") + if model_version1 is not None and model_version2 is not None: + if model_version1 != model_version2: + raise ValueError( + f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." + ) + return model_version1 or model_version2 + + +def _set_onnx_opset_version(model: ir.Model, version: int) -> None: + """Set the ONNX opset version imported by the model.""" + if "ai.onnx" in model.opset_imports: + del model.opset_imports["ai.onnx"] + model.opset_imports[""] = version + + class VersionConverterError(RuntimeError): """Raised when an node's version cannot be upgraded/downgraded successfully.""" @@ -215,25 +234,15 @@ def groupnormalization_20_21(node: ir.Node, op): class _VersionConverter: - opset_imports: dict[str, int] - model_version: int - def __init__(self, target_version: int): - self.target_version = target_version - - def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None: - if up_conversion is True: - node.version = opset_version + 1 - else: - node.version = opset_version - 1 + self._target_version = target_version def process_node( - self, node: ir.Node, opset_version: int, up_conversion: bool = True + self, node: ir.Node, from_version: int, up_conversion: bool = True ) -> Replacement | None: - if node.domain != "": - return None + assert node.domain == "" adapter = registry.lookup_adapters( - node.domain, node.op_type, opset_version, up_conversion + node.domain, node.op_type, from_version, up_conversion ) if adapter is None: return None @@ -264,67 +273,65 @@ def visit_node( self, node: ir.Node, root: ir.Graph | ir.Function, - opset_version: int, + from_version: int, up_conversion: bool = True, ) -> None: - replacement = self.process_node(node, opset_version, up_conversion) + if up_conversion: + to_version = from_version + 1 + else: + to_version = from_version - 1 + replacement = self.process_node(node, from_version, up_conversion) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): self.visit_attribute(attr) - return None + node.version = to_version else: + for new_node in replacement.new_nodes: + # TODO: control-flow + new_node.version = to_version self.replace_node(node, replacement, root) - return None def visit_graph(self, graph: ir.Graph) -> None: - if self.target_version > SUPPORTED_MAX_ONNX_OPSET: - logger.warning( - "Conversion to target opset: %s not currently supported.", - self.target_version, - ) - return None for node in graph: - up_conversion = True - if node.version is None: - node.version = self.model_version + if node.domain != "": + continue + node_version = node.version or self._default_onnx_opset + if node_version is None: + raise VersionConverterError(f"Node {node} has no version.") # Iterate each node from current node version -> target version # and updating node based on the correct adapter # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1] # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted - if self.target_version < node.version: - up_conversion = False - logger.warning( - "Target opset: %s less than %s, downstream version conversion not currently handled.", - self.target_version, - self.model_version, + if self._target_version < node_version: + raise VersionConverterError( + f"Target opset: {self._target_version} less than node version: {node.version}, " + "downstream version conversion not currently handled." ) - return None - for opset_version in range(node.version, self.target_version): + for from_version in range(node_version, self._target_version): try: - self.visit_node(node, graph, opset_version, up_conversion) - self._upgrade_version(node, opset_version, up_conversion) + self.visit_node(node, graph, from_version, up_conversion=True) except VersionConverterError as e: logger.warning( "Skipping version conversion for node %s due to exception: %s", node.op_type, e, ) - return None def visit_model(self, model: ir.Model) -> None: - self.opset_imports = model.opset_imports - model_version = self.opset_imports.get("") - if model_version is None: - model_version = model.opset_imports.get("ai.onnx") - if model_version is None: - return None - self.model_version = model_version + self._default_onnx_opset = _get_onnx_opset_version(model) self.visit_graph(model.graph) - return None + _set_onnx_opset_version(model, self._target_version) def convert_version(model: ir.Model, target_version: int) -> None: """Convert the model to the specified ONNX opset version.""" + if (target_version > SUPPORTED_MAX_ONNX_OPSET) or ( + target_version < SUPPORTED_MIN_ONNX_OPSET + ): + raise ValueError( + f"Target opset version {target_version} is not supported. " + f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}." + ) version_converter = _VersionConverter(target_version=target_version) version_converter.visit_model(model) diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index 2726dc1a4e..cf6507196b 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -5,6 +5,7 @@ import unittest import onnx.defs +import pytest from onnxscript import ir, version_converter @@ -41,18 +42,19 @@ def test_upstream_coverage(self): self.assertEqual(domain, "") self.assertIn((name, upgrade_version), op_upgrades) - def test_version_convert_non_standard_onnx_domain(self): + @pytest.mark.xfail(reason="TODO: Cleanup error status API.") + def test_version_convert_no_source_version(self): model = ir.from_onnx_text( """ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) { - shape_a = Constant() + shape_a = Constant() reshape_x = Reshape (input_x, shape_a) - shape_b = Constant() + shape_b = Constant() reshape_y = Reshape (input_x, shape_b) gridsample = GridSample (reshape_x, reshape_y) - shape_c = Constant() + shape_c = Constant() output = Reshape (gridsample, shape_c) } """ @@ -63,16 +65,9 @@ def test_version_convert_non_standard_onnx_domain(self): target_version = 20 version_converter.convert_version(model, target_version=target_version) - self.assertEqual(model.graph.node(0).op_type, "Constant") - self.assertEqual(model.graph.node(0).version, None) - self.assertEqual(model.graph.node(1).op_type, "Reshape") - self.assertEqual(model.graph.node(1).version, None) - self.assertEqual(model.graph.node(4).op_type, "GridSample") - self.assertEqual(model.graph.node(4).version, None) - self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear") - class VersionConverter18to17Test(unittest.TestCase): + @pytest.mark.xfail(strict=True, reason="Version downgrade not yet supported.") def test_version_convert_compatible(self): model = ir.from_onnx_text( """ @@ -112,6 +107,7 @@ def test_version_convert_compatible(self): ) target_version = 19 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "Constant") self.assertEqual(model.graph.node(0).version, 19) @@ -138,6 +134,7 @@ def test_version_convert_compatible(self): ) target_version = 20 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "Constant") self.assertEqual(model.graph.node(0).version, 20) @@ -170,6 +167,7 @@ def test_version_convert_gridsample_linear(self): target_version = 20 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "Constant") self.assertEqual(model.graph.node(0).version, 20) @@ -200,6 +198,7 @@ def test_version_convert_gridsample_cubic(self): target_version = 20 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "Constant") self.assertEqual(model.graph.node(0).version, 20) @@ -231,6 +230,7 @@ def test_version_convert_inline(self): ) target_version = 20 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "Constant") self.assertEqual(model.graph.node(0).version, 20) @@ -259,6 +259,7 @@ def test_version_groupnorm(self): ) target_version = 21 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(3).op_type, "Reshape") self.assertEqual(model.graph.node(3).version, 21) @@ -289,12 +290,14 @@ def test_version_groupnorm_no_bias(self): ) target_version = 21 version_converter.convert_version(model, target_version=target_version) + self.assertEqual(model.opset_imports[""], target_version) self.assertEqual(model.graph.node(0).op_type, "GroupNormalization") self.assertEqual(model.graph.node(0).version, 20) class VersionConverter23to24Test(unittest.TestCase): + @pytest.mark.xfail(strict=True, reason="Version upgrade beyond 23 not yet supported.") def test_version_convert_compatible(self): model = ir.from_onnx_text( """