@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
4745
4745
4746
4746
// non-transformer models do not have attention heads
4747
4747
if (hparams.n_head() > 0) {
4748
- // sanity check for n_rot (optional)
4749
- hparams.n_rot = hparams.n_embd / hparams.n_head();
4750
-
4751
- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752
-
4753
- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754
- if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755
- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756
- }
4757
- }
4758
4748
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
4759
4749
// gpt-j n_rot = rotary_dim
4760
4750
@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
4763
4753
4764
4754
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
4765
4755
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756
+
4757
+ // sanity check for n_rot (optional)
4758
+ hparams.n_rot = hparams.n_embd_head_k;
4759
+
4760
+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761
+
4762
+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763
+ if (hparams.n_rot != hparams.n_embd_head_k) {
4764
+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765
+ }
4766
+ }
4766
4767
} else {
4767
4768
hparams.n_rot = 0;
4768
4769
hparams.n_embd_head_k = 0;
@@ -11650,7 +11651,7 @@ struct llm_build_context {
11650
11651
11651
11652
Qcur = ggml_rope_ext(
11652
11653
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11653
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11654
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11654
11655
ext_factor, attn_factor, beta_fast, beta_slow);
11655
11656
cb(Qcur, "Qcur", il);
11656
11657
@@ -11659,7 +11660,7 @@ struct llm_build_context {
11659
11660
11660
11661
Kcur = ggml_rope_ext(
11661
11662
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11662
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11663
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11663
11664
ext_factor, attn_factor, beta_fast, beta_slow);
11664
11665
cb(Kcur, "Kcur", il);
11665
11666
@@ -11763,7 +11764,7 @@ struct llm_build_context {
11763
11764
11764
11765
Qcur = ggml_rope_ext(
11765
11766
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11766
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11767
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11767
11768
ext_factor, attn_factor, beta_fast, beta_slow);
11768
11769
cb(Qcur, "Qcur", il);
11769
11770
@@ -11772,7 +11773,7 @@ struct llm_build_context {
11772
11773
11773
11774
Kcur = ggml_rope_ext(
11774
11775
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11775
- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11776
+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
11776
11777
ext_factor, attn_factor, beta_fast, beta_slow);
11777
11778
cb(Kcur, "Kcur", il);
11778
11779
0 commit comments