29
29
#
30
30
# This produces cos/sin values in a form that can be used by ORT's custom ops.
31
31
32
- # TODO: To apply the pattern-rewrite, we need to know the maximum position id.
33
- # Need to find a way to get this information from the model or its config.
34
-
35
32
36
33
class CosSinCacheFusion (pattern .RewriteRuleClassBase ):
37
34
def __init__ (
38
35
self ,
39
36
name : str ,
40
- max_pos_id : int ,
41
37
* ,
42
38
cast : bool = False ,
43
39
reshape : bool = False ,
@@ -47,13 +43,66 @@ def __init__(
47
43
# matched nodes as part of the rewrite-step. We apply a separate final
48
44
# pass to remove unused nodes.
49
45
super ().__init__ (name , remove_nodes = False )
50
- self ._max_pos_id = max_pos_id
46
+ # TODO: Determine what should be the default max_pos_id value
47
+ self ._max_pos_id = 2048 # Set a default max_pos_id value
48
+ self ._max_pos_id_from_model = False
51
49
# map from inv_freq to (cos, sin) values for transformed graph
52
50
self ._inv_freq_cos_sin_cache : dict [ir .Value , tuple [ir .Value , ir .Value ]] = {}
53
51
self ._reshape = reshape
54
52
self ._cast = cast
55
53
self ._const_freqs = const_freqs
56
54
55
+ def _set_max_position_ids_from_model (self , model : ir .Model ):
56
+ """Extract max_position_ids value from the metadata of an ONNX model."""
57
+ # TODO: Determine what the correct metadata key is for max_position_ids
58
+ if model .metadata_props ["max_position_id" ] is not None :
59
+ self ._max_pos_id = model .metadata_props ["max_position_id" ] # type: ignore[assignment]
60
+ self ._max_pos_id_from_model = True
61
+
62
+ def _compute_const_freqs (self , op , freqs ):
63
+ """Compute cos/sin values when frequencies are constant."""
64
+ angles = freqs .const_value .numpy ()
65
+ cos_value = np .cos (angles )
66
+ sin_value = np .sin (angles )
67
+ cos_2d = op .Constant (value = ir .tensor (cos_value ))
68
+ sin_2d = op .Constant (value = ir .tensor (sin_value ))
69
+ return cos_2d , sin_2d
70
+
71
+ def _compute_dynamic_freqs (self , op , inv_freq , position_ids , dtype ):
72
+ """Compute cos/sin values dynamically based on inv_freq and position_ids."""
73
+ if self ._max_pos_id_from_model :
74
+ # Use max_pos_id from the model metadata
75
+ max_pos_id = self ._max_pos_id
76
+ elif position_ids .const_value is not None :
77
+ # Calculate max_pos_id from the position_ids tensor
78
+ max_pos_id = int (np .max (position_ids .const_value .numpy ()))
79
+ self ._max_pos_id = max_pos_id
80
+ else :
81
+ # Dynamically compute max_pos_id from position_ids using ONNX ops
82
+ inv_freq = op .Reshape (inv_freq , op .Constant (value_ints = [1 , - 1 ]))
83
+ max_pos_id = op .ReduceMax (position_ids )
84
+ max_pos_id = op .Squeeze (max_pos_id , op .Constant (value_ints = [0 ]))
85
+ max_pos_id = op .Add (max_pos_id , op .Constant (value_int = 1 ))
86
+ pos_id_range = op .Range (
87
+ op .Constant (value_int = 0 ),
88
+ max_pos_id ,
89
+ op .Constant (value_int = 1 ),
90
+ )
91
+ pos_id_range = op .Reshape (pos_id_range , op .Constant (value_ints = [- 1 , 1 ]))
92
+ pos_id_range = op .Cast (pos_id_range , to = ir .DataType .FLOAT )
93
+ # Compute angles and cos/sin values
94
+ angles = op .MatMul (pos_id_range , inv_freq )
95
+ cos_2d = op .Cos (angles )
96
+ sin_2d = op .Sin (angles )
97
+ return cos_2d , sin_2d
98
+
99
+ # If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
100
+ # to compute angles and cos/sin values
101
+ inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
102
+ pos_id_range = np .arange (max_pos_id , dtype = np .float32 ).reshape (- 1 , 1 )
103
+ angles = np .matmul (pos_id_range , inv_freq_values )
104
+ return self ._compute_const_freqs (op , angles )
105
+
57
106
def cleanup (self ):
58
107
self ._inv_freq_cos_sin_cache .clear ()
59
108
@@ -128,16 +177,11 @@ def rewrite(
128
177
if inv_freq in self ._inv_freq_cos_sin_cache :
129
178
cos_2d , sin_2d = self ._inv_freq_cos_sin_cache [inv_freq ]
130
179
else :
180
+ # Compute cos/sin values based on whether frequencies are constant
131
181
if self ._const_freqs :
132
- angles = freqs . const_value . numpy ( )
182
+ cos_2d , sin_2d = self . _compute_const_freqs ( op , freqs )
133
183
else :
134
- inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
135
- pos_id_range = np .arange (self ._max_pos_id , dtype = np .float32 ).reshape (- 1 , 1 )
136
- angles = np .matmul (pos_id_range , inv_freq_values )
137
- cos_value = np .cos (angles )
138
- sin_value = np .sin (angles )
139
- cos_2d = op .Constant (value = ir .tensor (cos_value ))
140
- sin_2d = op .Constant (value = ir .tensor (sin_value ))
184
+ cos_2d , sin_2d = self ._compute_dynamic_freqs (op , inv_freq , position_ids , dtype )
141
185
if self ._cast :
142
186
cos_2d = op .Cast (cos_2d , to = dtype )
143
187
sin_2d = op .Cast (sin_2d , to = dtype )
@@ -157,15 +201,17 @@ def rewrite(
157
201
158
202
159
203
_cast_const_freqs = CosSinCacheFusion .rule (
160
- "CosSinCache_cast_const_freqs" , 2048 , cast = True , const_freqs = True
161
- )
162
- _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , 2048 , cast = True , const_freqs = False )
163
- _const_freqs = CosSinCacheFusion .rule (
164
- "CosSinCache_const_freqs" , 2048 , cast = False , const_freqs = True
204
+ "CosSinCache_cast_const_freqs" , cast = True , const_freqs = True
165
205
)
166
- _basic = CosSinCacheFusion .rule ("CosSinCache" , 2048 , cast = False )
206
+ _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , cast = True , const_freqs = False )
207
+ _const_freqs = CosSinCacheFusion .rule ("CosSinCache_const_freqs" , cast = False , const_freqs = True )
208
+ _basic = CosSinCacheFusion .rule ("CosSinCache" , cast = False )
167
209
168
210
cos_sin_cache_rules = pattern .RewriteRuleSet ([_cast , _cast_const_freqs , _const_freqs , _basic ])
169
211
170
212
171
- fuse_cos_sin_cache = _fusion_utils .apply_fusion_rules (cos_sin_cache_rules )
213
+ def fuse_cos_sin_cache (model : ir .Model ) -> int :
214
+ for rule in cos_sin_cache_rules :
215
+ if isinstance (rule , CosSinCacheFusion ):
216
+ rule ._set_max_position_ids_from_model (model )
217
+ return _fusion_utils .apply_fusion_rules (cos_sin_cache_rules )(model )
0 commit comments