Skip to content

Commit 20cdd63

Browse files
modify qkv attention axis
1 parent fc0523d commit 20cdd63

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def rewrite(
210210
# Dh_v = self.bindings.get("Dh_v")
211211
# qkv_hidden_sizes = [Dh_q, Dh_k, Dh_v]
212212
if self._no_slice:
213-
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=0)
213+
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1)
214214

215215
if self._has_past:
216216
attention, present = op.Attention(

0 commit comments

Comments
 (0)