Skip to content

Commit f767c3e

Browse files
fix mha test
1 parent ac9709a commit f767c3e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

onnxscript/rewriter/ort_fusions/_smollm_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_ort_inputs(self):
459459
if not hasattr(self, "_ort_inputs"):
460460
inputs = {
461461
"input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64),
462-
"position_ids": numpy.ones((1, 30), dtype=numpy.int64),
462+
"position_ids": numpy.arange(30).reshape(1, 30).astype(numpy.int64),
463463
"past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
464464
"past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
465465
}

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype):
9797
# If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
9898
# to compute angles and cos/sin values
9999
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
100-
pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1)
100+
pos_id_range = np.arange(max_pos_id + 1, dtype=np.float32).reshape(-1, 1)
101101
angles = np.matmul(pos_id_range, inv_freq_values)
102102
return self._compute_const_freqs(op, angles)
103103

0 commit comments

Comments
 (0)