Skip to content

[IR] Create a shape inference pass using onnx shape inference #2117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 24, 2025
Merged
2 changes: 1 addition & 1 deletion onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def save(

# Store the original initializer values so they can be restored if modify_model=False
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in model.graph.initializers.values()]
tensors = [v.const_value for v in initializer_values]

try:
model = _external_data.unload_from_model(
Expand Down
98 changes: 98 additions & 0 deletions onnxscript/ir/passes/common/shape_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Shape inference pass using onnx.shape_inference."""

from __future__ import annotations

__all__ = ["ShapeInferencePass"]

import logging

import onnx

from onnxscript import ir

logger = logging.getLogger(__name__)


class ShapeInferencePass(ir.passes.PassBase):
"""This pass performs shape inference on the graph."""

# This pass does not modify the model in place.
in_place = False

def __init__(
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
) -> None:
"""Initialize the shape inference pass.

Args:
check_type: If True, check the types of the inputs and outputs.
strict_mode: If True, use strict mode for shape inference.
data_prop: If True, use data propagation for shape inference.
"""
super().__init__()
self.check_type = check_type
self.strict_mode = strict_mode
self.data_prop = data_prop

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Store the original initializer values so they can be restored
initializer_values = tuple(model.graph.initializers.values())
tensors = {v.name: v.const_value for v in initializer_values}
original_inputs_len = len(model.graph.inputs)
initializer_names = {v.name for v in initializer_values}

# Turn the initializers into inputs and clear the initializers
# to limit the model size
for initializer in initializer_values:
# Make sure the initializer has its shape/type set
assert initializer.const_value is not None
if initializer.shape is None:
initializer.shape = initializer.const_value.shape
if initializer.dtype is None:
initializer.dtype = initializer.const_value.dtype
if initializer not in model.graph.inputs:
model.graph.inputs.append(initializer)
initializer.const_value = None
model.graph.initializers.clear()

# Perform shape inference
try:
proto = ir.serde.serialize_model(model)
proto = onnx.shape_inference.infer_shapes(
proto,
check_type=self.check_type,
strict_mode=self.strict_mode,
data_prop=self.data_prop,
)
inferred_model = ir.serde.deserialize_model(proto)
except Exception:
logger.warning("Shape inference failed. The model is not modified", exc_info=True)
return ir.passes.PassResult(model, modified=False)
finally:
# Restore the original initializer values so the model is unchanged
for initializer in initializer_values:
if initializer.name in initializer_names:
initializer.const_value = tensors[initializer.name]
model.graph.register_initializer(initializer)

# Restore the original inputs
inputs = model.graph.inputs[:original_inputs_len]
model.graph.inputs.clear()
model.graph.inputs.extend(inputs)

# Add the original initializer tensors to the new (inferred) model
for new_input in inferred_model.graph.inputs:
# Assign the tensors back to the initializers
if new_input.name in initializer_names:
new_input.const_value = tensors[new_input.name]
inferred_model.graph.register_initializer(new_input)

# Remove the inputs that were added
new_inputs = inferred_model.graph.inputs[:original_inputs_len]
inferred_model.graph.inputs.clear()
inferred_model.graph.inputs.extend(new_inputs)
# Even though modified, we know the pass will not change the model if we ran it again.
# So set modified to False
return ir.passes.PassResult(inferred_model, modified=False)
48 changes: 48 additions & 0 deletions onnxscript/ir/passes/common/shape_inference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

from onnxscript import ir
from onnxscript.ir.passes.common import shape_inference


class TestShapeInference(unittest.TestCase):
def test_shape_inference(self):
# Create a simple ONNX model with shape inference
# Define the model
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
ir.Value(
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
]

add_node = ir.Node("", "Add", inputs=inputs)

model = ir.Model(
ir.Graph(
inputs=inputs,
outputs=add_node.outputs,
nodes=[add_node],
opset_imports={"": 20},
),
ir_version=10,
)
self.assertIsNone(add_node.outputs[0].shape)
self.assertIsNone(add_node.outputs[0].dtype)

# Perform shape inference
result = shape_inference.ShapeInferencePass()(model)
self.assertFalse(result.modified)
self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((1, 2)))
self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((1, 2)))
self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT)


if __name__ == "__main__":
unittest.main()
Loading