Skip to content

Avoid using fixed value for max_pos_ids in cos_sin_cache fusion #2167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/_smollm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64),
"position_ids": numpy.ones((1, 30), dtype=numpy.int64),
"position_ids": numpy.arange(30).reshape(1, 30).astype(numpy.int64),
"past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
"past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
}
Expand Down
80 changes: 61 additions & 19 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@
#
# This produces cos/sin values in a form that can be used by ORT's custom ops.

# TODO: To apply the pattern-rewrite, we need to know the maximum position id.
# Need to find a way to get this information from the model or its config.


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(
self,
name: str,
max_pos_id: int,
*,
cast: bool = False,
reshape: bool = False,
Expand All @@ -47,13 +43,66 @@ def __init__(
# matched nodes as part of the rewrite-step. We apply a separate final
# pass to remove unused nodes.
super().__init__(name, remove_nodes=False)
self._max_pos_id = max_pos_id
# TODO: Determine what should be the default max_pos_id value
self._max_pos_id = None
# map from inv_freq to (cos, sin) values for transformed graph
self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {}
self._reshape = reshape
self._cast = cast
self._const_freqs = const_freqs

@property
def max_pos_id(self) -> int | None:
return self._max_pos_id

@max_pos_id.setter
def max_pos_id(self, max_pos_id: int):
self._max_pos_id = max_pos_id # type: ignore[assignment]

def _compute_const_freqs(self, op, freqs):
"""Compute cos/sin values when frequencies are constant."""
angles = freqs.const_value.numpy()
cos_value = np.cos(angles)
sin_value = np.sin(angles)
cos_2d = op.Constant(value=ir.tensor(cos_value))
sin_2d = op.Constant(value=ir.tensor(sin_value))
return cos_2d, sin_2d

def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype):
"""Compute cos/sin values dynamically based on inv_freq and position_ids."""
if self._max_pos_id is not None:
# Use max_pos_id from the model metadata
max_pos_id = self._max_pos_id
elif position_ids.const_value is not None:
# Calculate max_pos_id from the position_ids tensor
max_pos_id = int(np.max(position_ids.const_value.numpy()))
else:
# Dynamically compute max_pos_id from position_ids using ONNX ops
inv_freq = op.Reshape(inv_freq, op.Constant(value_ints=[1, -1]))
max_pos_id = op.ReduceMax(position_ids, keepdims=0)
max_pos_id = op.Add(max_pos_id, op.Constant(value_int=1))
pos_id_range = op.Range(
op.Constant(value_int=0),
max_pos_id,
op.Constant(value_int=1),
)
pos_id_range = op.Reshape(pos_id_range, op.Constant(value_ints=[-1, 1]))
pos_id_range = op.Cast(pos_id_range, to=ir.DataType.FLOAT)
# Compute angles and cos/sin values
angles = op.MatMul(pos_id_range, inv_freq)
cos_2d = op.Cos(angles)
sin_2d = op.Sin(angles)
return cos_2d, sin_2d

# If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
# to compute angles and cos/sin values
# Note: The one is added to max_pos_id as position_ids are 0-indexed
# and the range of position ids should be [0, max_pos_id], max_pos_id inclusive.
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(max_pos_id + 1, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
return self._compute_const_freqs(op, angles)

def cleanup(self):
self._inv_freq_cos_sin_cache.clear()

Expand Down Expand Up @@ -128,16 +177,11 @@ def rewrite(
if inv_freq in self._inv_freq_cos_sin_cache:
cos_2d, sin_2d = self._inv_freq_cos_sin_cache[inv_freq]
else:
# Compute cos/sin values based on whether frequencies are constant
if self._const_freqs:
angles = freqs.const_value.numpy()
cos_2d, sin_2d = self._compute_const_freqs(op, freqs)
else:
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
sin_value = np.sin(angles)
cos_2d = op.Constant(value=ir.tensor(cos_value))
sin_2d = op.Constant(value=ir.tensor(sin_value))
cos_2d, sin_2d = self._compute_dynamic_freqs(op, inv_freq, position_ids, dtype)
if self._cast:
cos_2d = op.Cast(cos_2d, to=dtype)
sin_2d = op.Cast(sin_2d, to=dtype)
Expand All @@ -157,13 +201,11 @@ def rewrite(


_cast_const_freqs = CosSinCacheFusion.rule(
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
)
_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False)
_const_freqs = CosSinCacheFusion.rule(
"CosSinCache_const_freqs", 2048, cast=False, const_freqs=True
"CosSinCache_cast_const_freqs", cast=True, const_freqs=True
)
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
_cast = CosSinCacheFusion.rule("CosSinCache_cast", cast=True, const_freqs=False)
_const_freqs = CosSinCacheFusion.rule("CosSinCache_const_freqs", cast=False, const_freqs=True)
_basic = CosSinCacheFusion.rule("CosSinCache", cast=False)

cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])

Expand Down
Loading