Skip to content

[IR] Create a shape inference pass using onnx shape inference #2117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 24, 2025
Merged
2 changes: 1 addition & 1 deletion onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def save(

# Store the original initializer values so they can be restored if modify_model=False
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in model.graph.initializers.values()]
tensors = [v.const_value for v in initializer_values]

try:
model = _external_data.unload_from_model(
Expand Down
77 changes: 77 additions & 0 deletions onnxscript/ir/passes/common/shape_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Shape inference pass using onnx.shape_inference."""

from __future__ import annotations

import logging

import onnx

from onnxscript import ir

logger = logging.getLogger(__name__)


class ShapeInferencePass(ir.passes.PassBase):
"""This pass performs shape inference on the graph."""

def __init__(
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
) -> None:
"""Initialize the shape inference pass.

Args:
check_type: If True, check the types of the inputs and outputs.
strict_mode: If True, use strict mode for shape inference.
data_prop: If True, use data propagation for shape inference.
"""
super().__init__()
self.check_type = check_type
self.strict_mode = strict_mode
self.data_prop = data_prop

in_place = False

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Store the original initializer values so they can be restored
initializer_values = tuple(model.graph.initializers.values())
tensors = {v.name: v.const_value for v in initializer_values}
original_inputs_len = len(model.graph.inputs)
initializer_names = {v.name for v in initializer_values}

# Turn the initializers into inputs and clear the initializers
# to limit the model size
for initializer in initializer_values:
if initializer not in model.graph.inputs:
model.graph.inputs.append(initializer)
initializer.const_value = None

Check warning on line 48 in onnxscript/ir/passes/common/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference.py#L47-L48

Added lines #L47 - L48 were not covered by tests
model.graph.initializers.clear()

# Perform shape inference
try:
proto = ir.serde.serialize_model(model)
proto = onnx.shape_inference.infer_shapes(
proto,
check_type=self.check_type,
strict_mode=self.strict_mode,
data_prop=self.data_prop,
)
model = ir.serde.deserialize_model(proto)
except Exception:
logger.warning("Shape inference failed. The model is not modified", exc_info=True)

Check warning on line 62 in onnxscript/ir/passes/common/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference.py#L61-L62

Added lines #L61 - L62 were not covered by tests
finally:
# Restore the original initializer values so the model is unchanged
for new_input in model.graph.inputs:
# Assign the tensors back to the initializers
if new_input.name in initializer_names:
model.graph.register_initializer(new_input)
new_input.const_value = tensors[new_input.name]

Check warning on line 69 in onnxscript/ir/passes/common/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference.py#L68-L69

Added lines #L68 - L69 were not covered by tests
# Remove the inputs that were added
inputs = model.graph.inputs[:original_inputs_len]
model.graph.inputs.clear()
model.graph.inputs.extend(inputs)

# Even though modified, we know the pass will not change the model if we ran it again.
# So set modified to False
return ir.passes.PassResult(model, modified=False)
48 changes: 48 additions & 0 deletions onnxscript/ir/passes/common/shape_inference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

from onnxscript import ir
from onnxscript.ir.passes.common import shape_inference


class TestShapeInference(unittest.TestCase):
def test_shape_inference(self):
# Create a simple ONNX model with shape inference
# Define the model
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
ir.Value(
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
]

add_node = ir.Node("", "Add", inputs=inputs)

model = ir.Model(
ir.Graph(
inputs=inputs,
outputs=add_node.outputs,
nodes=[add_node],
opset_imports={"": 20},
),
ir_version=10,
)
self.assertIsNone(add_node.outputs[0].shape)
self.assertIsNone(add_node.outputs[0].dtype)

# Perform shape inference
result = shape_inference.ShapeInferencePass()(model)
self.assertFalse(result.modified)
self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((1, 2)))
self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((1, 2)))
self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT)


if __name__ == "__main__":
unittest.main()

Check warning on line 48 in onnxscript/ir/passes/common/shape_inference_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference_test.py#L48

Added line #L48 was not covered by tests
Loading