Skip to content

Commit 9ca3c1a

Browse files
add slice checks
1 parent 0bf38e3 commit 9ca3c1a

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Sequence, Union
66

77
import onnxscript.ir as ir
8-
from onnxscript.rewriter import _fusion_utils, pattern
8+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
99

1010
Dim = Union[int, ir.SymbolicDim]
1111

@@ -36,6 +36,12 @@ def pattern(
3636
attention_bias,
3737
num_heads,
3838
# scale,
39+
start1,
40+
end1,
41+
start2,
42+
end2,
43+
start3,
44+
end3,
3945
q_mul,
4046
k_mul,
4147
v_mul,
@@ -45,28 +51,28 @@ def pattern(
4551
key_BSD = op.MatMul(input, k_mul)
4652
value_BSD = op.MatMul(input, v_mul)
4753
else:
48-
projected = op.MatMul(input, qkv_weight)
54+
projected = op.MatMul(input, qkv_weight, _outputs=["projected"])
4955

5056
# Slice packed Matmul QKV into Q, K, and V
5157
# Q, K, and V are of shape (B, S, D)
5258
query_BSD = op.Slice(
5359
projected,
54-
pattern.ANY_VALUE, # starts
55-
pattern.ANY_VALUE, # ends
60+
start1, # starts
61+
end1, # ends
5662
[2], # axes
5763
_outputs=["query_mm_sliced"],
5864
)
5965
key_BSD = op.Slice(
6066
projected,
61-
pattern.ANY_VALUE, # starts
62-
pattern.ANY_VALUE, # ends
67+
start2, # starts
68+
end2, # ends
6369
[2], # axes
6470
_outputs=["key_mm_sliced"],
6571
)
6672
value_BSD = op.Slice(
6773
projected,
68-
pattern.ANY_VALUE, # starts
69-
pattern.ANY_VALUE, # ends
74+
start3, # starts
75+
end3, # ends
7076
[2], # axes
7177
_outputs=["value_mm_sliced"],
7278
)
@@ -135,9 +141,16 @@ def check(
135141
op,
136142
input,
137143
qkv_weight,
144+
projected=None,
138145
query_mm_sliced=None,
139146
key_mm_sliced=None,
140147
value_mm_sliced=None,
148+
start1=None,
149+
end1=None,
150+
start2=None,
151+
end2=None,
152+
start3=None,
153+
end3=None,
141154
q_mul=None,
142155
k_mul=None,
143156
v_mul=None,
@@ -155,6 +168,23 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
155168
input,
156169
)
157170
if not self._no_slice:
171+
# Ensure slicing is done correctly
172+
if projected is None or projected.shape is None or len(projected.shape) != 3:
173+
return check_result.fail("Input projection is not a 3D tensor.", projected)
174+
hidden_size = projected.shape[2]
175+
if not isinstance(hidden_size, int):
176+
return check_result.fail("Hidden size is not an integer.", projected)
177+
if not (
178+
_ir_utils.is_singleton_value(start1, 0)
179+
and _ir_utils.get_singleton_value(end1) == _ir_utils.get_singleton_value(start2)
180+
and _ir_utils.get_singleton_value(end2) == _ir_utils.get_singleton_value(start3)
181+
and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size)
182+
):
183+
return check_result.fail(
184+
"Projected input is not being split into q, k, v correctly based on hidden sizes.",
185+
projected,
186+
)
187+
158188
if no_match(qkv_weight, ["D", "Dh"]):
159189
return check_result.fail(
160190
f"Shape mismatch: {qkv_weight} does not match expected dimensions ['D', 'Dh']",

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
219219
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
220220
query_BSHDh,
221221
)
222-
# If cross-attention key/value shapes are 4D
222+
# If cross-attention, key/value shapes are 4D
223223
if self._is_cross_attention:
224224
if no_match(key, ["B", "H", "Skv", "Dh"]):
225225
return check_result.fail(
@@ -349,9 +349,10 @@ def rewrite(
349349
for pre_scale_q in [True, False]
350350
for is_rotary in [False, True]
351351
for use_mask in [False, True]
352-
# Enforce has_past_present to be True first, to avoid missing the pattern
353-
for has_past_present in [True, False]
354352
for is_cross_attention in [False, True]
353+
for has_past_present in ([False] if is_cross_attention else [True, False])
354+
# Skip if both has_past_present and is_cross_attention are True
355+
if not (has_past_present and is_cross_attention)
355356
]
356357

357358
# Dynamically create the rules

0 commit comments

Comments
 (0)