15
15
16
16
17
17
class FuseBiasMHA (pattern .RewriteRuleClassBase ):
18
- def __init__ (
19
- self ,
20
- name ,
21
- * ,
22
- q_no_bias : bool ,
23
- k_no_bias : bool ,
24
- v_no_bias : bool ,
25
- ):
26
- super ().__init__ (name )
27
- self ._q_no_bias = q_no_bias
28
- self ._k_no_bias = k_no_bias
29
- self ._v_no_bias = v_no_bias
30
-
31
18
def pattern (
32
19
self ,
33
20
op ,
@@ -43,18 +30,21 @@ def pattern(
43
30
num_heads ,
44
31
# scale,
45
32
):
46
- if not self ._q_no_bias :
47
- query_BSD = op .Add (query_matmul , q_bias )
48
- else :
49
- query_BSD = query_matmul
50
- if not self ._k_no_bias :
51
- key_BSD = op .Add (key_matmul , k_bias )
52
- else :
53
- key_BSD = key_matmul
54
- if not self ._v_no_bias :
55
- value_BSD = op .Add (value_matmul , v_bias )
56
- else :
57
- value_BSD = value_matmul
33
+ query_BSD = pattern .OrValue (
34
+ [op .Add (query_matmul , q_bias ), query_matmul ],
35
+ tag_var = "has_q_bias" ,
36
+ tag_values = [True , False ],
37
+ )
38
+ key_BSD = pattern .OrValue (
39
+ [op .Add (key_matmul , k_bias ), key_matmul ],
40
+ tag_var = "has_k_bias" ,
41
+ tag_values = [True , False ],
42
+ )
43
+ value_BSD = pattern .OrValue (
44
+ [op .Add (value_matmul , v_bias ), value_matmul ],
45
+ tag_var = "has_v_bias" ,
46
+ tag_values = [True , False ],
47
+ )
58
48
59
49
return op .MultiHeadAttention (
60
50
query_BSD ,
@@ -72,14 +62,20 @@ def pattern(
72
62
73
63
def check (
74
64
self ,
75
- op ,
65
+ context ,
76
66
query_matmul ,
77
67
key_matmul ,
78
68
value_matmul ,
69
+ has_q_bias ,
70
+ has_k_bias ,
71
+ has_v_bias ,
79
72
** _ ,
80
73
) -> pattern .MatchResult : # type: ignore[name-defined]
81
74
check_result = pattern .MatchResult ()
82
75
76
+ if not (has_q_bias or has_k_bias or has_v_bias ):
77
+ return check_result .fail ("None of query, key, or value have a bias." )
78
+
83
79
self .bindings : dict [str , Dim ] = {}
84
80
85
81
def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
@@ -139,15 +135,15 @@ def rewrite(
139
135
# scale,
140
136
** _ ,
141
137
):
142
- if self . _q_no_bias :
138
+ if q_bias is None :
143
139
q_bias = op .Constant (
144
140
value = ir .tensor (numpy .zeros ((self .Dh_q ,), dtype = query_matmul .dtype .numpy ()))
145
141
)
146
- if self . _k_no_bias :
142
+ if k_bias is None :
147
143
k_bias = op .Constant (
148
144
value = ir .tensor (numpy .zeros ((self .Dh_k ,), dtype = key_matmul .dtype .numpy ()))
149
145
)
150
- if self . _v_no_bias :
146
+ if v_bias is None :
151
147
v_bias = op .Constant (
152
148
value = ir .tensor (numpy .zeros ((self .Dh_v ,), dtype = value_matmul .dtype .numpy ()))
153
149
)
@@ -167,30 +163,7 @@ def rewrite(
167
163
)
168
164
169
165
170
- parameter_combinations = [
171
- {
172
- "q_no_bias" : q_no_bias ,
173
- "k_no_bias" : k_no_bias ,
174
- "v_no_bias" : v_no_bias ,
175
- }
176
- for q_no_bias in [False , True ]
177
- for k_no_bias in [False , True ]
178
- for v_no_bias in [False , True ]
179
- ]
180
-
181
- # Dynamically create the rules
182
- fuse_mha_bias_rules = pattern .RewriteRuleSet (
183
- [
184
- FuseBiasMHA .rule (
185
- f"MHABias{ '_NoQBias' if params ['q_no_bias' ] else '' } "
186
- f"{ '_NoKBias' if params ['k_no_bias' ] else '' } "
187
- f"{ '_NoVBias' if params ['v_no_bias' ] else '' } " ,
188
- ** params ,
189
- )
190
- # Exclude (True, True, True) as it is an unnecessary case
191
- for params in parameter_combinations [:- 1 ]
192
- ]
193
- )
166
+ fuse_mha_bias_rules = pattern .RewriteRuleSet ([FuseBiasMHA .rule ()])
194
167
195
168
196
169
fuse_mha_bias = _fusion_utils .apply_fusion_rules (fuse_mha_bias_rules )
0 commit comments