Skip to content

Commit 3352043

Browse files
committed
Added new ggml_repeat2 API function in support of "broadcast" style KV duplication required for n_head_kv>1 config of Falcon-40B
1 parent 27cf1ad commit 3352043

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

include/ggml/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ extern "C" {
329329
GGML_OP_MAP_BINARY,
330330

331331
GGML_OP_COUNT,
332+
333+
GGML_OP_REPEAT2,
332334
};
333335

334336

@@ -638,6 +640,11 @@ extern "C" {
638640
struct ggml_tensor * a,
639641
struct ggml_tensor * b);
640642

643+
GGML_API struct ggml_tensor * ggml_repeat2(
644+
struct ggml_context * ctx,
645+
struct ggml_tensor * a,
646+
struct ggml_tensor * b);
647+
641648
GGML_API struct ggml_tensor * ggml_abs(
642649
struct ggml_context * ctx,
643650
struct ggml_tensor * a);

src/ggml.c

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5066,6 +5066,34 @@ struct ggml_tensor * ggml_repeat(
50665066
return result;
50675067
}
50685068

5069+
// ggml_repeat2
5070+
5071+
struct ggml_tensor * ggml_repeat2(
5072+
struct ggml_context * ctx,
5073+
struct ggml_tensor * a,
5074+
struct ggml_tensor * b) {
5075+
GGML_ASSERT(ggml_can_repeat(a, b));
5076+
5077+
bool is_node = false;
5078+
5079+
if (a->grad) {
5080+
is_node = true;
5081+
}
5082+
5083+
if (ggml_are_same_shape(a, b) && !is_node) {
5084+
return a;
5085+
}
5086+
5087+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5088+
5089+
result->op = GGML_OP_REPEAT2;
5090+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5091+
result->src0 = a;
5092+
result->src1 = b;
5093+
5094+
return result;
5095+
}
5096+
50695097
// ggml_abs
50705098

50715099
struct ggml_tensor * ggml_abs_impl(
@@ -8847,6 +8875,87 @@ static void ggml_compute_forward_repeat(
88478875
}
88488876
}
88498877

8878+
// ggml_compute_forward_repeat2
8879+
8880+
static void ggml_compute_forward_repeat2_f32(
8881+
const struct ggml_compute_params * params,
8882+
const struct ggml_tensor * src0,
8883+
struct ggml_tensor * dst) {
8884+
GGML_ASSERT(params->ith == 0);
8885+
GGML_ASSERT(ggml_can_repeat(src0, dst));
8886+
8887+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8888+
return;
8889+
}
8890+
8891+
const int64_t ne0 = dst->ne[0];
8892+
const int64_t ne1 = dst->ne[1];
8893+
const int64_t ne2 = dst->ne[2];
8894+
const int64_t ne3 = dst->ne[3];
8895+
8896+
const int64_t ne00 = src0->ne[0];
8897+
const int64_t ne01 = src0->ne[1];
8898+
const int64_t ne02 = src0->ne[2];
8899+
const int64_t ne03 = src0->ne[3];
8900+
8901+
const size_t nb0 = dst->nb[0];
8902+
const size_t nb1 = dst->nb[1];
8903+
const size_t nb2 = dst->nb[2];
8904+
const size_t nb3 = dst->nb[3];
8905+
8906+
const size_t nb00 = src0->nb[0];
8907+
const size_t nb01 = src0->nb[1];
8908+
const size_t nb02 = src0->nb[2];
8909+
const size_t nb03 = src0->nb[3];
8910+
8911+
// guaranteed to be an integer due to the check in ggml_can_repeat
8912+
const int nr0 = (int)(ne0/ne00);
8913+
const int nr1 = (int)(ne1/ne01);
8914+
const int nr2 = (int)(ne2/ne02);
8915+
const int nr3 = (int)(ne3/ne03);
8916+
8917+
// TODO: support for transposed / permuted tensors
8918+
GGML_ASSERT(nb0 == sizeof(float));
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
int i2k2 = 0;
8922+
8923+
// TODO: maybe this is not optimal?
8924+
for (int i3 = 0; i3 < nr3; i3++) {
8925+
for (int k3 = 0; k3 < ne03; k3++, i2k2 = 0) {
8926+
for (int i2 = 0; i2 < nr2; i2++) {
8927+
for (int k2 = 0; k2 < ne02; k2++, i2k2++) {
8928+
for (int i1 = 0; i1 < nr1; i1++) {
8929+
for (int k1 = 0; k1 < ne01; k1++) {
8930+
for (int i0 = 0; i0 < nr0; i0++) {
8931+
ggml_vec_cpy_f32(ne00,
8932+
(float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
8933+
(float *) ((char *) src0->data + ( k3)*nb03 + (i2k2 / nr2)*nb02 + ( k1)*nb01));
8934+
}
8935+
}
8936+
}
8937+
}
8938+
}
8939+
}
8940+
}
8941+
}
8942+
8943+
static void ggml_compute_forward_repeat2(
8944+
const struct ggml_compute_params * params,
8945+
const struct ggml_tensor * src0,
8946+
struct ggml_tensor * dst) {
8947+
switch (src0->type) {
8948+
case GGML_TYPE_F32:
8949+
{
8950+
ggml_compute_forward_repeat2_f32(params, src0, dst);
8951+
} break;
8952+
default:
8953+
{
8954+
GGML_ASSERT(false);
8955+
} break;
8956+
}
8957+
}
8958+
88508959
// ggml_compute_forward_abs
88518960

88528961
static void ggml_compute_forward_abs_f32(
@@ -13336,6 +13445,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1333613445
{
1333713446
ggml_compute_forward_repeat(params, tensor->src0, tensor);
1333813447
} break;
13448+
case GGML_OP_REPEAT2:
13449+
{
13450+
ggml_compute_forward_repeat2(params, tensor->src0, tensor);
13451+
} break;
1333913452
case GGML_OP_ABS:
1334013453
{
1334113454
ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13684,6 +13797,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1368413797
GGML_ASSERT(false); // TODO: implement
1368513798
} break;
1368613799
case GGML_OP_REPEAT:
13800+
case GGML_OP_REPEAT2:
1368713801
{
1368813802
// necessary for llama
1368913803
if (src0->grad) {
@@ -14568,6 +14682,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1456814682
case GGML_OP_SUM_ROWS:
1456914683
case GGML_OP_MEAN:
1457014684
case GGML_OP_REPEAT:
14685+
case GGML_OP_REPEAT2:
1457114686
case GGML_OP_ABS:
1457214687
case GGML_OP_SGN:
1457314688
case GGML_OP_NEG:

0 commit comments

Comments
 (0)