@@ -47,17 +47,18 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
47
47
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
48
48
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
49
49
# incorporated in our optimizer.
50
- model = shape_inference .infer_shapes (model )
50
+ shape_inference .infer_shapes (model )
51
51
optimize (model )
52
52
return model
53
53
54
54
55
- def fuse_xformers (model : ir .Model ) -> tuple [ir .Model , dict [str , int ]]:
55
+ def fuse_xformers (model : ir .Model , debug : bool = False ) -> tuple [ir .Model , dict [str , int ]]:
56
56
"""
57
57
Apply transformer-specific fusions to the given model.
58
58
59
59
Args:
60
60
model: The input ONNX model represented as an `ir.Model`.
61
+ debug: If debug is True, enable pattern matching tracer for debugging.
61
62
62
63
Returns:
63
64
A tuple containing:
@@ -67,35 +68,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
67
68
fusion_count = dict ()
68
69
69
70
model = _pre_optimize (model )
70
- fusion_count ["erf_gelu" ] = fuse_erfgelu (model )
71
- fusion_count ["rms_normalization" ] = fuse_rms_normalization (model )
72
- fusion_count ["skip_layer_normalization" ] = fuse_skip_layer_normalization (model )
73
- fusion_count ["skip_rms_normalization" ] = fuse_skip_rms_normalization (model )
74
- fusion_count ["rotary_embedding" ] = fuse_rotary_embedding (model )
75
- fusion_count ["partial_rotary_embedding" ] = fuse_partial_rotary_embedding (model )
76
- fusion_count ["cos_sin_cache" ] = fuse_cos_sin_cache (model )
77
- fusion_count ["sdpa" ] = fuse_sdpa (model )
71
+
72
+ def fuse (func , apply_shape_inference : bool = False ):
73
+ return func (model , debug = debug , apply_shape_inference = apply_shape_inference )
74
+
75
+ fusion_count ["erf_gelu" ] = fuse (fuse_erfgelu )
76
+ fusion_count ["rms_normalization" ] = fuse (fuse_rms_normalization )
77
+ fusion_count ["skip_layer_normalization" ] = fuse (fuse_skip_layer_normalization )
78
+ fusion_count ["skip_rms_normalization" ] = fuse (fuse_skip_rms_normalization )
79
+ fusion_count ["rotary_embedding" ] = fuse (fuse_rotary_embedding )
80
+ fusion_count ["partial_rotary_embedding" ] = fuse (fuse_partial_rotary_embedding )
81
+ fusion_count ["cos_sin_cache" ] = fuse (fuse_cos_sin_cache )
82
+ fusion_count ["sdpa" ] = fuse (fuse_sdpa , apply_shape_inference = True )
78
83
# Optimize to avoid trying multiple attention-based fusions
79
- fusion_count ["mha" ] = fuse_mha ( model )
84
+ fusion_count ["mha" ] = fuse ( fuse_mha )
80
85
if fusion_count ["mha" ] == 0 :
81
86
# If no MHA fusion was applied, we can try the GQA fusion.
82
87
# and avoid trying the attention fusion.
83
- fusion_count ["gqa" ] = fuse_gqa ( model )
84
- fusion_count ["packed_qkv_for_gqa" ] = fuse_qkv_gqa ( model )
88
+ fusion_count ["gqa" ] = fuse ( fuse_gqa )
89
+ fusion_count ["packed_qkv_for_gqa" ] = fuse ( fuse_qkv_gqa )
85
90
fusion_count ["attention" ] = 0
86
91
else :
87
- fusion_count ["attention" ] = fuse_attention ( model )
92
+ fusion_count ["attention" ] = fuse ( fuse_attention )
88
93
fusion_count ["gqa" ] = 0
89
- fusion_count ["gelu" ] = fuse_gelu ( model )
90
- fusion_count ["bias_gelu" ] = fuse_bias_gelu ( model )
94
+ fusion_count ["gelu" ] = fuse ( fuse_gelu )
95
+ fusion_count ["bias_gelu" ] = fuse ( fuse_bias_gelu )
91
96
# Finally: inline any intermediate fusion functions introduced that were not
92
97
# consumed by other fusions, and eliminate any remaining unused nodes.
93
98
optimize (model )
94
99
return model , fusion_count
95
100
96
101
97
102
def optimize_for_ort (
98
- model : ir .Model , config_name : str | None = None
103
+ model : ir .Model ,
104
+ config_name : str | None = None ,
105
+ * ,
106
+ debug : bool = False ,
99
107
) -> tuple [ir .Model , dict [str , int ]]:
100
108
"""
101
109
Optimize the model for ORT backend.
@@ -108,13 +116,18 @@ def optimize_for_ort(
108
116
config_name: The name of the configuration to use for optimization.
109
117
Typically it identifies the Execution Provider (EP) to optimize for.
110
118
If None, the default configuration will be used.
119
+ debug: If debug is True, enable pattern matching tracer for debugging.
111
120
112
121
Returns:
113
122
A tuple containing:
114
123
- The optimized `ir.Model` after applying transformer-specific fusions.
115
124
- A dictionary with a count of each of the fusions applied.
116
125
"""
117
126
118
- model , fusion_count = fuse_xformers (model )
127
+ model , fusion_count = fuse_xformers (
128
+ model ,
129
+ debug = debug ,
130
+ )
131
+ # Apply the ORT pattern rewrite rules.
119
132
rewrite (model , ORT_PATTERN_REWRITE_RULES )
120
133
return model , fusion_count
0 commit comments