Skip to content

Commit 2dd6b2d

Browse files
authored
Fix bug in handling constants in cos sin fusion (#2319)
Fix bug causing "'numpy.ndarray' object has no attribute 'const_value'" error in benchmark. Of the two calls to `_compute_const_freqs`, one was passing in an ir.Value, and the other a numpy array. Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 644e30c commit 2dd6b2d

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ def max_pos_id(self) -> int | None:
5959
def max_pos_id(self, max_pos_id: int):
6060
self._max_pos_id = max_pos_id # type: ignore[assignment]
6161

62-
def _compute_const_freqs(self, op, freqs):
62+
def _compute_const_freqs(self, op, angles: np.ndarray):
6363
"""Compute cos/sin values when frequencies are constant."""
64-
angles = freqs.const_value.numpy()
6564
cos_value = np.cos(angles)
6665
sin_value = np.sin(angles)
6766
cos_2d = op.Constant(value=ir.tensor(cos_value))
@@ -179,7 +178,7 @@ def rewrite(
179178
else:
180179
# Compute cos/sin values based on whether frequencies are constant
181180
if self._const_freqs:
182-
cos_2d, sin_2d = self._compute_const_freqs(op, freqs)
181+
cos_2d, sin_2d = self._compute_const_freqs(op, freqs.const_value.numpy())
183182
else:
184183
cos_2d, sin_2d = self._compute_dynamic_freqs(op, inv_freq, position_ids, dtype)
185184
if self._cast:

0 commit comments

Comments
 (0)