5
5
from typing import Sequence , Union
6
6
7
7
import onnxscript .ir as ir
8
- from onnxscript .rewriter import _fusion_utils , pattern
8
+ from onnxscript .rewriter import _fusion_utils , _ir_utils , pattern
9
9
10
10
"""
11
11
The MultiHeadAttention pattern: generate an instance
@@ -37,13 +37,15 @@ def __init__(
37
37
name ,
38
38
* ,
39
39
transpose_4d : bool ,
40
+ pre_scale_q : bool ,
40
41
is_rotary : bool ,
41
42
use_mask : bool ,
42
43
has_past_present : bool ,
43
44
is_cross_attention : bool ,
44
45
):
45
46
super ().__init__ (name )
46
47
self ._transpose_4d = transpose_4d
48
+ self ._pre_scale_q = pre_scale_q
47
49
self ._is_rotary = is_rotary
48
50
self ._use_mask = use_mask
49
51
self ._has_past_present = has_past_present
@@ -64,9 +66,12 @@ def pattern(
64
66
position_ids ,
65
67
cos ,
66
68
sin ,
69
+ q_scale ,
67
70
):
68
71
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
69
72
73
+ if self ._pre_scale_q :
74
+ query_BSD = op .Mul (query_BSD , q_scale )
70
75
# Reshape from (B, S, D) to (B, S, H, D/H)
71
76
query_BSHDh = op .Reshape (
72
77
query_BSD ,
@@ -202,6 +207,8 @@ def check(
202
207
past_key ,
203
208
past_value ,
204
209
query_BSHDh ,
210
+ key_BSHDh = None ,
211
+ value_BSHDh = None ,
205
212
** _ ,
206
213
) -> pattern .MatchResult : # type: ignore[name-defined]
207
214
check_result = pattern .MatchResult ()
@@ -239,25 +246,24 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
239
246
f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" ,
240
247
past_value ,
241
248
)
242
- """
249
+
243
250
if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
244
251
return check_result .fail (
245
252
f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
246
253
query_BSHDh ,
247
254
)
248
-
249
- if not self.is_cross_attention :
250
- if no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
255
+
256
+ if not self ._is_cross_attention :
257
+ if key_BSHDh and no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
251
258
return check_result .fail (
252
259
f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
253
260
query_BSHDh ,
254
261
)
255
- if no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
262
+ if value_BSHDh and no_match (value_BSHDh , ["B" , "S" , "H" , "Dh" ]):
256
263
return check_result .fail (
257
264
f"Shape mismatch: { value_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
258
265
query_BSHDh ,
259
266
)
260
- """
261
267
262
268
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
263
269
# But this also, unforunately, depends on ORT version.
@@ -283,9 +289,7 @@ def rewrite(
283
289
sin ,
284
290
** _ ,
285
291
):
286
- num_heads = 64
287
- # TODO: (fix) Error caused by incorrect SDPA fusion for pre-scaling case
288
- # num_heads = _ir_utils.get_dim(query_BSHDh, 2)
292
+ num_heads = _ir_utils .get_dim (query_BSHDh , 2 )
289
293
if not isinstance (num_heads , int ):
290
294
return None
291
295
@@ -341,12 +345,14 @@ def rewrite(
341
345
parameter_combinations = [
342
346
{
343
347
"transpose_4d" : transpose_4d ,
348
+ "pre_scale_q" : pre_scale_q ,
344
349
"is_rotary" : is_rotary ,
345
350
"use_mask" : use_mask ,
346
351
"has_past_present" : has_past_present ,
347
352
"is_cross_attention" : is_cross_attention ,
348
353
}
349
354
for transpose_4d in [False , True ]
355
+ for pre_scale_q in [True , False ]
350
356
for is_rotary in [False , True ]
351
357
for use_mask in [False , True ]
352
358
for has_past_present in [False , True ]
@@ -358,6 +364,7 @@ def rewrite(
358
364
[
359
365
MultiHeadAttention .rule (
360
366
f"MHA_{ '4D' if params ['transpose_4d' ] else '3D' } _Transpose"
367
+ f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
361
368
f"{ '_Rotary' if params ['is_rotary' ] else '' } "
362
369
f"{ '_Masked' if params ['use_mask' ] else '' } "
363
370
f"{ '_Past' if params ['has_past_present' ] else '' } "
0 commit comments