Skip to content

Commit 83cd139

Browse files
minor fixes
1 parent 83abbe4 commit 83cd139

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ def __init__(
4444
# pass to remove unused nodes.
4545
super().__init__(name, remove_nodes=False)
4646
# TODO: Determine what should be the default max_pos_id value
47-
self._max_pos_id = 2048 # Set a default max_pos_id value
48-
self._max_pos_id_from_model = False
47+
self._max_pos_id = None
4948
# map from inv_freq to (cos, sin) values for transformed graph
5049
self._inv_freq_cos_sin_cache: dict[ir.Value, tuple[ir.Value, ir.Value]] = {}
5150
self._reshape = reshape
5251
self._cast = cast
5352
self._const_freqs = const_freqs
5453

55-
def _set_max_position_ids_from_model(self, model: ir.Model):
56-
"""Extract max_position_ids value from the metadata of an ONNX model."""
57-
# TODO: Determine what the correct metadata key is for max_position_ids
58-
if model.metadata_props["max_position_id"] is not None:
59-
self._max_pos_id = model.metadata_props["max_position_id"] # type: ignore[assignment]
60-
self._max_pos_id_from_model = True
54+
@property
55+
def max_pos_id(self) -> int | None:
56+
return self._max_pos_id
57+
58+
@max_pos_id.setter
59+
def max_pos_id(self, max_pos_id: int):
60+
self._max_pos_id = max_pos_id # type: ignore[assignment]
6161

6262
def _compute_const_freqs(self, op, freqs):
6363
"""Compute cos/sin values when frequencies are constant."""
@@ -70,18 +70,16 @@ def _compute_const_freqs(self, op, freqs):
7070

7171
def _compute_dynamic_freqs(self, op, inv_freq, position_ids, dtype):
7272
"""Compute cos/sin values dynamically based on inv_freq and position_ids."""
73-
if self._max_pos_id_from_model:
73+
if self._max_pos_id is not None:
7474
# Use max_pos_id from the model metadata
7575
max_pos_id = self._max_pos_id
7676
elif position_ids.const_value is not None:
7777
# Calculate max_pos_id from the position_ids tensor
7878
max_pos_id = int(np.max(position_ids.const_value.numpy()))
79-
self._max_pos_id = max_pos_id
8079
else:
8180
# Dynamically compute max_pos_id from position_ids using ONNX ops
8281
inv_freq = op.Reshape(inv_freq, op.Constant(value_ints=[1, -1]))
83-
max_pos_id = op.ReduceMax(position_ids)
84-
max_pos_id = op.Squeeze(max_pos_id, op.Constant(value_ints=[0]))
82+
max_pos_id = op.ReduceMax(position_ids, keepdims=0)
8583
max_pos_id = op.Add(max_pos_id, op.Constant(value_int=1))
8684
pos_id_range = op.Range(
8785
op.Constant(value_int=0),
@@ -210,8 +208,4 @@ def rewrite(
210208
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])
211209

212210

213-
def fuse_cos_sin_cache(model: ir.Model) -> int:
214-
for rule in cos_sin_cache_rules:
215-
if isinstance(rule, CosSinCacheFusion):
216-
rule._set_max_position_ids_from_model(model)
217-
return _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)(model)
211+
fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)

0 commit comments

Comments
 (0)