14
14
#include < vector>
15
15
16
16
// default hparams (Falcon 7B)
17
- // TODO add n_head_kv to support 40B
18
17
struct falcon_hparams {
19
18
int32_t n_vocab = 65024 ;
20
19
int32_t n_ctx = 2048 ;
21
20
int32_t n_embd = 4544 ;
22
21
int32_t n_head = 71 ;
22
+ int32_t n_head_kv = 1 ;
23
23
int32_t n_layer = 32 ;
24
24
int32_t ftype = 1 ;
25
25
};
26
26
27
27
struct falcon_layer {
28
28
// normalization
29
- struct ggml_tensor * attention_norm;
30
- struct ggml_tensor * attention_norm_b;
29
+ struct ggml_tensor * input_layernorm;
30
+ struct ggml_tensor * input_layernorm_b;
31
+ struct ggml_tensor * attention_norm; // Falcon-40B only
32
+ struct ggml_tensor * attention_norm_b; // Falcon-40B only
31
33
32
34
// attention
33
35
struct ggml_tensor * query_key_value;
@@ -83,17 +85,19 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
83
85
fin.read ((char *) &hparams.n_vocab , sizeof (hparams.n_vocab ));
84
86
fin.read ((char *) &hparams.n_embd , sizeof (hparams.n_embd ));
85
87
fin.read ((char *) &hparams.n_head , sizeof (hparams.n_head ));
88
+ fin.read ((char *) &hparams.n_head_kv , sizeof (hparams.n_head_kv ));
86
89
fin.read ((char *) &hparams.n_layer , sizeof (hparams.n_layer ));
87
90
fin.read ((char *) &hparams.ftype , sizeof (hparams.ftype ));
88
91
89
92
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
90
93
91
- printf (" %s: n_vocab = %d\n " , __func__, hparams.n_vocab );
92
- printf (" %s: n_embd = %d\n " , __func__, hparams.n_embd );
93
- printf (" %s: n_head = %d\n " , __func__, hparams.n_head );
94
- printf (" %s: n_layer = %d\n " , __func__, hparams.n_layer );
95
- printf (" %s: ftype = %d\n " , __func__, hparams.ftype );
96
- printf (" %s: qntvr = %d\n " , __func__, qntvr);
94
+ printf (" %s: n_vocab = %d\n " , __func__, hparams.n_vocab );
95
+ printf (" %s: n_embd = %d\n " , __func__, hparams.n_embd );
96
+ printf (" %s: n_head = %d\n " , __func__, hparams.n_head );
97
+ printf (" %s: n_head_kv = %d\n " , __func__, hparams.n_head_kv );
98
+ printf (" %s: n_layer = %d\n " , __func__, hparams.n_layer );
99
+ printf (" %s: ftype = %d\n " , __func__, hparams.ftype );
100
+ printf (" %s: qntvr = %d\n " , __func__, qntvr);
97
101
98
102
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
99
103
}
@@ -136,6 +140,7 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
136
140
137
141
const int n_embd = hparams.n_embd ;
138
142
const int n_head = hparams.n_head ;
143
+ const int n_head_kv = hparams.n_head_kv ;
139
144
const int n_layer = hparams.n_layer ;
140
145
const int n_ctx = hparams.n_ctx ;
141
146
const int n_ff = 4 * model.hparams .n_embd ;
@@ -152,13 +157,22 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
152
157
153
158
ctx_size +=
154
159
n_layer *
155
- (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // attention_norm
160
+ (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // input_layernorm
156
161
ctx_size +=
157
162
n_layer *
158
- (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // attention_norm_b
163
+ (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // input_layernorm_b
164
+
165
+ if (n_head_kv > 1 ) { // Falcon-40B
166
+ ctx_size +=
167
+ n_layer *
168
+ (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // attention_norm
169
+ ctx_size +=
170
+ n_layer *
171
+ (n_embd * ggml_type_sizef (GGML_TYPE_F32)); // attention_norm_b
172
+ }
159
173
160
- ctx_size += n_layer * ( n_embd * (n_embd + 2 * (n_embd / n_head)) *
161
- ggml_type_sizef (wtype)) ; // query_key_value
174
+ ctx_size += n_layer * n_embd * (n_head_kv * 2 + n_head) * head_dim *
175
+ ggml_type_sizef (wtype); // query_key_value
162
176
ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef (wtype)); // wo
163
177
164
178
ctx_size +=
@@ -171,9 +185,9 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
171
185
ctx_size +=
172
186
n_layer * (n_ff * n_embd * ggml_type_sizef (wtype)); // ffn_down
173
187
174
- ctx_size += n_ctx * n_layer * head_dim *
188
+ ctx_size += n_ctx * n_layer * n_head_kv * head_dim *
175
189
ggml_type_sizef (GGML_TYPE_F32); // memory_k
176
- ctx_size += n_ctx * n_layer * head_dim *
190
+ ctx_size += n_ctx * n_layer * n_head_kv * head_dim *
177
191
ggml_type_sizef (GGML_TYPE_F32); // memory_v
178
192
179
193
ctx_size += (5 + 10 * n_layer) * 256 ; // object overhead TODO:
@@ -201,9 +215,11 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
201
215
202
216
const int n_embd = hparams.n_embd ;
203
217
const int n_head = hparams.n_head ;
218
+ const int n_head_kv = hparams.n_head_kv ;
204
219
const int n_layer = hparams.n_layer ;
205
220
const int n_ff = 4 * model.hparams .n_embd ;
206
221
const int n_vocab = hparams.n_vocab ;
222
+ const int head_dim = hparams.n_embd / hparams.n_head ;
207
223
208
224
model.layers .resize (n_layer);
209
225
@@ -224,24 +240,42 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
224
240
for (int i = 0 ; i < n_layer; ++i) {
225
241
auto & layer = model.layers [i];
226
242
227
- layer.attention_norm =
243
+ layer.input_layernorm =
228
244
ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_embd);
229
- layer.attention_norm_b =
245
+ layer.input_layernorm_b =
230
246
ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_embd);
231
247
248
+ if (n_head_kv > 1 ) { // Falcon-40B
249
+ layer.attention_norm =
250
+ ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_embd);
251
+ layer.attention_norm_b =
252
+ ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_embd);
253
+ }
254
+
232
255
// query_key_value shape for config.multi_query == True:
233
256
layer.query_key_value = ggml_new_tensor_2d (
234
- ctx, wtype, n_embd, n_embd + 2 * (n_embd / n_head));
257
+ ctx, wtype, n_embd, (n_head_kv * 2 + n_head) * head_dim );
235
258
layer.wo = ggml_new_tensor_2d (ctx, wtype, n_embd, n_embd);
236
259
237
260
layer.ffn_up = ggml_new_tensor_2d (ctx, wtype, n_embd, n_ff);
238
261
layer.ffn_down = ggml_new_tensor_2d (ctx, wtype, n_ff, n_embd);
239
262
240
263
// map by name
264
+ // Falcon-7B:
265
+ model.tensors [" transformer.h." + std::to_string (i) +
266
+ " .input_layernorm.weight" ] = layer.input_layernorm ;
267
+ model.tensors [" transformer.h." + std::to_string (i) +
268
+ " .input_layernorm.bias" ] = layer.input_layernorm_b ;
269
+
270
+ // Falcon-40B:
271
+ model.tensors [" transformer.h." + std::to_string (i) +
272
+ " .ln_mlp.weight" ] = layer.input_layernorm ;
241
273
model.tensors [" transformer.h." + std::to_string (i) +
242
- " .input_layernorm.weight " ] = layer.attention_norm ;
274
+ " .ln_mlp.bias " ] = layer.input_layernorm_b ;
243
275
model.tensors [" transformer.h." + std::to_string (i) +
244
- " .input_layernorm.bias" ] = layer.attention_norm_b ;
276
+ " .ln_attn.weight" ] = layer.attention_norm ;
277
+ model.tensors [" transformer.h." + std::to_string (i) +
278
+ " .ln_attn.bias" ] = layer.attention_norm_b ;
245
279
246
280
model.tensors [" transformer.h." + std::to_string (i) +
247
281
" .self_attention.query_key_value.weight" ] =
@@ -262,13 +296,14 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt_voca
262
296
263
297
const int n_layer = hparams.n_layer ;
264
298
const int n_ctx = hparams.n_ctx ;
299
+ const int n_head_kv = hparams.n_head_kv ;
265
300
const int head_dim = hparams.n_embd / hparams.n_head ;
266
301
267
302
const int64_t n_mem = n_layer*n_ctx;
268
303
const int64_t n_elements = head_dim*n_mem;
269
304
270
- model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_elements);
271
- model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_elements);
305
+ model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_head_kv * n_elements);
306
+ model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_head_kv * n_elements);
272
307
273
308
const size_t memory_size = ggml_nbytes (model.memory_k ) + ggml_nbytes (model.memory_v );
274
309
@@ -378,6 +413,7 @@ bool falcon_eval(
378
413
const int n_layer = hparams.n_layer ;
379
414
const int n_ctx = hparams.n_ctx ;
380
415
const int n_head = hparams.n_head ;
416
+ const int n_head_kv = hparams.n_head_kv ;
381
417
const int n_vocab = hparams.n_vocab ;
382
418
const size_t head_dim = n_embd / n_head;
383
419
@@ -420,7 +456,10 @@ bool falcon_eval(
420
456
421
457
// wte
422
458
struct ggml_tensor * inpL = ggml_get_rows (ctx0, model.tok_embeddings , embd);
423
- struct ggml_tensor * repeat_dummy = ggml_new_tensor_3d (ctx0, inpL->type , head_dim, N + n_past, n_head);
459
+ struct ggml_tensor * repeat_dummy = ggml_new_tensor_4d (ctx0, inpL->type , head_dim, N + n_past, n_head, 1 );
460
+
461
+ ggml_type wtype = ggml_ftype_to_ggml_type ((ggml_ftype) (model.hparams .ftype ));
462
+ const int sizeof_wtype = ggml_type_sizef (wtype);
424
463
425
464
for (int il = 0 ; il < n_layer; ++il) {
426
465
struct ggml_tensor * cur;
@@ -430,7 +469,15 @@ bool falcon_eval(
430
469
431
470
// self-attention
432
471
{
433
- {
472
+ layernorm_output = ggml_norm (ctx0, inpL);
473
+
474
+ layernorm_output = ggml_add (ctx0,
475
+ ggml_mul (ctx0,
476
+ ggml_repeat (ctx0, model.layers [il].input_layernorm , layernorm_output),
477
+ layernorm_output),
478
+ ggml_repeat (ctx0, model.layers [il].input_layernorm_b , layernorm_output));
479
+
480
+ if (n_head_kv > 1 ) { // Falcon-40B only
434
481
cur = ggml_norm (ctx0, inpL);
435
482
436
483
cur = ggml_add (ctx0,
@@ -439,71 +486,120 @@ bool falcon_eval(
439
486
cur),
440
487
ggml_repeat (ctx0, model.layers [il].attention_norm_b , cur));
441
488
}
442
- layernorm_output = cur;
489
+ else {
490
+ cur = layernorm_output;
491
+ }
443
492
444
493
// compute QKV
445
494
cur = ggml_mul_mat (ctx0, model.layers [il].query_key_value , cur);
446
495
447
- size_t fused_qkv_row_nb =
448
- (n_embd + 2 * (n_embd / n_head)) * sizeof (float );
496
+ struct ggml_tensor * Qcur = ggml_view_4d (
497
+ ctx0, cur, head_dim, n_head / n_head_kv, n_head_kv, N,
498
+ head_dim * sizeof_wtype,
499
+ head_dim * (n_head / n_head_kv + 2 ) * sizeof_wtype,
500
+ head_dim * (n_head / n_head_kv + 2 ) * n_head_kv * sizeof_wtype,
501
+ 0 );
502
+
503
+ struct ggml_tensor * Kcur = ggml_view_4d (
504
+ ctx0, cur, head_dim, 1 , N, n_head_kv,
505
+ head_dim * sizeof_wtype,
506
+ head_dim * (n_head / n_head_kv + 2 ) * sizeof_wtype,
507
+ head_dim * (n_head / n_head_kv + 2 ) * N * sizeof_wtype,
508
+ head_dim * (n_head / n_head_kv) * sizeof_wtype);
509
+
510
+ struct ggml_tensor * Vcur = ggml_view_4d (
511
+ ctx0, cur, head_dim, 1 , N, n_head_kv,
512
+ head_dim * sizeof_wtype,
513
+ head_dim * (n_head / n_head_kv + 2 ) * sizeof_wtype,
514
+ head_dim * (n_head / n_head_kv + 2 ) * N * sizeof_wtype,
515
+ head_dim * (n_head / n_head_kv + 1 ) * sizeof_wtype);
516
+
517
+ // TODO: The crazy piecewise copying below works (well, until GGML_MAX_NODES is hit),
518
+ // but it surely cannot remain so in production.
519
+ // As for the necessity of it, consider for example n_head_kv=2
520
+ // and n_head=4, head_dim=64. Unfortunately with that config we have addressing like
521
+ // offset = i * 512 + kv_group * 256 + head_in_group * 64
522
+ // required to collect each individual query vector in Qcur.
523
+ // I don't think it can be expressed using view or reshape.
524
+ // Maybe the GGML conversion could do something to alleviate it
525
+ // so that we can get rid of it.
449
526
450
- struct ggml_tensor * Qcur =
451
- ggml_view_3d (ctx0, cur, head_dim, n_head, N,
452
- head_dim * sizeof (float ), fused_qkv_row_nb, 0 );
527
+ struct ggml_tensor * Q =
528
+ ggml_new_tensor_3d (ctx0, wtype, head_dim, n_head, N);
453
529
454
- struct ggml_tensor * Kcur = ggml_view_3d (
455
- ctx0, cur, head_dim, 1 , N, head_dim * sizeof (float ),
456
- fused_qkv_row_nb, n_embd * sizeof (float ));
530
+ for (int i = 0 ; i < N; i++) {
531
+ for (int group = 0 ; group < n_head_kv; group++) {
532
+ for (int member = 0 ; member < n_head / n_head_kv; member++) {
533
+ size_t src_offset =
534
+ (i * (n_head + 2 * n_head_kv) * head_dim +
535
+ group * (n_head / n_head_kv + 2 ) * head_dim +
536
+ member * head_dim) * sizeof_wtype;
457
537
458
- struct ggml_tensor * Vcur = ggml_view_3d (
459
- ctx0, cur, head_dim, 1 , N, head_dim * sizeof ( float ),
460
- fused_qkv_row_nb, (n_embd + head_dim) * sizeof ( float ) );
538
+ size_t dst_offset =
539
+ (Q-> nb [ 1 ] * (group * n_head_kv + member) +
540
+ Q-> nb [ 2 ] * i );
461
541
542
+ struct ggml_tensor * src = ggml_view_1d (
543
+ ctx0, Qcur, head_dim, src_offset);
544
+
545
+ struct ggml_tensor * dst = ggml_view_1d (
546
+ ctx0, Q, head_dim, dst_offset);
547
+
548
+ ggml_build_forward_expand (&gf, ggml_cpy (ctx0, src, dst));
549
+ }
550
+ }
551
+ }
552
+
462
553
// using mode = 2 for neox mode
463
- Qcur = ggml_rope_inplace (ctx0, Qcur , n_past, head_dim, 2 );
554
+ Q = ggml_rope_inplace (ctx0, Q , n_past, head_dim, 2 );
464
555
Kcur = ggml_rope_inplace (ctx0, Kcur, n_past, head_dim, 2 );
465
556
466
557
// store key and value to memory
467
558
{
468
559
struct ggml_tensor * k = ggml_view_1d (
469
- ctx0, model.memory_k , N * head_dim,
470
- (ggml_element_size (model.memory_k ) * head_dim) *
560
+ ctx0, model.memory_k , N * n_head_kv * head_dim,
561
+ (ggml_element_size (model.memory_k ) * n_head_kv * head_dim) *
471
562
(il * n_ctx + n_past));
472
563
struct ggml_tensor * v = ggml_view_1d (
473
- ctx0, model.memory_v , N * head_dim,
474
- (ggml_element_size (model.memory_v ) * head_dim) *
564
+ ctx0, model.memory_v , N * n_head_kv * head_dim,
565
+ (ggml_element_size (model.memory_v ) * n_head_kv * head_dim) *
475
566
(il * n_ctx + n_past));
476
567
477
568
ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Kcur, k));
478
569
ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Vcur, v));
479
570
}
480
571
481
- // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
482
- struct ggml_tensor * Q =
483
- ggml_permute (ctx0,
484
- Qcur,
485
- 0 , 2 , 1 , 3 );
486
-
487
572
struct ggml_tensor * K = ggml_permute (
488
573
ctx0,
489
574
ggml_reshape_3d (
490
575
ctx0,
491
- ggml_view_1d (ctx0, model.memory_k , (n_past + N) * head_dim,
576
+ ggml_view_1d (ctx0, model.memory_k , (n_past + N) * n_head_kv * head_dim,
492
577
il * n_ctx *
493
578
ggml_element_size (model.memory_k ) *
579
+ n_head_kv *
494
580
head_dim),
495
- head_dim, 1 , n_past + N),
581
+ head_dim, n_head_kv , n_past + N),
496
582
0 , 2 , 1 , 3 );
497
583
498
584
// K * Q
585
+
586
+ // TODO Unfortunately this ggml_repeat does not do what we need it to do:
587
+ // [ K1, K2 ] will be broadcast into [ [K1, K2], [K1, K2] ], while we actually
588
+ // need them to become [ [K1, K1], [K2, K2] ] ... And I suppose there will be same
589
+ // problem with V below as well.
590
+ // Here too perhaps GGML conversion could do some preprocessing to obtain
591
+ // a more GGML-friendly memory format.
592
+
499
593
K = ggml_cont (ctx0, ggml_repeat (ctx0, K, repeat_dummy));
594
+ Q = ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 );
595
+
500
596
struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
501
597
502
598
// KQ_scaled = KQ / sqrt(n_embd/n_head)
503
599
struct ggml_tensor * KQ_scaled =
504
600
ggml_scale_inplace (ctx0,
505
601
KQ,
506
- ggml_new_f32 (ctx0, 1 .0f /sqrt (float (n_embd)/n_head ))
602
+ ggml_new_f32 (ctx0, 1 .0f /sqrt (float (head_dim) ))
507
603
);
508
604
509
605
// KQ_masked = mask_past(KQ_scaled)
@@ -515,15 +611,16 @@ bool falcon_eval(
515
611
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
516
612
struct ggml_tensor * V = ggml_permute (
517
613
ctx0,
518
- ggml_reshape_3d (
614
+ ggml_reshape_4d (
519
615
ctx0,
520
- ggml_view_1d (ctx0, model.memory_v , (n_past + N) * head_dim,
616
+ ggml_view_1d (ctx0, model.memory_v , (n_past + N) * n_head_kv * head_dim,
521
617
il * n_ctx *
522
618
ggml_element_size (model.memory_v ) *
619
+ n_head_kv *
523
620
head_dim),
524
- head_dim, 1 , n_past + N),
525
- 0 , 2 , 1 , 3 );
526
-
621
+ head_dim, 1 , n_head_kv, n_past + N),
622
+ 0 , 3 , 2 , 1 );
623
+
527
624
V = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_repeat (ctx0, V, repeat_dummy)));
528
625
529
626
// KQV = transpose(V) * KQ_soft_max
@@ -757,4 +854,4 @@ int main(int argc, char ** argv) {
757
854
ggml_free (model.ctx );
758
855
759
856
return 0 ;
760
- }
857
+ }
0 commit comments