Skip to content

Commit 44570c1

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Add coreml eager model compare (#11574)
Summary: Pull Request resolved: #11574 The snr is still low, currently under investigation. Reviewed By: limintang Differential Revision: D76456745
1 parent d719e8e commit 44570c1

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ModelArgs:
6464
use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1.
6565
# Additional Model Metadata needed at runtime
6666
rope_scale_factor: int = 8
67+
high_freq_factor: int = 4
68+
6769
bos_idx: int = 1
6870
eos_idx: int = 3
6971
bos_count: int = -1 # i.e., a single EOS is used as BOS
@@ -74,6 +76,9 @@ class ModelArgs:
7476

7577
use_cache_list: bool = True
7678

79+
use_kv_cache: bool = False
80+
enable_dynamic_shape: bool = False
81+
7782
def __post_init__(self):
7883
if self.n_kv_heads is None:
7984
self.n_kv_heads = self.n_heads
@@ -160,10 +165,16 @@ def __init__(self, params: ModelArgs):
160165
super().__init__()
161166
self.params = params
162167
if self.params.use_hf_rope:
163-
self.precompute_freqs_cis = hf_precompute_freqs_cis
168+
self.precompute_freqs_cis = partial(
169+
hf_precompute_freqs_cis,
170+
partial_rotary_factor=self.params.partial_rotary_factor,
171+
)
164172
else:
165173
self.precompute_freqs_cis = partial(
166-
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
174+
precompute_freqs_cis,
175+
use_scaled=self.params.use_scaled_rope,
176+
scale_factor=self.params.rope_scale_factor,
177+
high_freq_factor=self.params.high_freq_factor,
167178
)
168179
freqs_cos, freqs_sin = self.precompute_freqs_cis(
169180
self.params.head_dim,

0 commit comments

Comments
 (0)