2
2
# Licensed under the MIT License.
3
3
from __future__ import annotations
4
4
5
- from typing import Sequence
5
+ from typing import Sequence , Union
6
6
7
7
import onnxscript .ir as ir
8
- from onnxscript .rewriter import pattern
8
+ from onnxscript .rewriter import _ir_utils , pattern
9
9
10
10
"""
11
- The MultiHeadAttention pattern:
11
+ The MultiHeadAttention pattern: generate an instance
12
+ MHA (query, key, value, None, None, mask, past_key, past_value)
13
+ where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv).
14
+ The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias)
15
+ must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh).
12
16
17
+ We use the following abbreviations for the dimensions:
13
18
B: Batch size
14
19
S: Sequence length
15
20
D: input embedding dimension
21
+ Dv: value hidden size (usually, Dv = D)
16
22
H: number of heads
17
- d_h: head size (usually, D = H * d_h)
23
+ Dh: head size or embedding dimension per head (usually, D = H * Dh)
24
+ Skv: key/value sequence length
25
+ St: total sequence length
18
26
19
- thus, weights are usually of shape (D, D) and (D, D) and (D, D)
20
-
21
- for each of Q, K, and V, we have the following pattern:
22
- MatMul (Input, W), producing output of shape (B, S, D)
23
- Reshape to produce a matrix of shape (B, S, H, d_h)
24
- Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)
25
-
26
- This is followed by a RotaryEmbedding pattern for Q and K
27
-
28
- The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)
29
-
30
- The dot-product attention is then computed using SDPA.
31
- Finally, the output is transposed and reshaped back to (B, S, D) shape
27
+ In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh).
28
+ The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh).
32
29
"""
33
30
31
+ Dim = Union [int , ir .SymbolicDim ]
34
32
35
- def _check_shape (bindings : dict [str , int ], val : ir .Value , shape : Sequence [str ]) -> bool :
33
+
34
+ def _check_shape (bindings : dict [str , Dim ], val : ir .Value , shape : Sequence [str ]) -> bool :
36
35
if val .shape is None :
37
36
return False
38
37
if val .shape .rank () != len (shape ):
@@ -46,131 +45,170 @@ def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Sequence[str])
46
45
47
46
48
47
class MultiHeadAttention (pattern .RewriteRuleClassBase ):
49
- def __init__ (self , name : str , * , use_2d_matmul : bool ):
50
- super ().__init__ (name )
51
- self ._use_2d_matmul = use_2d_matmul
52
-
53
- def _compute_QKV (self , op , input , weight , reshape_var : str ):
54
- """Applied to generate each of Q, K, and V from input."""
55
- if self ._use_2d_matmul :
56
- # Convert batched input of shape (B, S, D) to 2D input (B*S, D)
57
- input = op .Reshape (input , _allow_other_inputs = True )
58
- projected = op .MatMul (input , weight )
59
- if self ._use_2d_matmul :
60
- # Convert 2D output back to batched output of shape (B, S, D)
61
- projected = op .Reshape (projected , _allow_other_inputs = True )
62
- # Reshape from (B, S, D) to (B, S, H, D/H)
63
- reshaped = op .Reshape (
64
- projected ,
65
- _allow_other_inputs = True ,
66
- _allow_other_attributes = True ,
67
- _outputs = [reshape_var ],
68
- )
69
- # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
70
- transposed = op .Transpose (reshaped , perm = [0 , 2 , 1 , 3 ])
71
- return transposed
48
+ def __init__ (self ):
49
+ super ().__init__ ("MHA" )
72
50
73
51
def pattern (
74
52
self ,
75
53
op ,
76
- input ,
77
- query_weight ,
78
- key_weight ,
79
- value_weight ,
80
- qkv_weight ,
54
+ query_BSD ,
55
+ key_BSD ,
56
+ value_BSD ,
81
57
mask ,
82
- cos ,
83
- sin ,
84
58
past_key ,
85
59
past_value ,
86
60
position_ids ,
61
+ cos ,
62
+ sin ,
87
63
):
88
- query = self ._compute_QKV (op , input , query_weight , "query_mm_reshaped" )
89
- key = self ._compute_QKV (op , input , key_weight , "key_mm_reshaped" )
90
- value = self ._compute_QKV (op , input , value_weight , "value_mm_reshaped" )
64
+ # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
65
+
66
+ # Reshape from (B, S, D) to (B, S, H, D/H)
67
+ query_BSHDh = op .Reshape (
68
+ query_BSD ,
69
+ _allow_other_inputs = True ,
70
+ _allow_other_attributes = True ,
71
+ _outputs = ["query_BSHDh" ],
72
+ )
73
+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
74
+ query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
75
+
76
+ # Reshape from (B, S, D) to (B, S, H, D/H)
77
+ key_BSHDh = op .Reshape (
78
+ key_BSD ,
79
+ _allow_other_inputs = True ,
80
+ _allow_other_attributes = True ,
81
+ _outputs = ["key_BSHDh" ],
82
+ )
83
+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84
+ key_BHSDh = op .Transpose (key_BSHDh , perm = [0 , 2 , 1 , 3 ])
85
+
86
+ # Reshape from (B, S, D) to (B, S, H, D/H)
87
+ value_BSHDh = op .Reshape (
88
+ value_BSD ,
89
+ _allow_other_inputs = True ,
90
+ _allow_other_attributes = True ,
91
+ _outputs = ["value_BSHDh" ],
92
+ )
93
+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
94
+ value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
95
+
96
+ query_BHSDh_rope = op .RotaryEmbedding (
97
+ query_BHSDh , position_ids , cos , sin , _domain = "com.microsoft"
98
+ )
99
+ key_BHSDh_rope = op .RotaryEmbedding (
100
+ key_BHSDh , position_ids , cos , sin , _domain = "com.microsoft"
101
+ )
91
102
92
- query_rope = op .RotaryEmbedding (query , position_ids , cos , sin , _domain = "com.microsoft" )
103
+ # Concatenate past_key cache and current key, and transpose to enable
104
+ # dot-product attention computation.
93
105
94
- key_rope = op .RotaryEmbedding (key , position_ids , cos , sin , _domain = "com.microsoft" )
95
- key_rope = op .Concat (past_key , key_rope , axis = - 2 )
96
- # Transpose last two axes of key_rope to compute dot-product via matmul.
97
- key_reshaped = op .Reshape (
98
- key_rope , _allow_other_inputs = True , _outputs = ["key_reshaped" ]
106
+ key_seq = op .Concat (past_key , key_BHSDh_rope , axis = - 2 )
107
+ # Transpose last two axes of key_seq to compute dot-product via matmul.
108
+ key_seq_BH_Skv_Dh = op .Reshape (
109
+ key_seq , _allow_other_inputs = True , _outputs = ["key_seq_BH_Skv_Dh" ]
99
110
)
100
- key_reshaped_transposed = op .Transpose (key_reshaped , perm = [0 , 2 , 1 ])
101
- key_transposed = op .Reshape (
102
- key_reshaped_transposed , _allow_other_inputs = True , _outputs = ["key_transposed " ]
111
+ key_seq_BH_Dh_Skv = op .Transpose (key_seq_BH_Skv_Dh , perm = [0 , 2 , 1 ])
112
+ key_seq_B_H_Dh_Skv = op .Reshape (
113
+ key_seq_BH_Dh_Skv , _allow_other_inputs = True , _outputs = ["key_seq_B_H_Dh_Skv " ]
103
114
)
104
115
105
- value = op .Concat (past_value , value , axis = - 2 )
116
+ # Concatenate past_value cache and current value
117
+ value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
106
118
107
119
attention = op .SDPA (
108
- query_rope , key_transposed , value , mask , _domain = "ai.onnxruntime.fusion"
120
+ query_BHSDh_rope ,
121
+ key_seq_B_H_Dh_Skv ,
122
+ value_seq ,
123
+ mask ,
124
+ _domain = "ai.onnxruntime.fusion" ,
109
125
)
110
- # Transpose back to (B, S, H, D/H)
126
+
127
+ # Transpose attention back to (B, S, H, D/H)
111
128
attention_transposed = op .Transpose (attention , perm = [0 , 2 , 1 , 3 ])
112
129
# Reshape back to (B, S, D)
113
130
attention_reshaped = op .Reshape (
114
131
attention_transposed , _allow_other_inputs = True , _outputs = ["attention_reshaped" ]
115
132
)
116
- return attention_reshaped , key_rope , value
133
+ return attention_reshaped , key_seq , value_seq
117
134
118
135
def check (
119
136
self ,
120
137
op ,
121
- query_mm_reshaped ,
122
- key_mm_reshaped ,
123
- value_mm_reshaped ,
124
- key_reshaped ,
125
- key_transposed ,
126
- attention_reshaped ,
138
+ query_BSD ,
139
+ key_BSD ,
140
+ value_BSD ,
141
+ mask ,
142
+ past_key ,
143
+ past_value ,
144
+ query_BSHDh ,
145
+ key_BSHDh ,
146
+ value_BSHDh ,
127
147
** _ ,
128
148
):
129
- bindings : dict [str , int ] = {}
130
- status = (
131
- _check_shape ( bindings , query_mm_reshaped , [ "B" , "S" , "H" , "d_h" ])
132
- and _check_shape (bindings , key_mm_reshaped , [ "B" , "S" , "H" , "d_h" ] )
133
- and _check_shape ( bindings , value_mm_reshaped , [ "B" , "S" , "H" , "d_h" ])
134
- and _check_shape ( bindings , key_reshaped , ["B*H " , "KVS " , "d_h " ])
135
- and _check_shape ( bindings , key_transposed , [ "B" , "H" , "d_h" , "KVS" ])
136
- and _check_shape ( bindings , attention_reshaped , ["B" , "S " , "H*d_h " ])
137
- )
138
- if not status :
149
+ bindings : dict [str , Dim ] = {}
150
+
151
+ def no_match ( val : ir . Value , dims : Sequence [ str ]) -> bool :
152
+ return not _check_shape (bindings , val , dims )
153
+
154
+ if no_match ( query_BSD , ["B" , "S " , "D " ]):
155
+ return False
156
+ if no_match ( key_BSD , ["B" , "Skv " , "D " ]):
157
+ return False
158
+ if no_match ( value_BSD , [ "B" , "Skv" , "D" ]) :
139
159
return False
140
- # if bindings["B"] * bindings["H"] != bindings["B*H"]:
141
- # return False
142
- # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
143
- # return False
160
+
161
+ if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
162
+ return False
163
+ if no_match (past_value , ["B" , "H" , "Spast" , "Dv" ]):
164
+ return False
165
+ if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
166
+ return False
167
+ if no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
168
+ return False
169
+ if no_match (value_BSHDh , ["B" , "S" , "H" , "Dh" ]):
170
+ return False
171
+ # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
172
+ # But this also, unforunately, depends on ORT version.
173
+
174
+ # TODO: verify Reshapes:
175
+ # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
176
+ # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
177
+ # or check Reshape's shape-input value
144
178
return True
145
179
146
180
def rewrite (
147
181
self ,
148
182
op ,
149
- input ,
150
- query_weight ,
151
- key_weight ,
152
- value_weight ,
183
+ query_BSD ,
184
+ key_BSD ,
185
+ value_BSD ,
153
186
mask ,
154
- cos ,
155
- sin ,
156
187
past_key ,
157
188
past_value ,
189
+ key_BSHDh ,
158
190
position_ids ,
159
- query_mm_reshaped ,
191
+ cos ,
192
+ sin ,
160
193
** _ ,
161
194
):
162
- num_heads = query_mm_reshaped .shape [2 ]
163
- query = op .MatMul (input , query_weight )
164
- key = op .MatMul (input , key_weight )
165
- value = op .MatMul (input , value_weight )
166
-
167
- query_rope = op .RotaryEmbedding (query , position_ids , cos , sin , _domain = "com.microsoft" )
168
- key_rope = op .RotaryEmbedding (key , position_ids , cos , sin , _domain = "com.microsoft" )
195
+ num_heads = _ir_utils .get_dim (key_BSHDh , 2 )
196
+ if not isinstance (num_heads , int ):
197
+ return None
198
+
199
+ # Switch to 3D RotaryEmbedding
200
+ # TODO: forward other attributes
201
+ query_BSD_rope = op .RotaryEmbedding (
202
+ query_BSD , position_ids , cos , sin , _domain = "com.microsoft"
203
+ )
204
+ key_BSD_rope = op .RotaryEmbedding (
205
+ key_BSD , position_ids , cos , sin , _domain = "com.microsoft"
206
+ )
169
207
170
208
return op .MultiHeadAttention (
171
- query_rope ,
172
- key_rope ,
173
- value ,
209
+ query_BSD_rope ,
210
+ key_BSD_rope ,
211
+ value_BSD ,
174
212
None , # bias
175
213
None , # key padding mask
176
214
mask , # attention mask/bias
@@ -182,11 +220,15 @@ def rewrite(
182
220
)
183
221
184
222
185
- _rule1 = MultiHeadAttention .rule ("MHA_2dmm" , use_2d_matmul = False )
223
+ _rule1 = MultiHeadAttention .rule ()
186
224
187
225
mha_rules = pattern .RewriteRuleSet ([_rule1 ])
188
226
189
227
190
- def fuse_mha (model : ir .Model ) -> int :
228
+ def fuse_mha (model : ir .Model , * , debug : bool = False ) -> int :
191
229
count = mha_rules .apply_to_model (model )
230
+ if debug and count == 0 :
231
+ tracer = pattern .MatchingTracer ()
232
+ mha_rules .apply_to_model (model , tracer = tracer )
233
+ tracer .report ()
192
234
return count
0 commit comments