@@ -44,20 +44,20 @@ def __init__(
44
44
# pass to remove unused nodes.
45
45
super ().__init__ (name , remove_nodes = False )
46
46
# TODO: Determine what should be the default max_pos_id value
47
- self ._max_pos_id = 2048 # Set a default max_pos_id value
48
- self ._max_pos_id_from_model = False
47
+ self ._max_pos_id = None
49
48
# map from inv_freq to (cos, sin) values for transformed graph
50
49
self ._inv_freq_cos_sin_cache : dict [ir .Value , tuple [ir .Value , ir .Value ]] = {}
51
50
self ._reshape = reshape
52
51
self ._cast = cast
53
52
self ._const_freqs = const_freqs
54
53
55
- def _set_max_position_ids_from_model (self , model : ir .Model ):
56
- """Extract max_position_ids value from the metadata of an ONNX model."""
57
- # TODO: Determine what the correct metadata key is for max_position_ids
58
- if model .metadata_props ["max_position_id" ] is not None :
59
- self ._max_pos_id = model .metadata_props ["max_position_id" ] # type: ignore[assignment]
60
- self ._max_pos_id_from_model = True
54
+ @property
55
+ def max_pos_id (self ) -> int | None :
56
+ return self ._max_pos_id
57
+
58
+ @max_pos_id .setter
59
+ def max_pos_id (self , max_pos_id : int ):
60
+ self ._max_pos_id = max_pos_id # type: ignore[assignment]
61
61
62
62
def _compute_const_freqs (self , op , freqs ):
63
63
"""Compute cos/sin values when frequencies are constant."""
@@ -70,18 +70,16 @@ def _compute_const_freqs(self, op, freqs):
70
70
71
71
def _compute_dynamic_freqs (self , op , inv_freq , position_ids , dtype ):
72
72
"""Compute cos/sin values dynamically based on inv_freq and position_ids."""
73
- if self ._max_pos_id_from_model :
73
+ if self ._max_pos_id is not None :
74
74
# Use max_pos_id from the model metadata
75
75
max_pos_id = self ._max_pos_id
76
76
elif position_ids .const_value is not None :
77
77
# Calculate max_pos_id from the position_ids tensor
78
78
max_pos_id = int (np .max (position_ids .const_value .numpy ()))
79
- self ._max_pos_id = max_pos_id
80
79
else :
81
80
# Dynamically compute max_pos_id from position_ids using ONNX ops
82
81
inv_freq = op .Reshape (inv_freq , op .Constant (value_ints = [1 , - 1 ]))
83
- max_pos_id = op .ReduceMax (position_ids )
84
- max_pos_id = op .Squeeze (max_pos_id , op .Constant (value_ints = [0 ]))
82
+ max_pos_id = op .ReduceMax (position_ids , keepdims = 0 )
85
83
max_pos_id = op .Add (max_pos_id , op .Constant (value_int = 1 ))
86
84
pos_id_range = op .Range (
87
85
op .Constant (value_int = 0 ),
@@ -210,8 +208,4 @@ def rewrite(
210
208
cos_sin_cache_rules = pattern .RewriteRuleSet ([_cast , _cast_const_freqs , _const_freqs , _basic ])
211
209
212
210
213
- def fuse_cos_sin_cache (model : ir .Model ) -> int :
214
- for rule in cos_sin_cache_rules :
215
- if isinstance (rule , CosSinCacheFusion ):
216
- rule ._set_max_position_ids_from_model (model )
217
- return _fusion_utils .apply_fusion_rules (cos_sin_cache_rules )(model )
211
+ fuse_cos_sin_cache = _fusion_utils .apply_fusion_rules (cos_sin_cache_rules )
0 commit comments