@@ -64,6 +64,8 @@ class ModelArgs:
64
64
use_scaled_rope : bool = True # Use scaled RoPE, introduced in llama3.1.
65
65
# Additional Model Metadata needed at runtime
66
66
rope_scale_factor : int = 8
67
+ high_freq_factor : int = 4
68
+
67
69
bos_idx : int = 1
68
70
eos_idx : int = 3
69
71
bos_count : int = - 1 # i.e., a single EOS is used as BOS
@@ -74,6 +76,9 @@ class ModelArgs:
74
76
75
77
use_cache_list : bool = True
76
78
79
+ use_kv_cache : bool = False
80
+ enable_dynamic_shape : bool = False
81
+
77
82
def __post_init__ (self ):
78
83
if self .n_kv_heads is None :
79
84
self .n_kv_heads = self .n_heads
@@ -160,10 +165,16 @@ def __init__(self, params: ModelArgs):
160
165
super ().__init__ ()
161
166
self .params = params
162
167
if self .params .use_hf_rope :
163
- self .precompute_freqs_cis = hf_precompute_freqs_cis
168
+ self .precompute_freqs_cis = partial (
169
+ hf_precompute_freqs_cis ,
170
+ partial_rotary_factor = self .params .partial_rotary_factor ,
171
+ )
164
172
else :
165
173
self .precompute_freqs_cis = partial (
166
- precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
174
+ precompute_freqs_cis ,
175
+ use_scaled = self .params .use_scaled_rope ,
176
+ scale_factor = self .params .rope_scale_factor ,
177
+ high_freq_factor = self .params .high_freq_factor ,
167
178
)
168
179
freqs_cos , freqs_sin = self .precompute_freqs_cis (
169
180
self .params .head_dim ,
0 commit comments