Skip to content

Commit 04ed2b8

Browse files
Avoid using fixed value for max_pos_ids in cos_sin_cache fusion (#2167)
To apply the cos sin fusion pattern-rewrite, we need to know the maximum position id. - If model/config has this information, use it calculate max_pos_id - If not, calculate max_pos_id using position ids using ONNX ops - Removes dependence of pre-setting the max_pos_id for each rewrite rule
1 parent 3536960 commit 04ed2b8

File tree

2 files changed

+62
-20
lines changed

2 files changed

+62
-20
lines changed

onnxscript/rewriter/ort_fusions/_smollm_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def get_ort_inputs(self):
459459
if not hasattr(self, "_ort_inputs"):
460460
inputs = {
461461
"input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64),
462-
"position_ids": numpy.ones((1, 30), dtype=numpy.int64),
462+
"position_ids": numpy.arange(30).reshape(1, 30).astype(numpy.int64),
463463
"past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
464464
"past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
465465
}

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,11 @@
2929
#
3030
# This produces cos/sin values in a form that can be used by ORT's custom ops.
3131

32-
# TODO: To apply the pattern-rewrite, we need to know the maximum position id.
33-
# Need to find a way to get this information from the model or its config.
34-
3532

3633
class CosSinCacheFusion(pattern.RewriteRuleClassBase):
3734
def __init__(
3835
self,
3936
name: str,
40-
max_pos_id: int,
4137
*,
4238
cast: bool = False,
4339
reshape: bool = False,
@@ -47,13 +43,66 @@ def __init__(
4743
# matched nodes as part of the rewrite-step. We apply a separate final
4844
# pass to remove unused nodes.
4945
super().__init__(name, remove_nodes=False)
50-
self._max_pos_id = max_pos_id
46+
# TODO: Determine what should be the default max_pos_id value
47+
self._max_pos_id = None
5148
# map from inv_freq to (cos, sin) values for transformed graph
5249
self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {}
5350
self._reshape = reshape
5451
self._cast = cast
5552
self._const_freqs = const_freqs
5653

54+
@property
55+
def max_pos_id(self) -> int | None:
56+
return self._max_pos_id
57+
58+
@max_pos_id.setter
59+
def max_pos_id(self, max_pos_id: int):
60+
self._max_pos_id = max_pos_id # type: ignore[assignment]
61+
62+
def _compute_const_freqs(self, op, freqs):
63+
"""Compute cos/sin values when frequencies are constant."""
64+
angles = freqs.const_value.numpy()
65+
cos_value = np.cos(angles)
66+
sin_value = np.sin(angles)
67+
cos_2d = op.Constant(value=ir.tensor(cos_value))
68+
sin_2d = op.Constant(value=ir.tensor(sin_value))
69+
return cos_2d, sin_2d
70+
71+
def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype):
72+
"""Compute cos/sin values dynamically based on inv_freq and position_ids."""
73+
if self._max_pos_id is not None:
74+
# Use max_pos_id from the model metadata
75+
max_pos_id = self._max_pos_id
76+
elif position_ids.const_value is not None:
77+
# Calculate max_pos_id from the position_ids tensor
78+
max_pos_id = int(np.max(position_ids.const_value.numpy()))
79+
else:
80+
# Dynamically compute max_pos_id from position_ids using ONNX ops
81+
inv_freq = op.Reshape(inv_freq, op.Constant(value_ints=[1, -1]))
82+
max_pos_id = op.ReduceMax(position_ids, keepdims=0)
83+
max_pos_id = op.Add(max_pos_id, op.Constant(value_int=1))
84+
pos_id_range = op.Range(
85+
op.Constant(value_int=0),
86+
max_pos_id,
87+
op.Constant(value_int=1),
88+
)
89+
pos_id_range = op.Reshape(pos_id_range, op.Constant(value_ints=[-1, 1]))
90+
pos_id_range = op.Cast(pos_id_range, to=ir.DataType.FLOAT)
91+
# Compute angles and cos/sin values
92+
angles = op.MatMul(pos_id_range, inv_freq)
93+
cos_2d = op.Cos(angles)
94+
sin_2d = op.Sin(angles)
95+
return cos_2d, sin_2d
96+
97+
# If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
98+
# to compute angles and cos/sin values
99+
# Note: The one is added to max_pos_id as position_ids are 0-indexed
100+
# and the range of position ids should be [0, max_pos_id], max_pos_id inclusive.
101+
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
102+
pos_id_range = np.arange(max_pos_id + 1, dtype=np.float32).reshape(-1, 1)
103+
angles = np.matmul(pos_id_range, inv_freq_values)
104+
return self._compute_const_freqs(op, angles)
105+
57106
def cleanup(self):
58107
self._inv_freq_cos_sin_cache.clear()
59108

@@ -128,16 +177,11 @@ def rewrite(
128177
if inv_freq in self._inv_freq_cos_sin_cache:
129178
cos_2d, sin_2d = self._inv_freq_cos_sin_cache[inv_freq]
130179
else:
180+
# Compute cos/sin values based on whether frequencies are constant
131181
if self._const_freqs:
132-
angles = freqs.const_value.numpy()
182+
cos_2d, sin_2d = self._compute_const_freqs(op, freqs)
133183
else:
134-
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
135-
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
136-
angles = np.matmul(pos_id_range, inv_freq_values)
137-
cos_value = np.cos(angles)
138-
sin_value = np.sin(angles)
139-
cos_2d = op.Constant(value=ir.tensor(cos_value))
140-
sin_2d = op.Constant(value=ir.tensor(sin_value))
184+
cos_2d, sin_2d = self._compute_dynamic_freqs(op, inv_freq, position_ids, dtype)
141185
if self._cast:
142186
cos_2d = op.Cast(cos_2d, to=dtype)
143187
sin_2d = op.Cast(sin_2d, to=dtype)
@@ -157,13 +201,11 @@ def rewrite(
157201

158202

159203
_cast_const_freqs = CosSinCacheFusion.rule(
160-
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
161-
)
162-
_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False)
163-
_const_freqs = CosSinCacheFusion.rule(
164-
"CosSinCache_const_freqs", 2048, cast=False, const_freqs=True
204+
"CosSinCache_cast_const_freqs", cast=True, const_freqs=True
165205
)
166-
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
206+
_cast = CosSinCacheFusion.rule("CosSinCache_cast", cast=True, const_freqs=False)
207+
_const_freqs = CosSinCacheFusion.rule("CosSinCache_const_freqs", cast=False, const_freqs=True)
208+
_basic = CosSinCacheFusion.rule("CosSinCache", cast=False)
167209

168210
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])
169211

0 commit comments

Comments
 (0)