|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | __all__ = [
|
6 |
| - # Functions |
| 6 | + "ConvertVersionPass", |
7 | 7 | "convert_version",
|
8 | 8 | ]
|
9 | 9 |
|
10 |
| -import onnxscript.optimizer |
| 10 | +import logging |
| 11 | + |
| 12 | +import onnx |
| 13 | + |
11 | 14 | from onnxscript import ir
|
| 15 | +from onnxscript.ir.passes.common import _c_api_utils |
| 16 | +from onnxscript.ir.passes.common import inliner as _inliner |
| 17 | +from onnxscript.ir.passes.common import unused_removal as _unused_removal |
12 | 18 | from onnxscript.version_converter import _version_converter
|
13 | 19 |
|
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +class ConvertVersionPass(ir.passes.InPlacePass): |
| 24 | + """Convert the model to the specified ONNX opset version. |
| 25 | +
|
| 26 | + This pass leverages the onnxscript version converter to convert the model. If |
| 27 | + the conversion is not supported, it falls back to the onnx C API to convert |
| 28 | + the model. This pass is in-place. |
| 29 | +
|
| 30 | + The pass is an no-op if the c-api fails. |
| 31 | +
|
| 32 | + Attributes: |
| 33 | + target_version: The target ONNX opset version to convert the model to. |
| 34 | + fallback: Whether to fallback to the onnx version converter if the |
| 35 | + target version is not supported. Default is False. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, target_version: int, fallback: bool = False) -> None: |
| 39 | + super().__init__() |
| 40 | + self.target_version = target_version |
| 41 | + self.fallback = fallback |
| 42 | + self.convert_pass = ir.passes.Sequential( |
| 43 | + _inliner.InlinePass(), |
| 44 | + _ConvertVersionPassRequiresInline( |
| 45 | + target_version=target_version, |
| 46 | + fallback=fallback, |
| 47 | + ), |
| 48 | + _unused_removal.RemoveUnusedNodesPass(), |
| 49 | + _unused_removal.RemoveUnusedFunctionsPass(), |
| 50 | + _unused_removal.RemoveUnusedOpsetsPass(), |
| 51 | + ) |
| 52 | + |
| 53 | + def call(self, model: ir.Model) -> ir.passes.PassResult: |
| 54 | + return self.convert_pass(model) |
| 55 | + |
| 56 | + |
| 57 | +class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): |
| 58 | + """Convert the model to the specified ONNX opset version. |
| 59 | +
|
| 60 | + This pass leverages the onnxscript version converter to convert the model. If |
| 61 | + the conversion is not supported, it falls back to the onnx C API to convert |
| 62 | + the model. This pass is in-place. |
| 63 | +
|
| 64 | + The pass is an no-op if the c-api fails. |
| 65 | +
|
| 66 | + Attributes: |
| 67 | + target_version: The target ONNX opset version to convert the model to. |
| 68 | + fallback: Whether to fallback to the onnx version converter if the |
| 69 | + target version is not supported. |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, target_version: int, fallback: bool) -> None: |
| 73 | + super().__init__() |
| 74 | + self.target_version = target_version |
| 75 | + self.fallback = fallback |
| 76 | + |
| 77 | + def call(self, model: ir.Model) -> ir.passes.PassResult: |
| 78 | + if model.functions: |
| 79 | + raise ValueError( |
| 80 | + "The model contains functions. The version conversion pass does not support " |
| 81 | + "functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the " |
| 82 | + f"functions before applying this pass ({self.__class__.__name__})." |
| 83 | + ) |
| 84 | + if "" in model.graph.opset_imports: |
| 85 | + onnx_opset_version = model.graph.opset_imports[""] |
| 86 | + if onnx_opset_version == self.target_version: |
| 87 | + # No need to convert the version |
| 88 | + return ir.passes.PassResult(model, False) |
| 89 | + |
| 90 | + # When fallback is disabled, always use the onnxscript version converter; |
| 91 | + # When fallback is enabled, use the onnxscript version converter |
| 92 | + # if the target version is supported. Otherwise, use the onnx C API |
| 93 | + # to convert the model. |
| 94 | + if not self.fallback or _version_converter.version_supported( |
| 95 | + model, self.target_version |
| 96 | + ): |
| 97 | + _version_converter.convert_version( |
| 98 | + model, |
| 99 | + target_version=self.target_version, |
| 100 | + ) |
| 101 | + return ir.passes.PassResult(model, True) |
| 102 | + |
| 103 | + if not self.fallback: |
| 104 | + logger.warning( |
| 105 | + "The model version conversion is not supported by the onnxscript version converter " |
| 106 | + "and fallback is disabled. The model was not modified" |
| 107 | + " (target version: %d). " |
| 108 | + "Set fallback=True to enable fallback to the onnx c-api version converter.", |
| 109 | + self.target_version, |
| 110 | + ) |
| 111 | + return ir.passes.PassResult(model, False) |
| 112 | + |
| 113 | + # If the onnxscript version converter does not support the conversion, |
| 114 | + # we can use the onnx C API to convert the model |
| 115 | + def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto: |
| 116 | + """Partial function to check the model.""" |
| 117 | + return onnx.version_converter.convert_version( |
| 118 | + proto, target_version=self.target_version |
| 119 | + ) |
| 120 | + |
| 121 | + try: |
| 122 | + converted_proto = _c_api_utils.call_onnx_api( |
| 123 | + func=_partial_convert_version, model=model |
| 124 | + ) |
| 125 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 126 | + logger.warning( |
| 127 | + "Failed to convert the model to the target version %d using the ONNX C API. " |
| 128 | + "The model was not modified", |
| 129 | + self.target_version, |
| 130 | + exc_info=e, |
| 131 | + ) |
| 132 | + return ir.passes.PassResult(model, False) |
| 133 | + |
| 134 | + converted_model = ir.from_proto(converted_proto) |
| 135 | + |
| 136 | + # Recover the initializers in the converted model |
| 137 | + for input in converted_model.graph.inputs: |
| 138 | + if input.name in model.graph.initializers: |
| 139 | + input.const_value = model.graph.initializers[input.name].const_value |
| 140 | + converted_model.graph.register_initializer(input) |
| 141 | + user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)] |
| 142 | + converted_model.graph.inputs.clear() |
| 143 | + converted_model.graph.inputs.extend(user_inputs) |
| 144 | + |
| 145 | + # Return the converted graph to the original model to keep the pass in-place |
| 146 | + model.graph = converted_model.graph |
| 147 | + return ir.passes.PassResult(model, True) |
| 148 | + |
14 | 149 |
|
15 |
| -def convert_version(model: ir.Model, target_version: int) -> None: |
16 |
| - """Convert the model to the specified ONNX opset version.""" |
| 150 | +def convert_version(model: ir.Model, target_version: int, fallback=False) -> None: |
| 151 | + """Convert the model to the specified ONNX opset version. |
17 | 152 |
|
18 |
| - # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. |
19 |
| - # Hence, we inline all the functions. |
20 |
| - onnxscript.optimizer.inline(model) |
21 |
| - _version_converter.convert_version(model, target_version) |
| 153 | + Args: |
| 154 | + model: The model to convert. |
| 155 | + target_version: The target ONNX opset version. |
| 156 | + fallback: Whether to fallback to the onnx version converter if the |
| 157 | + target version is not supported. Default is False. |
| 158 | + """ |
| 159 | + ConvertVersionPass(target_version=target_version, fallback=fallback)(model) |
0 commit comments