Skip to content

Commit 5a8b9e6

Browse files
Fix cross attention in MHA (#2337)
Fix a seeming bug in handling of cross-attention in MHA (to be verified): In MHA fusion, we start with an input graph where attention is applied to 4D query/key/value, and it is transformed into a MHA op on 3D query/key/value. In the case of cross-attention (with no rotary-embedding): the fusion seems to convert just query to 3D, and seems to leave key and value as 4D, which seems wrong. This PR adds the necessary 4D=>3D conversion for key/value before MHA. Note: This is a quick fix for the relevant case (that shows up). Other combinations may be worth checking out separately. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 77fba51 commit 5a8b9e6

File tree

1 file changed

+7
-0
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+7
-0
lines changed

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,13 @@ def rewrite(
349349
)
350350
else:
351351
key_BSD_emb = key
352+
elif self._is_cross_attention:
353+
query_BSD_emb = query_BSD
354+
# Must convert key/value from 4D to 3D for use in MHA
355+
key = op.Transpose(key, perm=[0, 2, 1, 3])
356+
key_BSD_emb = op.Reshape(key, op.Constant(value_ints=[0, 0, -1]))
357+
value = op.Transpose(value, perm=[0, 2, 1, 3])
358+
value = op.Reshape(value, op.Constant(value_ints=[0, 0, -1]))
352359
else:
353360
query_BSD_emb = query_BSD
354361
key_BSD_emb = key

0 commit comments

Comments
 (0)