Skip to content

Commit d7955f4

Browse files
authored
[pass] Implement checker pass and refactor shape inference (#2199)
- Refactor shape_inference pass to extract logic for handling large models for onnx c-api. - Implement an onnx checker pass leveraging the refactored logic.
1 parent 133f344 commit d7955f4

File tree

5 files changed

+259
-86
lines changed

5 files changed

+259
-86
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Utilities for interfacing with onnx C APIs."""
4+
5+
from __future__ import annotations
6+
7+
import logging
8+
from typing import TYPE_CHECKING, Callable, TypeVar
9+
10+
from onnxscript import ir
11+
12+
if TYPE_CHECKING:
13+
import onnx
14+
15+
16+
logger = logging.getLogger(__name__)
17+
# Temporarily remove initializers larger than this size to keep model size down
18+
# for the onnx.shape_inference call because it needs to serialize the model
19+
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20+
_R = TypeVar("_R")
21+
22+
23+
def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
24+
"""Call an ONNX C API function by temporarily removing initializers.
25+
26+
This is necessary because the ONNX C API does not support large models
27+
with initializers that have large tensor values. The input model is left
28+
unchanged no matter the call succeeds or not.
29+
30+
Args:
31+
func: Partially applied function that takes a model proto and returns anything.
32+
model: The IR model to pass to the API function.
33+
34+
Returns:
35+
The resulting ModelProto that contains the result of the API call.
36+
"""
37+
38+
# Store the original initializer values so they can be restored
39+
initializer_values = tuple(model.graph.initializers.values())
40+
tensors = {v.name: v.const_value for v in initializer_values}
41+
original_inputs_len = len(model.graph.inputs)
42+
43+
# Turn the initializers into inputs and clear the initializers
44+
# to limit the model size
45+
for initializer in initializer_values:
46+
# Make sure the initializer has its shape/type set
47+
assert initializer.const_value is not None
48+
if initializer.shape is None:
49+
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
50+
if initializer.dtype is None:
51+
initializer.dtype = initializer.const_value.dtype
52+
if initializer not in model.graph.inputs:
53+
model.graph.inputs.append(initializer)
54+
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
55+
# Temporarily remove the initializer value to reduce model size
56+
# for onnx.shape_inference
57+
initializer.const_value = None
58+
assert initializer.name is not None
59+
model.graph.initializers.pop(initializer.name)
60+
61+
proto = ir.serde.serialize_model(model)
62+
63+
try:
64+
# Call the ONNX C API function
65+
result = func(proto)
66+
finally:
67+
# Restore the original initializer values so the model is unchanged
68+
for initializer in initializer_values:
69+
initializer.const_value = tensors[initializer.name]
70+
model.graph.register_initializer(initializer)
71+
72+
# Restore the original inputs
73+
inputs = model.graph.inputs[:original_inputs_len]
74+
model.graph.inputs.clear()
75+
model.graph.inputs.extend(inputs)
76+
77+
return result
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Passes for debugging purposes."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"CheckerPass",
9+
]
10+
11+
import onnx
12+
13+
from onnxscript import ir
14+
from onnxscript.ir.passes.common import _c_api_utils
15+
16+
17+
class CheckerPass(ir.passes.PassBase):
18+
"""Run onnx checker on the model."""
19+
20+
@property
21+
def in_place(self) -> bool:
22+
return True
23+
24+
@property
25+
def changes_input(self) -> bool:
26+
return False
27+
28+
def __init__(
29+
self,
30+
full_check: bool = False,
31+
skip_opset_compatibility_check: bool = False,
32+
check_custom_domain: bool = False,
33+
):
34+
super().__init__()
35+
self.full_check = full_check
36+
self.skip_opset_compatibility_check = skip_opset_compatibility_check
37+
self.check_custom_domain = check_custom_domain
38+
39+
def call(self, model: ir.Model) -> ir.passes.PassResult:
40+
"""Run the onnx checker on the model."""
41+
42+
def _partial_check_model(proto: onnx.ModelProto) -> None:
43+
"""Partial function to check the model."""
44+
onnx.checker.check_model(
45+
proto,
46+
full_check=self.full_check,
47+
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
48+
check_custom_domain=self.check_custom_domain,
49+
)
50+
51+
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
52+
# The model is not modified
53+
return ir.passes.PassResult(model, False)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
from onnxscript import ir
8+
from onnxscript.ir.passes.common import onnx_checker
9+
10+
11+
class TestCheckerPass(unittest.TestCase):
12+
def test_pass_is_no_op(self):
13+
checker_pass = onnx_checker.CheckerPass()
14+
self.assertTrue(checker_pass.in_place)
15+
self.assertFalse(checker_pass.changes_input)
16+
17+
def test_check_simple_model(self):
18+
inputs = [
19+
ir.Value(
20+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
21+
),
22+
ir.Value(
23+
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
24+
),
25+
]
26+
27+
tape = ir.tape.Tape()
28+
29+
output = tape.op("Add", inputs=inputs)
30+
output.shape = ir.Shape((1, 2))
31+
output.dtype = ir.DataType.FLOAT
32+
33+
model = ir.Model(
34+
ir.Graph(
35+
inputs=inputs,
36+
outputs=[output],
37+
nodes=tape.nodes,
38+
opset_imports={"": 20},
39+
name="test_model",
40+
),
41+
ir_version=10,
42+
)
43+
# No exception should be raised
44+
onnx_checker.CheckerPass()(model)
45+
46+
def test_check_invalid_model(self):
47+
inputs = [
48+
ir.Value(
49+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
50+
),
51+
ir.Value(
52+
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
53+
),
54+
]
55+
56+
tape = ir.tape.Tape()
57+
58+
output = tape.op("Add", inputs=inputs)
59+
output.shape = ir.Shape((1, 2))
60+
output.dtype = ir.DataType.FLOAT
61+
62+
model = ir.Model(
63+
ir.Graph(
64+
inputs=inputs,
65+
outputs=[output],
66+
nodes=tape.nodes,
67+
opset_imports={"": 20},
68+
),
69+
ir_version=10,
70+
)
71+
72+
with self.assertRaisesRegex(
73+
Exception, "Field 'name' of 'graph' is required to be non-empty"
74+
):
75+
onnx_checker.CheckerPass()(model)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main()

onnxscript/ir/passes/common/shape_inference.py

Lines changed: 45 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,52 @@
1414
import onnx
1515

1616
from onnxscript import ir
17+
from onnxscript.ir.passes.common import _c_api_utils
1718

1819
logger = logging.getLogger(__name__)
1920

20-
# Temporarily remove initializers larger than this size to keep model size down
21-
# for the onnx.shape_inference call because it needs to serialize the model
22-
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
2321

22+
def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
23+
"""Merge the shape inferred model with the original model.
2424
25-
class ShapeInferencePass(ir.passes.FunctionalPass):
25+
Args:
26+
model: The original IR model.
27+
inferred_proto: The ONNX model with shapes and types inferred.
28+
29+
Returns:
30+
A tuple containing the modified model and a boolean indicating whether the model was modified.
31+
"""
32+
inferred_model = ir.serde.deserialize_model(inferred_proto)
33+
modified = False
34+
for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
35+
original_values = ir.convenience.create_value_mapping(original_graph)
36+
inferred_values = ir.convenience.create_value_mapping(inferred_graph)
37+
for name, value in original_values.items():
38+
if name in inferred_values:
39+
inferred_value = inferred_values[name]
40+
if value.shape != inferred_value.shape and inferred_value.shape is not None:
41+
value.shape = inferred_value.shape
42+
modified = True
43+
if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
44+
value.dtype = inferred_value.dtype
45+
modified = True
46+
else:
47+
logger.warning(
48+
"Value %s not found in inferred graph %s", name, inferred_graph.name
49+
)
50+
return modified
51+
52+
53+
class ShapeInferencePass(ir.passes.InPlacePass):
2654
"""This pass performs shape inference on the graph."""
2755

2856
def __init__(
2957
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
3058
) -> None:
3159
"""Initialize the shape inference pass.
3260
61+
If inference fails, the model is left unchanged.
62+
3363
Args:
3464
check_type: If True, check the types of the inputs and outputs.
3565
strict_mode: If True, use strict mode for shape inference.
@@ -41,75 +71,22 @@ def __init__(
4171
self.data_prop = data_prop
4272

4373
def call(self, model: ir.Model) -> ir.passes.PassResult:
44-
# Store the original initializer values so they can be restored
45-
initializer_values = tuple(model.graph.initializers.values())
46-
tensors = {v.name: v.const_value for v in initializer_values}
47-
original_inputs_len = len(model.graph.inputs)
48-
initializer_names = {v.name for v in initializer_values}
49-
50-
# Turn the initializers into inputs and clear the initializers
51-
# to limit the model size
52-
for initializer in initializer_values:
53-
# Make sure the initializer has its shape/type set
54-
assert initializer.const_value is not None
55-
if initializer.shape is None:
56-
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
57-
if initializer.dtype is None:
58-
initializer.dtype = initializer.const_value.dtype
59-
if initializer not in model.graph.inputs:
60-
model.graph.inputs.append(initializer)
61-
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
62-
# Temporarily remove the initializer value to reduce model size
63-
# for onnx.shape_inference
64-
initializer.const_value = None
65-
assert initializer.name is not None
66-
model.graph.initializers.pop(initializer.name)
67-
68-
# Perform shape inference
69-
try:
70-
proto = ir.serde.serialize_model(model)
71-
value_infos = {info.name: info for info in proto.graph.value_info}
72-
inferred_proto = onnx.shape_inference.infer_shapes(
74+
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75+
return onnx.shape_inference.infer_shapes(
7376
proto,
7477
check_type=self.check_type,
7578
strict_mode=self.strict_mode,
7679
data_prop=self.data_prop,
7780
)
78-
inferred_value_infos = {
79-
info.name: info for info in inferred_proto.graph.value_info
80-
}
81-
inferred_model = ir.serde.deserialize_model(inferred_proto)
82-
83-
except Exception: # pylint: disable=broad-exception-caught
84-
logger.warning("Shape inference failed. The model is not modified", exc_info=True)
85-
return ir.passes.PassResult(model, modified=False)
86-
finally:
87-
# Restore the original initializer values so the model is unchanged
88-
for initializer in initializer_values:
89-
if initializer.name in initializer_names:
90-
initializer.const_value = tensors[initializer.name]
91-
model.graph.register_initializer(initializer)
92-
93-
# Restore the original inputs
94-
inputs = model.graph.inputs[:original_inputs_len]
95-
model.graph.inputs.clear()
96-
model.graph.inputs.extend(inputs)
97-
98-
# Add the original initializer tensors to the new (inferred) model
99-
for new_input in inferred_model.graph.inputs:
100-
# Assign the tensors back to the initializers
101-
if new_input.name in initializer_names:
102-
new_input.const_value = tensors[new_input.name]
103-
inferred_model.graph.register_initializer(new_input)
104-
105-
# Remove the inputs that were added
106-
new_inputs = inferred_model.graph.inputs[:original_inputs_len]
107-
inferred_model.graph.inputs.clear()
108-
inferred_model.graph.inputs.extend(new_inputs)
109-
110-
return ir.passes.PassResult(
111-
inferred_model, modified=value_infos != inferred_value_infos
112-
)
81+
82+
try:
83+
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
84+
except Exception as e: # pylint: disable=broad-exception-caught
85+
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
86+
return ir.passes.PassResult(model, False)
87+
88+
modified = _merge_func(model, inferred_model_proto)
89+
return ir.passes.PassResult(model, modified=modified)
11390

11491

11592
def infer_shapes(

0 commit comments

Comments
 (0)