Skip to content

Commit fc0523d

Browse files
fix skip_normalization fusion
1 parent c5c96af commit fc0523d

File tree

4 files changed

+27
-12
lines changed

4 files changed

+27
-12
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,13 @@ def optimize_for_ort(
122122
model, fusion_count = fuse_xformers(model)
123123
rewrite(model, ORT_PATTERN_REWRITE_RULES)
124124
return model, fusion_count
125+
126+
'''
127+
from onnxscript import ir, rewriter
128+
import onnxscript.rewriter.ort_fusions as ort_fusions
129+
model_ir = ir.serde.deserialize_model(model)
130+
model_ir, count = ort_fusions.optimize_for_ort(model_ir)
131+
print("Applied fusions", count)
132+
print("\n\n\n\n\n\n\n\n\n\n\n")
133+
model = ir.serde.serialize_model(model_ir)
134+
'''

onnxscript/rewriter/ort_fusions/_whisper_tiny.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_whisper_tiny(self):
3737
encoder_model.opset_imports["ai.onnxruntime.fusion"] = 1
3838

3939
print(f"Fused {fusion_count_e} ops")
40-
self.assertEqual(fusion_count_e["skip_layer_normalization"], 17)
40+
self.assertEqual(fusion_count_e["skip_layer_normalization"], 8)
4141
self.assertEqual(fusion_count_e["sdpa"], 4)
4242
self.assertEqual(fusion_count_e["mha"], 4)
4343
self.assertEqual(fusion_count_e["attention"], 4)
@@ -67,7 +67,7 @@ def test_whisper_tiny(self):
6767
decoder_model.opset_imports["ai.onnxruntime.fusion"] = 1
6868

6969
print(f"Fused {fusion_count_d} ops")
70-
self.assertEqual(fusion_count_d["skip_layer_normalization"], 25)
70+
self.assertEqual(fusion_count_d["skip_layer_normalization"], 12)
7171
self.assertEqual(fusion_count_d["sdpa"], 8)
7272
# 4 self-attention + 4 cross-attention
7373
self.assertEqual(fusion_count_d["mha"], 8)

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def rewrite(
198198
# attention_bias,
199199
num_heads,
200200
# scale,
201+
q_mul=None,
202+
k_mul=None,
203+
v_mul=None,
201204
**_,
202205
):
203206
# Use bindings to get the values of Dh_q, Dh_k, and Dh_v
@@ -206,6 +209,8 @@ def rewrite(
206209
# Dh_k = self.bindings.get("Dh_k")
207210
# Dh_v = self.bindings.get("Dh_v")
208211
# qkv_hidden_sizes = [Dh_q, Dh_k, Dh_v]
212+
if self._no_slice:
213+
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=0)
209214

210215
if self._has_past:
211216
attention, present = op.Attention(

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,35 +47,35 @@ def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
4747
epsilon=epsilon,
4848
stash_type=stash_type,
4949
)
50-
return normalized, skip_sum
50+
return normalized
5151

5252

5353
def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type):
5454
if stash_type.value != 1: # FLOAT type
5555
return None
56-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
56+
normalized, _mean, _inv_std_var = op.SkipLayerNormalization(
5757
input,
5858
skip,
5959
gamma,
6060
beta,
6161
epsilon=epsilon,
62-
_outputs=4,
62+
_outputs=3,
6363
_domain="com.microsoft",
6464
)
65-
return normalized, skip_sum
65+
return normalized
6666

6767

6868
# Fusion rule for Add + SkipLayerNormalization
6969
def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type):
70-
bias_sum = op.Add(input, bias)
71-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
72-
bias_sum,
73-
skip,
70+
input_with_bias = op.Add(input, bias)
71+
skip_sum = op.Add(skip, input_with_bias)
72+
normalized = op.LayerNormalization(
73+
skip_sum,
7474
gamma,
7575
beta,
76+
axis=-1,
7677
epsilon=epsilon,
77-
_outputs=4,
78-
_domain="com.microsoft",
78+
stash_type=stash_type,
7979
)
8080
return normalized, skip_sum
8181

0 commit comments

Comments
 (0)