Skip to content

Commit 7df7530

Browse files
committed
gemma2: add sliding window mask
1 parent 1c5eba6 commit 7df7530

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

src/llama.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ enum llm_kv {
287287

288288
LLM_KV_VOCAB_SIZE,
289289
LLM_KV_CONTEXT_LENGTH,
290+
LLM_KV_CONTEXT_LENGTH_SWA,
290291
LLM_KV_EMBEDDING_LENGTH,
291292
LLM_KV_BLOCK_COUNT,
292293
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
@@ -379,6 +380,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
379380

380381
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
381382
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
383+
{ LLM_KV_CONTEXT_LENGTH_SWA, "%s.context_length_swa" },
382384
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
383385
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
384386
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
@@ -2079,7 +2081,8 @@ struct llama_hparams {
20792081
bool use_par_res;
20802082

20812083
uint32_t n_vocab;
2082-
uint32_t n_ctx_train; // context size the model was trained on
2084+
uint32_t n_ctx_train; // context size the model was trained on
2085+
int32_t n_ctx_swa = -1; // context size for sliding window attention (SWA)
20832086
uint32_t n_embd;
20842087
uint32_t n_head;
20852088
uint32_t n_head_kv;
@@ -2661,6 +2664,9 @@ struct llama_context {
26612664
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
26622665
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
26632666

2667+
// KQ mask per layer, used by sliding window attention (gemma 2)
2668+
std::vector<struct ggml_tensor *> inp_KQ_mask_l;
2669+
26642670
// control vectors
26652671
struct llama_control_vector cvec;
26662672
};
@@ -4709,6 +4715,8 @@ static void llm_load_hparams(
47094715
} break;
47104716
case LLM_ARCH_GEMMA2:
47114717
{
4718+
hparams.n_ctx_swa = 4096; // default value
4719+
ml.get_key(LLM_KV_CONTEXT_LENGTH_SWA, hparams.n_ctx_swa, false);
47124720
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
47134721
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
47144722
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -11029,9 +11037,16 @@ struct llm_build_context {
1102911037
struct ggml_tensor * inp_pos = build_inp_pos();
1103011038

1103111039
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11032-
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11040+
// gemma 2 requires different mask for layers using sliding window (SWA)
11041+
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask();
11042+
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask();
11043+
lctx.inp_KQ_mask_l.clear();
1103311044

1103411045
for (int il = 0; il < n_layer; ++il) {
11046+
// (il % 2) layers use SWA
11047+
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full;
11048+
lctx.inp_KQ_mask_l.push_back(KQ_mask);
11049+
1103511050
// norm
1103611051
cur = llm_build_norm(ctx0, inpL, hparams,
1103711052
model.layers[il].attn_norm, NULL,
@@ -12671,6 +12686,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1267112686
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1267212687

1267312688
float * data = (float *) lctx.inp_KQ_mask->data;
12689+
float * data_swa = nullptr;
12690+
12691+
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
12692+
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
12693+
GGML_ASSERT(hparams.n_ctx_swa > 0);
12694+
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
12695+
data = (float *) lctx.inp_KQ_mask_l[1]->data;
12696+
}
1267412697

1267512698
// For causal attention, use only the previous KV cells
1267612699
// of the correct sequence for each token of the batch.
@@ -12692,6 +12715,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1269212715
}
1269312716
}
1269412717
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12718+
12719+
// may need to cut off old tokens for sliding window
12720+
if (data_swa && f != -INFINITY) {
12721+
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
12722+
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
12723+
f = -INFINITY;
12724+
}
12725+
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12726+
}
1269512727
}
1269612728
}
1269712729

0 commit comments

Comments
 (0)