Skip to content

Commit 1bf7194

Browse files
ggerganovNexesenex
authored andcommitted
llama : fix n_rot default (ggml-org#8348)
ggml-ci
1 parent 6030c8b commit 1bf7194

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

llama.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
47454745

47464746
// non-transformer models do not have attention heads
47474747
if (hparams.n_head() > 0) {
4748-
// sanity check for n_rot (optional)
4749-
hparams.n_rot = hparams.n_embd / hparams.n_head();
4750-
4751-
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752-
4753-
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754-
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755-
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756-
}
4757-
}
47584748
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
47594749
// gpt-j n_rot = rotary_dim
47604750

@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
47634753

47644754
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
47654755
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756+
4757+
// sanity check for n_rot (optional)
4758+
hparams.n_rot = hparams.n_embd_head_k;
4759+
4760+
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761+
4762+
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763+
if (hparams.n_rot != hparams.n_embd_head_k) {
4764+
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765+
}
4766+
}
47664767
} else {
47674768
hparams.n_rot = 0;
47684769
hparams.n_embd_head_k = 0;
@@ -11650,7 +11651,7 @@ struct llm_build_context {
1165011651

1165111652
Qcur = ggml_rope_ext(
1165211653
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11653-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11654+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1165411655
ext_factor, attn_factor, beta_fast, beta_slow);
1165511656
cb(Qcur, "Qcur", il);
1165611657

@@ -11659,7 +11660,7 @@ struct llm_build_context {
1165911660

1166011661
Kcur = ggml_rope_ext(
1166111662
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11662-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11663+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1166311664
ext_factor, attn_factor, beta_fast, beta_slow);
1166411665
cb(Kcur, "Kcur", il);
1166511666

@@ -11763,7 +11764,7 @@ struct llm_build_context {
1176311764

1176411765
Qcur = ggml_rope_ext(
1176511766
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11766-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11767+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1176711768
ext_factor, attn_factor, beta_fast, beta_slow);
1176811769
cb(Qcur, "Qcur", il);
1176911770

@@ -11772,7 +11773,7 @@ struct llm_build_context {
1177211773

1177311774
Kcur = ggml_rope_ext(
1177411775
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11775-
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
11776+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1177611777
ext_factor, attn_factor, beta_fast, beta_slow);
1177711778
cb(Kcur, "Kcur", il);
1177811779

0 commit comments

Comments
 (0)