Skip to content

Commit f90c2d8

Browse files
justinchubybmehta001
authored andcommitted
[IR] Create a shape inference pass using onnx shape inference (microsoft#2117)
It handles large models by removing the initializers before sending the model to onnx shape inference.
1 parent 60bfa68 commit f90c2d8

File tree

4 files changed

+288
-2
lines changed

4 files changed

+288
-2
lines changed

onnxscript/ir/_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def save(
7878

7979
# Store the original initializer values so they can be restored if modify_model=False
8080
initializer_values = tuple(model.graph.initializers.values())
81-
tensors = [v.const_value for v in model.graph.initializers.values()]
81+
tensors = [v.const_value for v in initializer_values]
8282

8383
try:
8484
model = _external_data.unload_from_model(

onnxscript/ir/passes/_pass_infra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class PassResult:
5858
5959
Attributes:
6060
model: The transformed model.
61-
modified: Whether the model was modified.
61+
modified: Whether the resulting model is different from the input model.
6262
"""
6363

6464
model: ir.Model
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Shape inference pass using onnx.shape_inference."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"ShapeInferencePass",
9+
"infer_shapes",
10+
]
11+
12+
import logging
13+
14+
import onnx
15+
16+
from onnxscript import ir
17+
18+
logger = logging.getLogger(__name__)
19+
20+
# Temporarily remove initializers larger than this size to keep model size down
21+
# for the onnx.shape_inference call because it needs to serialize the model
22+
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
23+
24+
25+
class ShapeInferencePass(ir.passes.PassBase):
26+
"""This pass performs shape inference on the graph."""
27+
28+
# This pass does not modify the model in place.
29+
in_place = False
30+
31+
def __init__(
32+
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
33+
) -> None:
34+
"""Initialize the shape inference pass.
35+
36+
Args:
37+
check_type: If True, check the types of the inputs and outputs.
38+
strict_mode: If True, use strict mode for shape inference.
39+
data_prop: If True, use data propagation for shape inference.
40+
"""
41+
super().__init__()
42+
self.check_type = check_type
43+
self.strict_mode = strict_mode
44+
self.data_prop = data_prop
45+
46+
def call(self, model: ir.Model) -> ir.passes.PassResult:
47+
# Store the original initializer values so they can be restored
48+
initializer_values = tuple(model.graph.initializers.values())
49+
tensors = {v.name: v.const_value for v in initializer_values}
50+
original_inputs_len = len(model.graph.inputs)
51+
initializer_names = {v.name for v in initializer_values}
52+
53+
# Turn the initializers into inputs and clear the initializers
54+
# to limit the model size
55+
for initializer in initializer_values:
56+
# Make sure the initializer has its shape/type set
57+
assert initializer.const_value is not None
58+
if initializer.shape is None:
59+
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
60+
if initializer.dtype is None:
61+
initializer.dtype = initializer.const_value.dtype
62+
if initializer not in model.graph.inputs:
63+
model.graph.inputs.append(initializer)
64+
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
65+
# Temporarily remove the initializer value to reduce model size
66+
# for onnx.shape_inference
67+
initializer.const_value = None
68+
assert initializer.name is not None
69+
model.graph.initializers.pop(initializer.name)
70+
71+
# Perform shape inference
72+
try:
73+
proto = ir.serde.serialize_model(model)
74+
value_infos = {info.name: info for info in proto.graph.value_info}
75+
inferred_proto = onnx.shape_inference.infer_shapes(
76+
proto,
77+
check_type=self.check_type,
78+
strict_mode=self.strict_mode,
79+
data_prop=self.data_prop,
80+
)
81+
inferred_value_infos = {
82+
info.name: info for info in inferred_proto.graph.value_info
83+
}
84+
inferred_model = ir.serde.deserialize_model(inferred_proto)
85+
86+
except Exception: # pylint: disable=broad-exception-caught
87+
logger.warning("Shape inference failed. The model is not modified", exc_info=True)
88+
return ir.passes.PassResult(model, modified=False)
89+
finally:
90+
# Restore the original initializer values so the model is unchanged
91+
for initializer in initializer_values:
92+
if initializer.name in initializer_names:
93+
initializer.const_value = tensors[initializer.name]
94+
model.graph.register_initializer(initializer)
95+
96+
# Restore the original inputs
97+
inputs = model.graph.inputs[:original_inputs_len]
98+
model.graph.inputs.clear()
99+
model.graph.inputs.extend(inputs)
100+
101+
# Add the original initializer tensors to the new (inferred) model
102+
for new_input in inferred_model.graph.inputs:
103+
# Assign the tensors back to the initializers
104+
if new_input.name in initializer_names:
105+
new_input.const_value = tensors[new_input.name]
106+
inferred_model.graph.register_initializer(new_input)
107+
108+
# Remove the inputs that were added
109+
new_inputs = inferred_model.graph.inputs[:original_inputs_len]
110+
inferred_model.graph.inputs.clear()
111+
inferred_model.graph.inputs.extend(new_inputs)
112+
113+
return ir.passes.PassResult(
114+
inferred_model, modified=value_infos != inferred_value_infos
115+
)
116+
117+
118+
def infer_shapes(
119+
model: ir.Model,
120+
*,
121+
check_type: bool = True,
122+
strict_mode: bool = True,
123+
data_prop: bool = True,
124+
) -> ir.Model:
125+
"""Perform shape inference on the model.
126+
127+
Args:
128+
model: The model to perform shape inference on.
129+
check_type: If True, check the types of the inputs and outputs.
130+
strict_mode: If True, use strict mode for shape inference.
131+
data_prop: If True, use data propagation for shape inference.
132+
133+
Returns:
134+
The model with shape inference applied.
135+
"""
136+
return ShapeInferencePass(
137+
check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
138+
)(model).model
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
9+
from onnxscript import ir
10+
from onnxscript.ir.passes.common import shape_inference
11+
12+
13+
class TestShapeInferencePass(unittest.TestCase):
14+
def test_pass(self):
15+
# Create a simple ONNX model with shape inference
16+
# Define the model
17+
inputs = [
18+
ir.Value(
19+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
20+
),
21+
ir.Value(
22+
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
23+
),
24+
]
25+
26+
add_node = ir.Node("", "Add", inputs=inputs)
27+
28+
model = ir.Model(
29+
ir.Graph(
30+
inputs=inputs,
31+
outputs=add_node.outputs,
32+
nodes=[add_node],
33+
opset_imports={"": 20},
34+
),
35+
ir_version=10,
36+
)
37+
self.assertIsNone(add_node.outputs[0].shape)
38+
self.assertIsNone(add_node.outputs[0].dtype)
39+
40+
# Perform shape inference
41+
result = shape_inference.ShapeInferencePass()(model)
42+
self.assertTrue(result.modified)
43+
self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((1, 2)))
44+
self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT)
45+
self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((1, 2)))
46+
self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT)
47+
48+
def test_pass_with_initializers(self):
49+
# _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size
50+
# of a tensor. This is fine as we just need to create a big tensor whose size
51+
# passes _BIG_TENSOR_SIZE_LIMIT
52+
big_dim = shape_inference._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access
53+
inputs = [
54+
ir.Value(
55+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
56+
),
57+
ir.Value(
58+
name="input_b",
59+
type=ir.TensorType(ir.DataType.FLOAT),
60+
shape=ir.Shape((big_dim, 1)),
61+
const_value=ir.tensor([[42]] * big_dim, dtype=ir.DataType.FLOAT),
62+
),
63+
]
64+
65+
# Shape and type are not explicitly set for the initializer but it should still work
66+
initializer = ir.Value(
67+
name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT)
68+
)
69+
70+
add_node = ir.Node("", "Add", inputs=[*inputs])
71+
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], initializer])
72+
73+
model = ir.Model(
74+
graph := ir.Graph(
75+
inputs=inputs,
76+
outputs=mul_node.outputs,
77+
nodes=[add_node, mul_node],
78+
opset_imports={"": 20},
79+
),
80+
ir_version=10,
81+
)
82+
graph.register_initializer(inputs[1])
83+
graph.register_initializer(initializer)
84+
85+
self.assertIsNone(add_node.outputs[0].shape)
86+
self.assertIsNone(add_node.outputs[0].dtype)
87+
self.assertIsNone(mul_node.outputs[0].shape)
88+
self.assertIsNone(mul_node.outputs[0].dtype)
89+
self.assertIsNone(initializer.shape)
90+
self.assertIsNone(initializer.dtype)
91+
92+
# Perform shape inference
93+
result = shape_inference.ShapeInferencePass()(model)
94+
self.assertTrue(result.modified)
95+
self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((big_dim, 2)))
96+
self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT)
97+
self.assertEqual(result.model.graph.node(1).outputs[0].shape, ir.Shape((big_dim, 2)))
98+
self.assertEqual(result.model.graph.node(1).outputs[0].dtype, ir.DataType.FLOAT)
99+
self.assertEqual(
100+
result.model.graph.initializers["initializer"].shape, ir.Shape((1, 2))
101+
)
102+
self.assertEqual(
103+
result.model.graph.initializers["initializer"].dtype, ir.DataType.FLOAT
104+
)
105+
self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((big_dim, 2)))
106+
self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT)
107+
108+
# Check that the initializer correctly appears in the result
109+
self.assertEqual(len(result.model.graph.inputs), 2)
110+
self.assertEqual(len(result.model.graph.initializers), 2)
111+
np.testing.assert_array_equal(
112+
result.model.graph.initializers["input_b"].const_value.numpy(),
113+
np.array([[42]] * big_dim, dtype=np.float32),
114+
strict=True,
115+
)
116+
self.assertEqual(
117+
result.model.graph.initializers["input_b"].const_value.dtype,
118+
ir.DataType.FLOAT,
119+
)
120+
np.testing.assert_array_equal(
121+
result.model.graph.initializers["initializer"].const_value.numpy(),
122+
np.array([[2.0, 3.0]], dtype=np.float32),
123+
strict=True,
124+
)
125+
self.assertEqual(
126+
result.model.graph.initializers["initializer"].const_value.dtype,
127+
ir.DataType.FLOAT,
128+
)
129+
130+
# Check that the original model is not modified
131+
self.assertIsNone(add_node.outputs[0].shape)
132+
self.assertIsNone(add_node.outputs[0].dtype)
133+
self.assertIsNone(mul_node.outputs[0].shape)
134+
self.assertIsNone(mul_node.outputs[0].dtype)
135+
self.assertEqual(len(model.graph.inputs), 2)
136+
self.assertEqual(len(model.graph.initializers), 2)
137+
self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value)
138+
self.assertEqual(len(model.graph.outputs), 1)
139+
self.assertEqual(model.graph.outputs[0].shape, None)
140+
self.assertEqual(model.graph.outputs[0].dtype, None)
141+
# Check that the initializer is not modified
142+
self.assertIs(
143+
model.graph.initializers["initializer"].const_value, initializer.const_value
144+
)
145+
146+
147+
if __name__ == "__main__":
148+
unittest.main()

0 commit comments

Comments
 (0)