Skip to content

Commit 1fa3150

Browse files
add cos, sin calc functions
1 parent 8b1f814 commit 1fa3150

File tree

1 file changed

+66
-20
lines changed

1 file changed

+66
-20
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,11 @@
2929
#
3030
# This produces cos/sin values in a form that can be used by ORT's custom ops.
3131

32-
# TODO: To apply the pattern-rewrite, we need to know the maximum position id.
33-
# Need to find a way to get this information from the model or its config.
34-
3532

3633
class CosSinCacheFusion(pattern.RewriteRuleClassBase):
3734
def __init__(
3835
self,
3936
name: str,
40-
max_pos_id: int,
4137
*,
4238
cast: bool = False,
4339
reshape: bool = False,
@@ -47,13 +43,66 @@ def __init__(
4743
# matched nodes as part of the rewrite-step. We apply a separate final
4844
# pass to remove unused nodes.
4945
super().__init__(name, remove_nodes=False)
50-
self._max_pos_id = max_pos_id
46+
# TODO: Determine what should be the default max_pos_id value
47+
self._max_pos_id = 2048 # Set a default max_pos_id value
48+
self._max_pos_id_from_model = False
5149
# map from inv_freq to (cos, sin) values for transformed graph
5250
self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {}
5351
self._reshape = reshape
5452
self._cast = cast
5553
self._const_freqs = const_freqs
5654

55+
def _set_max_position_ids_from_model(self, model: ir.Model):
56+
"""Extract max_position_ids value from the metadata of an ONNX model."""
57+
# TODO: Determine what the correct metadata key is for max_position_ids
58+
if model.metadata_props["max_position_id"] is not None:
59+
self._max_pos_id = model.metadata_props["max_position_id"] # type: ignore[assignment]
60+
self._max_pos_id_from_model = True
61+
62+
def _compute_const_freqs(self, op, freqs):
63+
"""Compute cos/sin values when frequencies are constant."""
64+
angles = freqs.const_value.numpy()
65+
cos_value = np.cos(angles)
66+
sin_value = np.sin(angles)
67+
cos_2d = op.Constant(value=ir.tensor(cos_value))
68+
sin_2d = op.Constant(value=ir.tensor(sin_value))
69+
return cos_2d, sin_2d
70+
71+
def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype):
72+
"""Compute cos/sin values dynamically based on inv_freq and position_ids."""
73+
if self._max_pos_id_from_model:
74+
# Use max_pos_id from the model metadata
75+
max_pos_id = self._max_pos_id
76+
elif position_ids.const_value is not None:
77+
# Calculate max_pos_id from the position_ids tensor
78+
max_pos_id = int(np.max(position_ids.const_value.numpy()))
79+
self._max_pos_id = max_pos_id
80+
else:
81+
# Dynamically compute max_pos_id from position_ids using ONNX ops
82+
inv_freq = op.Reshape(inv_freq, op.Constant(value_ints=[1, -1]))
83+
max_pos_id = op.ReduceMax(position_ids)
84+
max_pos_id = op.Squeeze(max_pos_id, op.Constant(value_ints=[0]))
85+
max_pos_id = op.Add(max_pos_id, op.Constant(value_int=1))
86+
pos_id_range = op.Range(
87+
op.Constant(value_int=0),
88+
max_pos_id,
89+
op.Constant(value_int=1),
90+
)
91+
pos_id_range = op.Reshape(pos_id_range, op.Constant(value_ints=[-1, 1]))
92+
pos_id_range = op.Cast(pos_id_range, to=ir.DataType.FLOAT)
93+
# Compute angles and cos/sin values
94+
angles = op.MatMul(pos_id_range, inv_freq)
95+
cos_2d = op.Cos(angles)
96+
sin_2d = op.Sin(angles)
97+
return cos_2d, sin_2d
98+
99+
# If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
100+
# to compute angles and cos/sin values
101+
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
102+
pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1)
103+
angles = np.matmul(pos_id_range, inv_freq_values)
104+
return self._compute_const_freqs(op, angles)
105+
57106
def cleanup(self):
58107
self._inv_freq_cos_sin_cache.clear()
59108

@@ -128,16 +177,11 @@ def rewrite(
128177
if inv_freq in self._inv_freq_cos_sin_cache:
129178
cos_2d, sin_2d = self._inv_freq_cos_sin_cache[inv_freq]
130179
else:
180+
# Compute cos/sin values based on whether frequencies are constant
131181
if self._const_freqs:
132-
angles = freqs.const_value.numpy()
182+
cos_2d, sin_2d = self._compute_const_freqs(op, freqs)
133183
else:
134-
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
135-
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
136-
angles = np.matmul(pos_id_range, inv_freq_values)
137-
cos_value = np.cos(angles)
138-
sin_value = np.sin(angles)
139-
cos_2d = op.Constant(value=ir.tensor(cos_value))
140-
sin_2d = op.Constant(value=ir.tensor(sin_value))
184+
cos_2d, sin_2d = self._compute_dynamic_freqs(op, inv_freq, position_ids, dtype)
141185
if self._cast:
142186
cos_2d = op.Cast(cos_2d, to=dtype)
143187
sin_2d = op.Cast(sin_2d, to=dtype)
@@ -157,15 +201,17 @@ def rewrite(
157201

158202

159203
_cast_const_freqs = CosSinCacheFusion.rule(
160-
"CosSinCache_cast_const_freqs", 2048, cast=True, const_freqs=True
161-
)
162-
_cast = CosSinCacheFusion.rule("CosSinCache_cast", 2048, cast=True, const_freqs=False)
163-
_const_freqs = CosSinCacheFusion.rule(
164-
"CosSinCache_const_freqs", 2048, cast=False, const_freqs=True
204+
"CosSinCache_cast_const_freqs", cast=True, const_freqs=True
165205
)
166-
_basic = CosSinCacheFusion.rule("CosSinCache", 2048, cast=False)
206+
_cast = CosSinCacheFusion.rule("CosSinCache_cast", cast=True, const_freqs=False)
207+
_const_freqs = CosSinCacheFusion.rule("CosSinCache_const_freqs", cast=False, const_freqs=True)
208+
_basic = CosSinCacheFusion.rule("CosSinCache", cast=False)
167209

168210
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])
169211

170212

171-
fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)
213+
def fuse_cos_sin_cache(model: ir.Model) -> int:
214+
for rule in cos_sin_cache_rules:
215+
if isinstance(rule, CosSinCacheFusion):
216+
rule._set_max_position_ids_from_model(model)
217+
return _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)(model)

0 commit comments

Comments
 (0)