Skip to content

Commit e805383

Browse files
shubhambhokare1bmehta001
authored andcommitted
Add spda rules and tests for unmasked patterns (microsoft#2123)
- Add spda rules and tests for unmasked patterns - Fix parameterized import
1 parent da083ee commit e805383

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
6363
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
6464

6565

66+
# Rules for SDPA without mask
67+
unmasked_pre_div_sdpa_rule = SDPA.rule(
68+
"unmasked_pre_div_sdpa", use_mask=False, pre_scale=True, use_mul=False
69+
)
70+
unmasked_pre_mul_sdpa_rule = SDPA.rule(
71+
"unmasked_pre_mul_sdpa", use_mask=False, pre_scale=True, use_mul=True
72+
)
73+
unmasked_post_div_sdpa_rule = SDPA.rule(
74+
"unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False
75+
)
76+
unmasked_post_mul_sdpa_rule = SDPA.rule(
77+
"unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True
78+
)
79+
80+
# Rules for SDPA with mask
6681
masked_pre_div_sdpa_rule = SDPA.rule(
6782
"masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False
6883
)
@@ -73,10 +88,21 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
7388
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
7489
)
7590
masked_post_mul_sdpa_rule = SDPA.rule(
76-
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True
91+
"masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True
7792
)
7893

79-
sdpa_rules = pattern.RewriteRuleSet([masked_pre_mul_sdpa_rule, masked_post_div_sdpa_rule])
94+
sdpa_rules = pattern.RewriteRuleSet(
95+
[
96+
unmasked_pre_mul_sdpa_rule,
97+
unmasked_post_div_sdpa_rule,
98+
unmasked_post_mul_sdpa_rule,
99+
unmasked_pre_div_sdpa_rule,
100+
masked_pre_mul_sdpa_rule,
101+
masked_post_div_sdpa_rule,
102+
masked_post_mul_sdpa_rule,
103+
masked_pre_div_sdpa_rule,
104+
]
105+
)
80106

81107

82108
def fuse_sdpa(model: ir.Model) -> int:

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import unittest
1010

1111
import numpy
12-
from parameterized import parameterized
12+
import parameterized
1313

1414
import onnxscript.ir as ir
1515
import onnxscript.optimizer
@@ -28,6 +28,52 @@
2828
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
2929

3030

31+
@script()
32+
def _unmasked_pre_div_sdpa_script(query, key, value, mask):
33+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
34+
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
35+
scaled_query = op.Div(query, divisor)
36+
scaled_key = op.Div(key_transposed, divisor)
37+
attn_score = op.MatMul(scaled_query, scaled_key)
38+
attn_weight = op.Softmax(attn_score, axis=-1)
39+
attn_output = op.MatMul(attn_weight, value)
40+
return attn_output
41+
42+
43+
@script()
44+
def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
45+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
46+
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
47+
scaled_query = op.Mul(query, multiplier)
48+
scaled_key = op.Mul(key_transposed, multiplier)
49+
attn_score = op.MatMul(scaled_query, scaled_key)
50+
attn_weight = op.Softmax(attn_score, axis=-1)
51+
attn_output = op.MatMul(attn_weight, value)
52+
return attn_output
53+
54+
55+
@script()
56+
def _unmasked_post_div_sdpa_script(query, key, value, mask):
57+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
58+
divisor = op.Constant(value_float=SCALE_FACTOR)
59+
attn_score = op.MatMul(query, key_transposed)
60+
scaled_attn_score = op.Div(attn_score, divisor)
61+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
62+
attn_output = op.MatMul(attn_weight, value)
63+
return attn_output
64+
65+
66+
@script()
67+
def _unmasked_post_mul_sdpa_script(query, key, value, mask):
68+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
69+
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
70+
attn_score = op.MatMul(query, key_transposed)
71+
scaled_attn_score = op.Mul(attn_score, multiplier)
72+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
73+
attn_output = op.MatMul(attn_weight, value)
74+
return attn_output
75+
76+
3177
@script()
3278
def _masked_pre_div_sdpa_script(query, key, value, mask):
3379
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
@@ -105,8 +151,12 @@ def get_ort_inputs(self):
105151

106152

107153
class TestSDPAFusion(unittest.TestCase):
108-
@parameterized.expand(
154+
@parameterized.parameterized.expand(
109155
[
156+
("unmasked_pre_div", _unmasked_pre_div_sdpa_script),
157+
("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script),
158+
("unmasked_post_div", _unmasked_post_div_sdpa_script),
159+
("unmasked_post_mul", _unmasked_post_mul_sdpa_script),
110160
("pre_div", _masked_pre_div_sdpa_script),
111161
("pre_mul", _masked_pre_mul_sdpa_script),
112162
("post_div", _masked_post_div_sdpa_script),

0 commit comments

Comments
 (0)