14
14
import onnx
15
15
16
16
from onnxscript import ir
17
+ from onnxscript .ir .passes .common import _c_api_utils
17
18
18
19
logger = logging .getLogger (__name__ )
19
20
20
- # Temporarily remove initializers larger than this size to keep model size down
21
- # for the onnx.shape_inference call because it needs to serialize the model
22
- _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
23
21
22
+ def _merge_func (model : ir .Model , inferred_proto : onnx .ModelProto ) -> bool :
23
+ """Merge the shape inferred model with the original model.
24
24
25
- class ShapeInferencePass (ir .passes .FunctionalPass ):
25
+ Args:
26
+ model: The original IR model.
27
+ inferred_proto: The ONNX model with shapes and types inferred.
28
+
29
+ Returns:
30
+ A tuple containing the modified model and a boolean indicating whether the model was modified.
31
+ """
32
+ inferred_model = ir .serde .deserialize_model (inferred_proto )
33
+ modified = False
34
+ for original_graph , inferred_graph in zip (model .graphs (), inferred_model .graphs ()):
35
+ original_values = ir .convenience .create_value_mapping (original_graph )
36
+ inferred_values = ir .convenience .create_value_mapping (inferred_graph )
37
+ for name , value in original_values .items ():
38
+ if name in inferred_values :
39
+ inferred_value = inferred_values [name ]
40
+ if value .shape != inferred_value .shape and inferred_value .shape is not None :
41
+ value .shape = inferred_value .shape
42
+ modified = True
43
+ if value .dtype != inferred_value .dtype and inferred_value .dtype is not None :
44
+ value .dtype = inferred_value .dtype
45
+ modified = True
46
+ else :
47
+ logger .warning (
48
+ "Value %s not found in inferred graph %s" , name , inferred_graph .name
49
+ )
50
+ return modified
51
+
52
+
53
+ class ShapeInferencePass (ir .passes .InPlacePass ):
26
54
"""This pass performs shape inference on the graph."""
27
55
28
56
def __init__ (
29
57
self , check_type : bool = True , strict_mode : bool = True , data_prop : bool = True
30
58
) -> None :
31
59
"""Initialize the shape inference pass.
32
60
61
+ If inference fails, the model is left unchanged.
62
+
33
63
Args:
34
64
check_type: If True, check the types of the inputs and outputs.
35
65
strict_mode: If True, use strict mode for shape inference.
@@ -41,75 +71,22 @@ def __init__(
41
71
self .data_prop = data_prop
42
72
43
73
def call (self , model : ir .Model ) -> ir .passes .PassResult :
44
- # Store the original initializer values so they can be restored
45
- initializer_values = tuple (model .graph .initializers .values ())
46
- tensors = {v .name : v .const_value for v in initializer_values }
47
- original_inputs_len = len (model .graph .inputs )
48
- initializer_names = {v .name for v in initializer_values }
49
-
50
- # Turn the initializers into inputs and clear the initializers
51
- # to limit the model size
52
- for initializer in initializer_values :
53
- # Make sure the initializer has its shape/type set
54
- assert initializer .const_value is not None
55
- if initializer .shape is None :
56
- initializer .shape = initializer .const_value .shape # type: ignore[assignment]
57
- if initializer .dtype is None :
58
- initializer .dtype = initializer .const_value .dtype
59
- if initializer not in model .graph .inputs :
60
- model .graph .inputs .append (initializer )
61
- if initializer .const_value .nbytes > _BIG_TENSOR_SIZE_LIMIT :
62
- # Temporarily remove the initializer value to reduce model size
63
- # for onnx.shape_inference
64
- initializer .const_value = None
65
- assert initializer .name is not None
66
- model .graph .initializers .pop (initializer .name )
67
-
68
- # Perform shape inference
69
- try :
70
- proto = ir .serde .serialize_model (model )
71
- value_infos = {info .name : info for info in proto .graph .value_info }
72
- inferred_proto = onnx .shape_inference .infer_shapes (
74
+ def partial_infer_shapes (proto : onnx .ModelProto ) -> onnx .ModelProto :
75
+ return onnx .shape_inference .infer_shapes (
73
76
proto ,
74
77
check_type = self .check_type ,
75
78
strict_mode = self .strict_mode ,
76
79
data_prop = self .data_prop ,
77
80
)
78
- inferred_value_infos = {
79
- info .name : info for info in inferred_proto .graph .value_info
80
- }
81
- inferred_model = ir .serde .deserialize_model (inferred_proto )
82
-
83
- except Exception : # pylint: disable=broad-exception-caught
84
- logger .warning ("Shape inference failed. The model is not modified" , exc_info = True )
85
- return ir .passes .PassResult (model , modified = False )
86
- finally :
87
- # Restore the original initializer values so the model is unchanged
88
- for initializer in initializer_values :
89
- if initializer .name in initializer_names :
90
- initializer .const_value = tensors [initializer .name ]
91
- model .graph .register_initializer (initializer )
92
-
93
- # Restore the original inputs
94
- inputs = model .graph .inputs [:original_inputs_len ]
95
- model .graph .inputs .clear ()
96
- model .graph .inputs .extend (inputs )
97
-
98
- # Add the original initializer tensors to the new (inferred) model
99
- for new_input in inferred_model .graph .inputs :
100
- # Assign the tensors back to the initializers
101
- if new_input .name in initializer_names :
102
- new_input .const_value = tensors [new_input .name ]
103
- inferred_model .graph .register_initializer (new_input )
104
-
105
- # Remove the inputs that were added
106
- new_inputs = inferred_model .graph .inputs [:original_inputs_len ]
107
- inferred_model .graph .inputs .clear ()
108
- inferred_model .graph .inputs .extend (new_inputs )
109
-
110
- return ir .passes .PassResult (
111
- inferred_model , modified = value_infos != inferred_value_infos
112
- )
81
+
82
+ try :
83
+ inferred_model_proto = _c_api_utils .call_onnx_api (partial_infer_shapes , model )
84
+ except Exception as e : # pylint: disable=broad-exception-caught
85
+ logger .warning ("Shape inference failed: %s. Model is left unchanged" , exc_info = e )
86
+ return ir .passes .PassResult (model , False )
87
+
88
+ modified = _merge_func (model , inferred_model_proto )
89
+ return ir .passes .PassResult (model , modified = modified )
113
90
114
91
115
92
def infer_shapes (
0 commit comments