Skip to content

Commit 7aba165

Browse files
authored
Cleanup mha-bias rules using disjunction (#2326)
The MHA-Bias rules can be simplified using pattern-disjunction. (This _may_ help with Whisper ... that was my original motivation, but not sure, after I fixed another issue in PR #2325, which may be the primary issue ). But the cleanup is useful anyway, and it makes fusion more efficient.) Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 0288a66 commit 7aba165

File tree

1 file changed

+26
-53
lines changed

1 file changed

+26
-53
lines changed

onnxscript/rewriter/ort_fusions/fuse_mha_bias.py

Lines changed: 26 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,6 @@
1515

1616

1717
class FuseBiasMHA(pattern.RewriteRuleClassBase):
18-
def __init__(
19-
self,
20-
name,
21-
*,
22-
q_no_bias: bool,
23-
k_no_bias: bool,
24-
v_no_bias: bool,
25-
):
26-
super().__init__(name)
27-
self._q_no_bias = q_no_bias
28-
self._k_no_bias = k_no_bias
29-
self._v_no_bias = v_no_bias
30-
3118
def pattern(
3219
self,
3320
op,
@@ -43,18 +30,21 @@ def pattern(
4330
num_heads,
4431
# scale,
4532
):
46-
if not self._q_no_bias:
47-
query_BSD = op.Add(query_matmul, q_bias)
48-
else:
49-
query_BSD = query_matmul
50-
if not self._k_no_bias:
51-
key_BSD = op.Add(key_matmul, k_bias)
52-
else:
53-
key_BSD = key_matmul
54-
if not self._v_no_bias:
55-
value_BSD = op.Add(value_matmul, v_bias)
56-
else:
57-
value_BSD = value_matmul
33+
query_BSD = pattern.OrValue(
34+
[op.Add(query_matmul, q_bias), query_matmul],
35+
tag_var="has_q_bias",
36+
tag_values=[True, False],
37+
)
38+
key_BSD = pattern.OrValue(
39+
[op.Add(key_matmul, k_bias), key_matmul],
40+
tag_var="has_k_bias",
41+
tag_values=[True, False],
42+
)
43+
value_BSD = pattern.OrValue(
44+
[op.Add(value_matmul, v_bias), value_matmul],
45+
tag_var="has_v_bias",
46+
tag_values=[True, False],
47+
)
5848

5949
return op.MultiHeadAttention(
6050
query_BSD,
@@ -72,14 +62,20 @@ def pattern(
7262

7363
def check(
7464
self,
75-
op,
65+
context,
7666
query_matmul,
7767
key_matmul,
7868
value_matmul,
69+
has_q_bias,
70+
has_k_bias,
71+
has_v_bias,
7972
**_,
8073
) -> pattern.MatchResult: # type: ignore[name-defined]
8174
check_result = pattern.MatchResult()
8275

76+
if not (has_q_bias or has_k_bias or has_v_bias):
77+
return check_result.fail("None of query, key, or value have a bias.")
78+
8379
self.bindings: dict[str, Dim] = {}
8480

8581
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
@@ -139,15 +135,15 @@ def rewrite(
139135
# scale,
140136
**_,
141137
):
142-
if self._q_no_bias:
138+
if q_bias is None:
143139
q_bias = op.Constant(
144140
value=ir.tensor(numpy.zeros((self.Dh_q,), dtype=query_matmul.dtype.numpy()))
145141
)
146-
if self._k_no_bias:
142+
if k_bias is None:
147143
k_bias = op.Constant(
148144
value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=key_matmul.dtype.numpy()))
149145
)
150-
if self._v_no_bias:
146+
if v_bias is None:
151147
v_bias = op.Constant(
152148
value=ir.tensor(numpy.zeros((self.Dh_v,), dtype=value_matmul.dtype.numpy()))
153149
)
@@ -167,30 +163,7 @@ def rewrite(
167163
)
168164

169165

170-
parameter_combinations = [
171-
{
172-
"q_no_bias": q_no_bias,
173-
"k_no_bias": k_no_bias,
174-
"v_no_bias": v_no_bias,
175-
}
176-
for q_no_bias in [False, True]
177-
for k_no_bias in [False, True]
178-
for v_no_bias in [False, True]
179-
]
180-
181-
# Dynamically create the rules
182-
fuse_mha_bias_rules = pattern.RewriteRuleSet(
183-
[
184-
FuseBiasMHA.rule(
185-
f"MHABias{'_NoQBias' if params['q_no_bias'] else ''}"
186-
f"{'_NoKBias' if params['k_no_bias'] else ''}"
187-
f"{'_NoVBias' if params['v_no_bias'] else ''}",
188-
**params,
189-
)
190-
# Exclude (True, True, True) as it is an unnecessary case
191-
for params in parameter_combinations[:-1]
192-
]
193-
)
166+
fuse_mha_bias_rules = pattern.RewriteRuleSet([FuseBiasMHA.rule()])
194167

195168

196169
fuse_mha_bias = _fusion_utils.apply_fusion_rules(fuse_mha_bias_rules)

0 commit comments

Comments
 (0)