29
29
#
30
30
# This produces cos/sin values in a form that can be used by ORT's custom ops.
31
31
32
- # TODO: To apply the pattern-rewrite, we need to know the maximum position id.
33
- # Need to find a way to get this information from the model or its config.
34
-
35
32
36
33
class CosSinCacheFusion (pattern .RewriteRuleClassBase ):
37
34
def __init__ (
38
35
self ,
39
36
name : str ,
40
- max_pos_id : int ,
41
37
* ,
42
38
cast : bool = False ,
43
39
reshape : bool = False ,
@@ -47,13 +43,66 @@ def __init__(
47
43
# matched nodes as part of the rewrite-step. We apply a separate final
48
44
# pass to remove unused nodes.
49
45
super ().__init__ (name , remove_nodes = False )
50
- self ._max_pos_id = max_pos_id
46
+ # TODO: Determine what should be the default max_pos_id value
47
+ self ._max_pos_id = None
51
48
# map from inv_freq to (cos, sin) values for transformed graph
52
49
self ._inv_freq_cos_sin_cache : dict [ir .Value , tuple [ir .Value , ir .Value ]] = {}
53
50
self ._reshape = reshape
54
51
self ._cast = cast
55
52
self ._const_freqs = const_freqs
56
53
54
+ @property
55
+ def max_pos_id (self ) -> int | None :
56
+ return self ._max_pos_id
57
+
58
+ @max_pos_id .setter
59
+ def max_pos_id (self , max_pos_id : int ):
60
+ self ._max_pos_id = max_pos_id # type: ignore[assignment]
61
+
62
+ def _compute_const_freqs (self , op , freqs ):
63
+ """Compute cos/sin values when frequencies are constant."""
64
+ angles = freqs .const_value .numpy ()
65
+ cos_value = np .cos (angles )
66
+ sin_value = np .sin (angles )
67
+ cos_2d = op .Constant (value = ir .tensor (cos_value ))
68
+ sin_2d = op .Constant (value = ir .tensor (sin_value ))
69
+ return cos_2d , sin_2d
70
+
71
+ def _compute_dynamic_freqs (self , op , inv_freq , position_ids , dtype ):
72
+ """Compute cos/sin values dynamically based on inv_freq and position_ids."""
73
+ if self ._max_pos_id is not None :
74
+ # Use max_pos_id from the model metadata
75
+ max_pos_id = self ._max_pos_id
76
+ elif position_ids .const_value is not None :
77
+ # Calculate max_pos_id from the position_ids tensor
78
+ max_pos_id = int (np .max (position_ids .const_value .numpy ()))
79
+ else :
80
+ # Dynamically compute max_pos_id from position_ids using ONNX ops
81
+ inv_freq = op .Reshape (inv_freq , op .Constant (value_ints = [1 , - 1 ]))
82
+ max_pos_id = op .ReduceMax (position_ids , keepdims = 0 )
83
+ max_pos_id = op .Add (max_pos_id , op .Constant (value_int = 1 ))
84
+ pos_id_range = op .Range (
85
+ op .Constant (value_int = 0 ),
86
+ max_pos_id ,
87
+ op .Constant (value_int = 1 ),
88
+ )
89
+ pos_id_range = op .Reshape (pos_id_range , op .Constant (value_ints = [- 1 , 1 ]))
90
+ pos_id_range = op .Cast (pos_id_range , to = ir .DataType .FLOAT )
91
+ # Compute angles and cos/sin values
92
+ angles = op .MatMul (pos_id_range , inv_freq )
93
+ cos_2d = op .Cos (angles )
94
+ sin_2d = op .Sin (angles )
95
+ return cos_2d , sin_2d
96
+
97
+ # If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
98
+ # to compute angles and cos/sin values
99
+ # Note: The one is added to max_pos_id as position_ids are 0-indexed
100
+ # and the range of position ids should be [0, max_pos_id], max_pos_id inclusive.
101
+ inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
102
+ pos_id_range = np .arange (max_pos_id + 1 , dtype = np .float32 ).reshape (- 1 , 1 )
103
+ angles = np .matmul (pos_id_range , inv_freq_values )
104
+ return self ._compute_const_freqs (op , angles )
105
+
57
106
def cleanup (self ):
58
107
self ._inv_freq_cos_sin_cache .clear ()
59
108
@@ -128,16 +177,11 @@ def rewrite(
128
177
if inv_freq in self ._inv_freq_cos_sin_cache :
129
178
cos_2d , sin_2d = self ._inv_freq_cos_sin_cache [inv_freq ]
130
179
else :
180
+ # Compute cos/sin values based on whether frequencies are constant
131
181
if self ._const_freqs :
132
- angles = freqs . const_value . numpy ( )
182
+ cos_2d , sin_2d = self . _compute_const_freqs ( op , freqs )
133
183
else :
134
- inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
135
- pos_id_range = np .arange (self ._max_pos_id , dtype = np .float32 ).reshape (- 1 , 1 )
136
- angles = np .matmul (pos_id_range , inv_freq_values )
137
- cos_value = np .cos (angles )
138
- sin_value = np .sin (angles )
139
- cos_2d = op .Constant (value = ir .tensor (cos_value ))
140
- sin_2d = op .Constant (value = ir .tensor (sin_value ))
184
+ cos_2d , sin_2d = self ._compute_dynamic_freqs (op , inv_freq , position_ids , dtype )
141
185
if self ._cast :
142
186
cos_2d = op .Cast (cos_2d , to = dtype )
143
187
sin_2d = op .Cast (sin_2d , to = dtype )
@@ -157,13 +201,11 @@ def rewrite(
157
201
158
202
159
203
_cast_const_freqs = CosSinCacheFusion .rule (
160
- "CosSinCache_cast_const_freqs" , 2048 , cast = True , const_freqs = True
161
- )
162
- _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , 2048 , cast = True , const_freqs = False )
163
- _const_freqs = CosSinCacheFusion .rule (
164
- "CosSinCache_const_freqs" , 2048 , cast = False , const_freqs = True
204
+ "CosSinCache_cast_const_freqs" , cast = True , const_freqs = True
165
205
)
166
- _basic = CosSinCacheFusion .rule ("CosSinCache" , 2048 , cast = False )
206
+ _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , cast = True , const_freqs = False )
207
+ _const_freqs = CosSinCacheFusion .rule ("CosSinCache_const_freqs" , cast = False , const_freqs = True )
208
+ _basic = CosSinCacheFusion .rule ("CosSinCache" , cast = False )
167
209
168
210
cos_sin_cache_rules = pattern .RewriteRuleSet ([_cast , _cast_const_freqs , _const_freqs , _basic ])
169
211
0 commit comments