Skip to content

Commit 27cf1ad

Browse files
committed
Experimental support for Falcon-40B (and Falcon-7B); breaks 7B GGML conversion format in obvious way because of new hparam n_head_kv; does not match 40B output exactly because of wrong Kcur broadcast; poor implementation using unnecessary (?) memcpys due to weird memory layout of QKV weights with n_head_kv > 1
1 parent a25229d commit 27cf1ad

File tree

2 files changed

+154
-56
lines changed

2 files changed

+154
-56
lines changed

examples/falcon/convert-hf-to-ggml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def bytes_to_unicode():
7777
fout.write(struct.pack("i", hparams["vocab_size"]))
7878
fout.write(struct.pack("i", hparams["hidden_size"]))
7979
fout.write(struct.pack("i", hparams["n_head"]))
80+
fout.write(struct.pack("i", hparams["n_head_kv"] if "n_head_kv" in hparams else 1))
8081
fout.write(struct.pack("i", hparams["n_layer"]))
8182
fout.write(struct.pack("i", ftype))
8283

examples/falcon/main.cpp

Lines changed: 153 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414
#include <vector>
1515

1616
// default hparams (Falcon 7B)
17-
// TODO add n_head_kv to support 40B
1817
struct falcon_hparams {
1918
int32_t n_vocab = 65024;
2019
int32_t n_ctx = 2048;
2120
int32_t n_embd = 4544;
2221
int32_t n_head = 71;
22+
int32_t n_head_kv = 1;
2323
int32_t n_layer = 32;
2424
int32_t ftype = 1;
2525
};
2626

2727
struct falcon_layer {
2828
// normalization
29-
struct ggml_tensor* attention_norm;
30-
struct ggml_tensor* attention_norm_b;
29+
struct ggml_tensor* input_layernorm;
30+
struct ggml_tensor* input_layernorm_b;
31+
struct ggml_tensor* attention_norm; // Falcon-40B only
32+
struct ggml_tensor* attention_norm_b; // Falcon-40B only
3133

3234
// attention
3335
struct ggml_tensor* query_key_value;
@@ -83,17 +85,19 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
8385
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
8486
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
8587
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
88+
fin.read((char *) &hparams.n_head_kv, sizeof(hparams.n_head_kv));
8689
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
8790
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
8891

8992
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
9093

91-
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
92-
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
93-
printf("%s: n_head = %d\n", __func__, hparams.n_head);
94-
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
95-
printf("%s: ftype = %d\n", __func__, hparams.ftype);
96-
printf("%s: qntvr = %d\n", __func__, qntvr);
94+
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
95+
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
96+
printf("%s: n_head = %d\n", __func__, hparams.n_head);
97+
printf("%s: n_head_kv = %d\n", __func__, hparams.n_head_kv);
98+
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
99+
printf("%s: ftype = %d\n", __func__, hparams.ftype);
100+
printf("%s: qntvr = %d\n", __func__, qntvr);
97101

98102
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
99103
}
@@ -136,6 +140,7 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
136140

137141
const int n_embd = hparams.n_embd;
138142
const int n_head = hparams.n_head;
143+
const int n_head_kv = hparams.n_head_kv;
139144
const int n_layer = hparams.n_layer;
140145
const int n_ctx = hparams.n_ctx;
141146
const int n_ff = 4 * model.hparams.n_embd;
@@ -152,13 +157,22 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
152157

153158
ctx_size +=
154159
n_layer *
155-
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
160+
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // input_layernorm
156161
ctx_size +=
157162
n_layer *
158-
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // attention_norm_b
163+
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // input_layernorm_b
164+
165+
if (n_head_kv > 1) { // Falcon-40B
166+
ctx_size +=
167+
n_layer *
168+
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
169+
ctx_size +=
170+
n_layer *
171+
(n_embd * ggml_type_sizef(GGML_TYPE_F32)); // attention_norm_b
172+
}
159173

160-
ctx_size += n_layer * (n_embd * (n_embd + 2 * (n_embd / n_head)) *
161-
ggml_type_sizef(wtype)); // query_key_value
174+
ctx_size += n_layer * n_embd * (n_head_kv * 2 + n_head) * head_dim *
175+
ggml_type_sizef(wtype); // query_key_value
162176
ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // wo
163177

164178
ctx_size +=
@@ -171,9 +185,9 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
171185
ctx_size +=
172186
n_layer * (n_ff * n_embd * ggml_type_sizef(wtype)); // ffn_down
173187

174-
ctx_size += n_ctx * n_layer * head_dim *
188+
ctx_size += n_ctx * n_layer * n_head_kv * head_dim *
175189
ggml_type_sizef(GGML_TYPE_F32); // memory_k
176-
ctx_size += n_ctx * n_layer * head_dim *
190+
ctx_size += n_ctx * n_layer * n_head_kv * head_dim *
177191
ggml_type_sizef(GGML_TYPE_F32); // memory_v
178192

179193
ctx_size += (5 + 10 * n_layer) * 256; // object overhead TODO:
@@ -201,9 +215,11 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
201215

202216
const int n_embd = hparams.n_embd;
203217
const int n_head = hparams.n_head;
218+
const int n_head_kv = hparams.n_head_kv;
204219
const int n_layer = hparams.n_layer;
205220
const int n_ff = 4 * model.hparams.n_embd;
206221
const int n_vocab = hparams.n_vocab;
222+
const int head_dim = hparams.n_embd / hparams.n_head;
207223

208224
model.layers.resize(n_layer);
209225

@@ -224,24 +240,42 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
224240
for (int i = 0; i < n_layer; ++i) {
225241
auto& layer = model.layers[i];
226242

227-
layer.attention_norm =
243+
layer.input_layernorm =
228244
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
229-
layer.attention_norm_b =
245+
layer.input_layernorm_b =
230246
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
231247

248+
if (n_head_kv > 1) { // Falcon-40B
249+
layer.attention_norm =
250+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
251+
layer.attention_norm_b =
252+
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
253+
}
254+
232255
// query_key_value shape for config.multi_query == True:
233256
layer.query_key_value = ggml_new_tensor_2d(
234-
ctx, wtype, n_embd, n_embd + 2 * (n_embd / n_head));
257+
ctx, wtype, n_embd, (n_head_kv * 2 + n_head) * head_dim);
235258
layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
236259

237260
layer.ffn_up = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
238261
layer.ffn_down = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
239262

240263
// map by name
264+
// Falcon-7B:
265+
model.tensors["transformer.h." + std::to_string(i) +
266+
".input_layernorm.weight"] = layer.input_layernorm;
267+
model.tensors["transformer.h." + std::to_string(i) +
268+
".input_layernorm.bias"] = layer.input_layernorm_b;
269+
270+
// Falcon-40B:
271+
model.tensors["transformer.h." + std::to_string(i) +
272+
".ln_mlp.weight"] = layer.input_layernorm;
241273
model.tensors["transformer.h." + std::to_string(i) +
242-
".input_layernorm.weight"] = layer.attention_norm;
274+
".ln_mlp.bias"] = layer.input_layernorm_b;
243275
model.tensors["transformer.h." + std::to_string(i) +
244-
".input_layernorm.bias"] = layer.attention_norm_b;
276+
".ln_attn.weight"] = layer.attention_norm;
277+
model.tensors["transformer.h." + std::to_string(i) +
278+
".ln_attn.bias"] = layer.attention_norm_b;
245279

246280
model.tensors["transformer.h." + std::to_string(i) +
247281
".self_attention.query_key_value.weight"] =
@@ -262,13 +296,14 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
262296

263297
const int n_layer = hparams.n_layer;
264298
const int n_ctx = hparams.n_ctx;
299+
const int n_head_kv = hparams.n_head_kv;
265300
const int head_dim = hparams.n_embd / hparams.n_head;
266301

267302
const int64_t n_mem = n_layer*n_ctx;
268303
const int64_t n_elements = head_dim*n_mem;
269304

270-
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
271-
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
305+
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_head_kv * n_elements);
306+
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_head_kv * n_elements);
272307

273308
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
274309

@@ -378,6 +413,7 @@ bool falcon_eval(
378413
const int n_layer = hparams.n_layer;
379414
const int n_ctx = hparams.n_ctx;
380415
const int n_head = hparams.n_head;
416+
const int n_head_kv = hparams.n_head_kv;
381417
const int n_vocab = hparams.n_vocab;
382418
const size_t head_dim = n_embd / n_head;
383419

@@ -420,7 +456,10 @@ bool falcon_eval(
420456

421457
// wte
422458
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
423-
struct ggml_tensor* repeat_dummy = ggml_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head);
459+
struct ggml_tensor* repeat_dummy = ggml_new_tensor_4d(ctx0, inpL->type, head_dim, N + n_past, n_head, 1);
460+
461+
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
462+
const int sizeof_wtype = ggml_type_sizef(wtype);
424463

425464
for (int il = 0; il < n_layer; ++il) {
426465
struct ggml_tensor * cur;
@@ -430,7 +469,15 @@ bool falcon_eval(
430469

431470
// self-attention
432471
{
433-
{
472+
layernorm_output = ggml_norm(ctx0, inpL);
473+
474+
layernorm_output = ggml_add(ctx0,
475+
ggml_mul(ctx0,
476+
ggml_repeat(ctx0, model.layers[il].input_layernorm, layernorm_output),
477+
layernorm_output),
478+
ggml_repeat(ctx0, model.layers[il].input_layernorm_b, layernorm_output));
479+
480+
if (n_head_kv > 1) { // Falcon-40B only
434481
cur = ggml_norm(ctx0, inpL);
435482

436483
cur = ggml_add(ctx0,
@@ -439,71 +486,120 @@ bool falcon_eval(
439486
cur),
440487
ggml_repeat(ctx0, model.layers[il].attention_norm_b, cur));
441488
}
442-
layernorm_output = cur;
489+
else {
490+
cur = layernorm_output;
491+
}
443492

444493
// compute QKV
445494
cur = ggml_mul_mat(ctx0, model.layers[il].query_key_value, cur);
446495

447-
size_t fused_qkv_row_nb =
448-
(n_embd + 2 * (n_embd / n_head)) * sizeof(float);
496+
struct ggml_tensor* Qcur = ggml_view_4d(
497+
ctx0, cur, head_dim, n_head / n_head_kv, n_head_kv, N,
498+
head_dim * sizeof_wtype,
499+
head_dim * (n_head / n_head_kv + 2) * sizeof_wtype,
500+
head_dim * (n_head / n_head_kv + 2) * n_head_kv * sizeof_wtype,
501+
0);
502+
503+
struct ggml_tensor* Kcur = ggml_view_4d(
504+
ctx0, cur, head_dim, 1, N, n_head_kv,
505+
head_dim * sizeof_wtype,
506+
head_dim * (n_head / n_head_kv + 2) * sizeof_wtype,
507+
head_dim * (n_head / n_head_kv + 2) * N * sizeof_wtype,
508+
head_dim * (n_head / n_head_kv) * sizeof_wtype);
509+
510+
struct ggml_tensor* Vcur = ggml_view_4d(
511+
ctx0, cur, head_dim, 1, N, n_head_kv,
512+
head_dim * sizeof_wtype,
513+
head_dim * (n_head / n_head_kv + 2) * sizeof_wtype,
514+
head_dim * (n_head / n_head_kv + 2) * N * sizeof_wtype,
515+
head_dim * (n_head / n_head_kv + 1) * sizeof_wtype);
516+
517+
// TODO: The crazy piecewise copying below works (well, until GGML_MAX_NODES is hit),
518+
// but it surely cannot remain so in production.
519+
// As for the necessity of it, consider for example n_head_kv=2
520+
// and n_head=4, head_dim=64. Unfortunately with that config we have addressing like
521+
// offset = i * 512 + kv_group * 256 + head_in_group * 64
522+
// required to collect each individual query vector in Qcur.
523+
// I don't think it can be expressed using view or reshape.
524+
// Maybe the GGML conversion could do something to alleviate it
525+
// so that we can get rid of it.
449526

450-
struct ggml_tensor* Qcur =
451-
ggml_view_3d(ctx0, cur, head_dim, n_head, N,
452-
head_dim * sizeof(float), fused_qkv_row_nb, 0);
527+
struct ggml_tensor * Q =
528+
ggml_new_tensor_3d(ctx0, wtype, head_dim, n_head, N);
453529

454-
struct ggml_tensor* Kcur = ggml_view_3d(
455-
ctx0, cur, head_dim, 1, N, head_dim * sizeof(float),
456-
fused_qkv_row_nb, n_embd * sizeof(float));
530+
for (int i = 0; i < N; i++) {
531+
for (int group = 0; group < n_head_kv; group++) {
532+
for (int member = 0; member < n_head / n_head_kv; member++) {
533+
size_t src_offset =
534+
(i * (n_head + 2 * n_head_kv) * head_dim +
535+
group * (n_head / n_head_kv + 2) * head_dim +
536+
member * head_dim) * sizeof_wtype;
457537

458-
struct ggml_tensor* Vcur = ggml_view_3d(
459-
ctx0, cur, head_dim, 1, N, head_dim * sizeof(float),
460-
fused_qkv_row_nb, (n_embd + head_dim) * sizeof(float));
538+
size_t dst_offset =
539+
(Q->nb[1] * (group * n_head_kv + member) +
540+
Q->nb[2] * i);
461541

542+
struct ggml_tensor* src = ggml_view_1d(
543+
ctx0, Qcur, head_dim, src_offset);
544+
545+
struct ggml_tensor* dst = ggml_view_1d(
546+
ctx0, Q, head_dim, dst_offset);
547+
548+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, src, dst));
549+
}
550+
}
551+
}
552+
462553
// using mode = 2 for neox mode
463-
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, head_dim, 2);
554+
Q = ggml_rope_inplace(ctx0, Q, n_past, head_dim, 2);
464555
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, head_dim, 2);
465556

466557
// store key and value to memory
467558
{
468559
struct ggml_tensor* k = ggml_view_1d(
469-
ctx0, model.memory_k, N * head_dim,
470-
(ggml_element_size(model.memory_k) * head_dim) *
560+
ctx0, model.memory_k, N * n_head_kv * head_dim,
561+
(ggml_element_size(model.memory_k) * n_head_kv * head_dim) *
471562
(il * n_ctx + n_past));
472563
struct ggml_tensor* v = ggml_view_1d(
473-
ctx0, model.memory_v, N * head_dim,
474-
(ggml_element_size(model.memory_v) * head_dim) *
564+
ctx0, model.memory_v, N * n_head_kv * head_dim,
565+
(ggml_element_size(model.memory_v) * n_head_kv * head_dim) *
475566
(il * n_ctx + n_past));
476567

477568
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
478569
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
479570
}
480571

481-
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
482-
struct ggml_tensor * Q =
483-
ggml_permute(ctx0,
484-
Qcur,
485-
0, 2, 1, 3);
486-
487572
struct ggml_tensor* K = ggml_permute(
488573
ctx0,
489574
ggml_reshape_3d(
490575
ctx0,
491-
ggml_view_1d(ctx0, model.memory_k, (n_past + N) * head_dim,
576+
ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_head_kv * head_dim,
492577
il * n_ctx *
493578
ggml_element_size(model.memory_k) *
579+
n_head_kv *
494580
head_dim),
495-
head_dim, 1, n_past + N),
581+
head_dim, n_head_kv, n_past + N),
496582
0, 2, 1, 3);
497583

498584
// K * Q
585+
586+
// TODO Unfortunately this ggml_repeat does not do what we need it to do:
587+
// [ K1, K2 ] will be broadcast into [ [K1, K2], [K1, K2] ], while we actually
588+
// need them to become [ [K1, K1], [K2, K2] ] ... And I suppose there will be same
589+
// problem with V below as well.
590+
// Here too perhaps GGML conversion could do some preprocessing to obtain
591+
// a more GGML-friendly memory format.
592+
499593
K = ggml_cont(ctx0, ggml_repeat(ctx0, K, repeat_dummy));
594+
Q = ggml_permute(ctx0, Q, 0, 2, 1, 3);
595+
500596
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
501597

502598
// KQ_scaled = KQ / sqrt(n_embd/n_head)
503599
struct ggml_tensor * KQ_scaled =
504600
ggml_scale_inplace(ctx0,
505601
KQ,
506-
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
602+
ggml_new_f32(ctx0, 1.0f/sqrt(float(head_dim)))
507603
);
508604

509605
// KQ_masked = mask_past(KQ_scaled)
@@ -515,15 +611,16 @@ bool falcon_eval(
515611
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
516612
struct ggml_tensor* V = ggml_permute(
517613
ctx0,
518-
ggml_reshape_3d(
614+
ggml_reshape_4d(
519615
ctx0,
520-
ggml_view_1d(ctx0, model.memory_v, (n_past + N) * head_dim,
616+
ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_head_kv * head_dim,
521617
il * n_ctx *
522618
ggml_element_size(model.memory_v) *
619+
n_head_kv *
523620
head_dim),
524-
head_dim, 1, n_past + N),
525-
0, 2, 1, 3);
526-
621+
head_dim, 1, n_head_kv, n_past + N),
622+
0, 3, 2, 1);
623+
527624
V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat(ctx0, V, repeat_dummy)));
528625

529626
// KQV = transpose(V) * KQ_soft_max
@@ -757,4 +854,4 @@ int main(int argc, char ** argv) {
757854
ggml_free(model.ctx);
758855

759856
return 0;
760-
}
857+
}

0 commit comments

Comments
 (0)