Skip to content

Commit c5c96af

Browse files
fix pre_mul_q placement
1 parent 9c87a4c commit c5c96af

File tree

3 files changed

+45
-30
lines changed

3 files changed

+45
-30
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
7676
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
7777
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
7878
fusion_count["sdpa"] = fuse_sdpa(model)
79+
model = _pre_optimize(model)
7980
# Optimize to avoid trying multiple attention-based fusions
8081
fusion_count["mha"] = fuse_mha(model)
8182
if fusion_count["mha"] == 0:

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Sequence, Union
66

77
import onnxscript.ir as ir
8-
from onnxscript.rewriter import _fusion_utils, pattern
8+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
99

1010
"""
1111
The MultiHeadAttention pattern: generate an instance
@@ -37,13 +37,15 @@ def __init__(
3737
name,
3838
*,
3939
transpose_4d: bool,
40+
pre_scale_q: bool,
4041
is_rotary: bool,
4142
use_mask: bool,
4243
has_past_present: bool,
4344
is_cross_attention: bool,
4445
):
4546
super().__init__(name)
4647
self._transpose_4d = transpose_4d
48+
self._pre_scale_q = pre_scale_q
4749
self._is_rotary = is_rotary
4850
self._use_mask = use_mask
4951
self._has_past_present = has_past_present
@@ -64,9 +66,12 @@ def pattern(
6466
position_ids,
6567
cos,
6668
sin,
69+
q_scale,
6770
):
6871
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
6972

73+
if self._pre_scale_q:
74+
query_BSD = op.Mul(query_BSD, q_scale)
7075
# Reshape from (B, S, D) to (B, S, H, D/H)
7176
query_BSHDh = op.Reshape(
7277
query_BSD,
@@ -202,6 +207,8 @@ def check(
202207
past_key,
203208
past_value,
204209
query_BSHDh,
210+
key_BSHDh=None,
211+
value_BSHDh=None,
205212
**_,
206213
) -> pattern.MatchResult: # type: ignore[name-defined]
207214
check_result = pattern.MatchResult()
@@ -239,25 +246,24 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
239246
f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']",
240247
past_value,
241248
)
242-
"""
249+
243250
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
244251
return check_result.fail(
245252
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
246253
query_BSHDh,
247254
)
248-
249-
if not self.is_cross_attention:
250-
if no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
255+
256+
if not self._is_cross_attention:
257+
if key_BSHDh and no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
251258
return check_result.fail(
252259
f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
253260
query_BSHDh,
254261
)
255-
if no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
262+
if value_BSHDh and no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
256263
return check_result.fail(
257264
f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
258265
query_BSHDh,
259266
)
260-
"""
261267

262268
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
263269
# But this also, unforunately, depends on ORT version.
@@ -283,9 +289,7 @@ def rewrite(
283289
sin,
284290
**_,
285291
):
286-
num_heads = 64
287-
# TODO: (fix) Error caused by incorrect SDPA fusion for pre-scaling case
288-
# num_heads = _ir_utils.get_dim(query_BSHDh, 2)
292+
num_heads = _ir_utils.get_dim(query_BSHDh, 2)
289293
if not isinstance(num_heads, int):
290294
return None
291295

@@ -341,12 +345,14 @@ def rewrite(
341345
parameter_combinations = [
342346
{
343347
"transpose_4d": transpose_4d,
348+
"pre_scale_q": pre_scale_q,
344349
"is_rotary": is_rotary,
345350
"use_mask": use_mask,
346351
"has_past_present": has_past_present,
347352
"is_cross_attention": is_cross_attention,
348353
}
349354
for transpose_4d in [False, True]
355+
for pre_scale_q in [True, False]
350356
for is_rotary in [False, True]
351357
for use_mask in [False, True]
352358
for has_past_present in [False, True]
@@ -358,6 +364,7 @@ def rewrite(
358364
[
359365
MultiHeadAttention.rule(
360366
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
367+
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
361368
f"{'_Rotary' if params['is_rotary'] else ''}"
362369
f"{'_Masked' if params['use_mask'] else ''}"
363370
f"{'_Past' if params['has_past_present'] else ''}"

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,22 @@ def check(
102102
if self._use_mul:
103103
expected_scaling_factor = 1.0 / expected_scaling_factor
104104

105-
if self._pre_scale:
105+
if self._pre_scale and not self._pre_scale_q:
106106
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
107107
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
108108
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
109-
# Calculate the scaling factor for query
109+
# Calculate the scaling factor for query
110110
if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None:
111-
return check_result.fail(
112-
"Query scale is not a scalar.",
113-
query_scale,
114-
)
115-
# Ensure the scaling factor for key is the same as for query
111+
return check_result.fail(
112+
"Query scale is not a scalar.",
113+
query_scale,
114+
)
115+
# Ensure the scaling factor for key is the same as for query
116116
if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None:
117-
return check_result.fail(
118-
"Key scale is not a scalar.",
119-
key_scale,
120-
)
117+
return check_result.fail(
118+
"Key scale is not a scalar.",
119+
key_scale,
120+
)
121121
if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3):
122122
return check_result.fail(
123123
"Query and key scales are not equal.",
@@ -129,13 +129,13 @@ def check(
129129
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
130130
self._scale = None
131131
else:
132-
# Check if qk_scale is a scalar == expected_scaling_factor)
133-
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
132+
# Check if qk_scale is a scalar == expected_scaling_factor)
133+
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
134134
if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None:
135-
return check_result.fail(
136-
"QK scale is not a scalar.",
137-
qk_scale,
138-
)
135+
return check_result.fail(
136+
"QK scale is not a scalar.",
137+
qk_scale,
138+
)
139139
if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3):
140140
self._scale = qk_scale_value
141141
else:
@@ -153,13 +153,20 @@ def rewrite(
153153
key_transposed,
154154
value,
155155
mask,
156-
query_reshape,
156+
query_scale,
157+
key_scale,
158+
qk_scale,
159+
query_reshape=None,
157160
**_,
158161
):
159-
if self._has_3d_inputs and self._pre_scale_q:
162+
if self._has_3d_inputs and self._pre_scale and self._pre_scale_q:
163+
if self._use_mul:
164+
query_mul = op.Mul(query, qk_scale)
165+
else:
166+
query_mul = op.Div(query, qk_scale)
160167
# Reshape and transpose 3D input of shape (B, S, D)
161168
# to 4D input of shape (B, N, S, H)
162-
queryBNSH = op.Reshape(query, query_reshape)
169+
queryBNSH = op.Reshape(query_mul, query_reshape)
163170
query = op.Transpose(queryBNSH, perm=[0, 2, 1, 3])
164171

165172
sdpa_args = [query, key_transposed, value]

0 commit comments

Comments
 (0)