5
5
from typing import Sequence , Union
6
6
7
7
import onnxscript .ir as ir
8
- from onnxscript .rewriter import _fusion_utils , pattern
8
+ from onnxscript .rewriter import _fusion_utils , _ir_utils , pattern
9
9
10
10
Dim = Union [int , ir .SymbolicDim ]
11
11
@@ -36,6 +36,12 @@ def pattern(
36
36
attention_bias ,
37
37
num_heads ,
38
38
# scale,
39
+ start1 ,
40
+ end1 ,
41
+ start2 ,
42
+ end2 ,
43
+ start3 ,
44
+ end3 ,
39
45
q_mul ,
40
46
k_mul ,
41
47
v_mul ,
@@ -45,28 +51,28 @@ def pattern(
45
51
key_BSD = op .MatMul (input , k_mul )
46
52
value_BSD = op .MatMul (input , v_mul )
47
53
else :
48
- projected = op .MatMul (input , qkv_weight )
54
+ projected = op .MatMul (input , qkv_weight , _outputs = [ "projected" ] )
49
55
50
56
# Slice packed Matmul QKV into Q, K, and V
51
57
# Q, K, and V are of shape (B, S, D)
52
58
query_BSD = op .Slice (
53
59
projected ,
54
- pattern . ANY_VALUE , # starts
55
- pattern . ANY_VALUE , # ends
60
+ start1 , # starts
61
+ end1 , # ends
56
62
[2 ], # axes
57
63
_outputs = ["query_mm_sliced" ],
58
64
)
59
65
key_BSD = op .Slice (
60
66
projected ,
61
- pattern . ANY_VALUE , # starts
62
- pattern . ANY_VALUE , # ends
67
+ start2 , # starts
68
+ end2 , # ends
63
69
[2 ], # axes
64
70
_outputs = ["key_mm_sliced" ],
65
71
)
66
72
value_BSD = op .Slice (
67
73
projected ,
68
- pattern . ANY_VALUE , # starts
69
- pattern . ANY_VALUE , # ends
74
+ start3 , # starts
75
+ end3 , # ends
70
76
[2 ], # axes
71
77
_outputs = ["value_mm_sliced" ],
72
78
)
@@ -135,9 +141,16 @@ def check(
135
141
op ,
136
142
input ,
137
143
qkv_weight ,
144
+ projected = None ,
138
145
query_mm_sliced = None ,
139
146
key_mm_sliced = None ,
140
147
value_mm_sliced = None ,
148
+ start1 = None ,
149
+ end1 = None ,
150
+ start2 = None ,
151
+ end2 = None ,
152
+ start3 = None ,
153
+ end3 = None ,
141
154
q_mul = None ,
142
155
k_mul = None ,
143
156
v_mul = None ,
@@ -155,6 +168,23 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
155
168
input ,
156
169
)
157
170
if not self ._no_slice :
171
+ # Ensure slicing is done correctly
172
+ if projected is None or projected .shape is None or len (projected .shape ) != 3 :
173
+ return check_result .fail ("Input projection is not a 3D tensor." , projected )
174
+ hidden_size = projected .shape [2 ]
175
+ if not isinstance (hidden_size , int ):
176
+ return check_result .fail ("Hidden size is not an integer." , projected )
177
+ if not (
178
+ _ir_utils .is_singleton_value (start1 , 0 )
179
+ and _ir_utils .get_singleton_value (end1 ) == _ir_utils .get_singleton_value (start2 )
180
+ and _ir_utils .get_singleton_value (end2 ) == _ir_utils .get_singleton_value (start3 )
181
+ and _ir_utils .is_singleton_value (end3 , lambda x : x >= hidden_size )
182
+ ):
183
+ return check_result .fail (
184
+ "Projected input is not being split into q, k, v correctly based on hidden sizes." ,
185
+ projected ,
186
+ )
187
+
158
188
if no_match (qkv_weight , ["D" , "Dh" ]):
159
189
return check_result .fail (
160
190
f"Shape mismatch: { qkv_weight } does not match expected dimensions ['D', 'Dh']" ,
0 commit comments