Skip to content

Commit da5cc79

Browse files
Add spda rules and tests for unmasked patterns (#2123)
- Add spda rules and tests for unmasked patterns - Fix parameterized import
1 parent aa62570 commit da5cc79

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
7373
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
7474

7575

76+
# Rules for SDPA without mask
77+
unmasked_pre_div_sdpa_rule = SDPA.rule(
78+
"unmasked_pre_div_sdpa", use_mask=False, pre_scale=True, use_mul=False
79+
)
80+
unmasked_pre_mul_sdpa_rule = SDPA.rule(
81+
"unmasked_pre_mul_sdpa", use_mask=False, pre_scale=True, use_mul=True
82+
)
83+
unmasked_post_div_sdpa_rule = SDPA.rule(
84+
"unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False
85+
)
86+
unmasked_post_mul_sdpa_rule = SDPA.rule(
87+
"unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True
88+
)
89+
90+
# Rules for SDPA with mask
7691
masked_pre_div_sdpa_rule = SDPA.rule(
7792
"masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False
7893
)
@@ -83,11 +98,15 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
8398
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
8499
)
85100
masked_post_mul_sdpa_rule = SDPA.rule(
86-
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True
101+
"masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True
87102
)
88103

89104
sdpa_rules = pattern.RewriteRuleSet(
90105
[
106+
unmasked_pre_mul_sdpa_rule,
107+
unmasked_post_div_sdpa_rule,
108+
unmasked_post_mul_sdpa_rule,
109+
unmasked_pre_div_sdpa_rule,
91110
masked_pre_mul_sdpa_rule,
92111
masked_post_div_sdpa_rule,
93112
masked_post_mul_sdpa_rule,

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import unittest
1010

1111
import numpy
12-
from parameterized import parameterized
12+
import parameterized
1313

1414
import onnxscript.ir as ir
1515
import onnxscript.optimizer
@@ -28,6 +28,52 @@
2828
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
2929

3030

31+
@script()
32+
def _unmasked_pre_div_sdpa_script(query, key, value, mask):
33+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
34+
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
35+
scaled_query = op.Div(query, divisor)
36+
scaled_key = op.Div(key_transposed, divisor)
37+
attn_score = op.MatMul(scaled_query, scaled_key)
38+
attn_weight = op.Softmax(attn_score, axis=-1)
39+
attn_output = op.MatMul(attn_weight, value)
40+
return attn_output
41+
42+
43+
@script()
44+
def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
45+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
46+
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
47+
scaled_query = op.Mul(query, multiplier)
48+
scaled_key = op.Mul(key_transposed, multiplier)
49+
attn_score = op.MatMul(scaled_query, scaled_key)
50+
attn_weight = op.Softmax(attn_score, axis=-1)
51+
attn_output = op.MatMul(attn_weight, value)
52+
return attn_output
53+
54+
55+
@script()
56+
def _unmasked_post_div_sdpa_script(query, key, value, mask):
57+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
58+
divisor = op.Constant(value_float=SCALE_FACTOR)
59+
attn_score = op.MatMul(query, key_transposed)
60+
scaled_attn_score = op.Div(attn_score, divisor)
61+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
62+
attn_output = op.MatMul(attn_weight, value)
63+
return attn_output
64+
65+
66+
@script()
67+
def _unmasked_post_mul_sdpa_script(query, key, value, mask):
68+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
69+
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
70+
attn_score = op.MatMul(query, key_transposed)
71+
scaled_attn_score = op.Mul(attn_score, multiplier)
72+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
73+
attn_output = op.MatMul(attn_weight, value)
74+
return attn_output
75+
76+
3177
@script()
3278
def _masked_pre_div_sdpa_script(query, key, value, mask):
3379
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
@@ -105,8 +151,12 @@ def get_ort_inputs(self):
105151

106152

107153
class TestSDPAFusion(unittest.TestCase):
108-
@parameterized.expand(
154+
@parameterized.parameterized.expand(
109155
[
156+
("unmasked_pre_div", _unmasked_pre_div_sdpa_script),
157+
("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script),
158+
("unmasked_post_div", _unmasked_post_div_sdpa_script),
159+
("unmasked_post_mul", _unmasked_post_mul_sdpa_script),
110160
("pre_div", _masked_pre_div_sdpa_script),
111161
("pre_mul", _masked_pre_mul_sdpa_script),
112162
("post_div", _masked_post_div_sdpa_script),

0 commit comments

Comments
 (0)