Skip to content

Commit 1868f6c

Browse files
committed
ggml, llama : avoid heavy V transpose + improvements
ggml : - added ggml_view_3d() - ggml_view_tensor() now inherits the stride too - reimplement ggml_cpy() to account for dst stride - no longer require tensor->data to be memory aligned llama : - compute RoPE on 32-bit tensors (should be more accurate) - store RoPE-ed K in the KV cache - store transposed V in the KV cache (significant speed-up) - avoid unnecessary Q copy
1 parent 53dbba7 commit 1868f6c

File tree

3 files changed

+214
-158
lines changed

3 files changed

+214
-158
lines changed

ggml.c

+167-128
Original file line numberDiff line numberDiff line change
@@ -3219,7 +3219,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
32193219
/*.pad =*/ { 0 },
32203220
};
32213221

3222-
ggml_assert_aligned(result->data);
3222+
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3223+
//ggml_assert_aligned(result->data);
32233224

32243225
for (int i = 0; i < n_dims; i++) {
32253226
result->ne[i] = ne[i];
@@ -3620,7 +3621,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
36203621
struct ggml_tensor * ggml_view_tensor(
36213622
struct ggml_context * ctx,
36223623
const struct ggml_tensor * src) {
3623-
return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3624+
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
3625+
3626+
result->nb[0] = src->nb[0];
3627+
result->nb[1] = src->nb[1];
3628+
result->nb[2] = src->nb[2];
3629+
result->nb[3] = src->nb[3];
3630+
3631+
return result;
36243632
}
36253633

36263634
////////////////////////////////////////////////////////////////////////////////
@@ -4510,6 +4518,37 @@ struct ggml_tensor * ggml_view_2d(
45104518
return result;
45114519
}
45124520

4521+
// ggml_view_3d
4522+
4523+
struct ggml_tensor * ggml_view_3d(
4524+
struct ggml_context * ctx,
4525+
struct ggml_tensor * a,
4526+
int64_t ne0,
4527+
int64_t ne1,
4528+
int64_t ne2,
4529+
size_t nb1,
4530+
size_t nb2,
4531+
size_t offset) {
4532+
if (a->grad) {
4533+
GGML_ASSERT(false); // gradient propagation is not supported
4534+
}
4535+
4536+
const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
4537+
4538+
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
4539+
4540+
result->nb[1] = nb1;
4541+
result->nb[2] = nb2;
4542+
result->nb[3] = result->nb[2]*ne2;
4543+
4544+
result->op = GGML_OP_VIEW;
4545+
result->grad = NULL;
4546+
result->src0 = a;
4547+
result->src1 = NULL; // TODO: maybe store the offset here?
4548+
4549+
return result;
4550+
}
4551+
45134552
// ggml_permute
45144553

45154554
struct ggml_tensor * ggml_permute(
@@ -4845,7 +4884,6 @@ static void ggml_compute_forward_dup_f16(
48454884
const struct ggml_tensor * src0,
48464885
struct ggml_tensor * dst) {
48474886
GGML_ASSERT(params->ith == 0);
4848-
GGML_ASSERT(ggml_is_contiguous(dst));
48494887
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
48504888

48514889
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -4862,85 +4900,96 @@ static void ggml_compute_forward_dup_f16(
48624900
const size_t nb02 = src0->nb[2];
48634901
const size_t nb03 = src0->nb[3];
48644902

4865-
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
4903+
const size_t nb0 = dst->nb[0];
4904+
const size_t nb1 = dst->nb[1];
4905+
const size_t nb2 = dst->nb[2];
4906+
const size_t nb3 = dst->nb[3];
4907+
4908+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
48664909
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
48674910
return;
48684911
}
48694912

4870-
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
4871-
if (dst->type == GGML_TYPE_F16) {
4872-
size_t id = 0;
4873-
const size_t rs = ne00*nb00;
4874-
4875-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4876-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4877-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4878-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4879-
char * dst_ptr = (char *) dst->data + id*rs;
4880-
4881-
memcpy(dst_ptr, src0_ptr, rs);
4882-
4883-
id++;
4884-
}
4885-
}
4886-
}
4887-
} else if (dst->type == GGML_TYPE_F32) {
4888-
size_t id = 0;
4889-
float * dst_ptr = (float *) dst->data;
4890-
4891-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4892-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4893-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4894-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4895-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4896-
4897-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4898-
id++;
4899-
}
4900-
}
4913+
if (src0->type == dst->type &&
4914+
src0->ne[0] == dst->ne[0] &&
4915+
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
4916+
// copy by rows
4917+
const size_t rs = ne00*nb00;
4918+
for (int64_t i03 = 0; i03 < ne03; i03++) {
4919+
for (int64_t i02 = 0; i02 < ne02; i02++) {
4920+
for (int64_t i01 = 0; i01 < ne01; i01++) {
4921+
memcpy(
4922+
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
4923+
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
4924+
rs);
49014925
}
49024926
}
4903-
} else {
4904-
GGML_ASSERT(false); // TODO: implement
49054927
}
4906-
} else {
4907-
//printf("%s: this is not optimal - fix me\n", __func__);
4928+
return;
4929+
}
49084930

4909-
if (dst->type == GGML_TYPE_F32) {
4910-
size_t id = 0;
4911-
float * dst_ptr = (float *) dst->data;
4931+
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
49124932

4913-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4914-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4915-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4916-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4917-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4933+
// dst counters
4934+
int64_t i10 = 0;
4935+
int64_t i11 = 0;
4936+
int64_t i12 = 0;
4937+
int64_t i13 = 0;
49184938

4919-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
4920-
id++;
4939+
if (dst->type == GGML_TYPE_F16) {
4940+
for (int64_t i03 = 0; i03 < ne03; i03++) {
4941+
for (int64_t i02 = 0; i02 < ne02; i02++) {
4942+
for (int64_t i01 = 0; i01 < ne01; i01++) {
4943+
for (int64_t i00 = 0; i00 < ne00; i00++) {
4944+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4945+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4946+
4947+
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
4948+
4949+
if (++i10 == ne00) {
4950+
i10 = 0;
4951+
if (++i11 == ne01) {
4952+
i11 = 0;
4953+
if (++i12 == ne02) {
4954+
i12 = 0;
4955+
if (++i13 == ne03) {
4956+
i13 = 0;
4957+
}
4958+
}
4959+
}
49214960
}
49224961
}
49234962
}
49244963
}
4925-
} else if (dst->type == GGML_TYPE_F16) {
4926-
size_t id = 0;
4927-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4928-
4929-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4930-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4931-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4932-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4933-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4934-
4935-
dst_ptr[id] = *src0_ptr;
4936-
id++;
4964+
}
4965+
} else if (dst->type == GGML_TYPE_F32) {
4966+
for (int64_t i03 = 0; i03 < ne03; i03++) {
4967+
for (int64_t i02 = 0; i02 < ne02; i02++) {
4968+
for (int64_t i01 = 0; i01 < ne01; i01++) {
4969+
for (int64_t i00 = 0; i00 < ne00; i00++) {
4970+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
4971+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
4972+
4973+
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
4974+
4975+
if (++i10 == ne00) {
4976+
i10 = 0;
4977+
if (++i11 == ne01) {
4978+
i11 = 0;
4979+
if (++i12 == ne02) {
4980+
i12 = 0;
4981+
if (++i13 == ne03) {
4982+
i13 = 0;
4983+
}
4984+
}
4985+
}
49374986
}
49384987
}
49394988
}
49404989
}
4941-
} else {
4942-
GGML_ASSERT(false); // TODO: implement
49434990
}
4991+
} else {
4992+
GGML_ASSERT(false); // TODO: implement
49444993
}
49454994
}
49464995

@@ -4949,7 +4998,6 @@ static void ggml_compute_forward_dup_f32(
49494998
const struct ggml_tensor * src0,
49504999
struct ggml_tensor * dst) {
49515000
GGML_ASSERT(params->ith == 0);
4952-
GGML_ASSERT(ggml_is_contiguous(dst));
49535001
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
49545002

49555003
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -4966,85 +5014,76 @@ static void ggml_compute_forward_dup_f32(
49665014
const size_t nb02 = src0->nb[2];
49675015
const size_t nb03 = src0->nb[3];
49685016

4969-
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
5017+
const size_t nb0 = dst->nb[0];
5018+
const size_t nb1 = dst->nb[1];
5019+
const size_t nb2 = dst->nb[2];
5020+
const size_t nb3 = dst->nb[3];
5021+
5022+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
49705023
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
49715024
return;
49725025
}
49735026

4974-
if (src0->nb[0] == sizeof(float)) {
4975-
if (dst->type == GGML_TYPE_F32) {
4976-
size_t id = 0;
4977-
const size_t rs = ne00*nb00;
4978-
4979-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4980-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4981-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4982-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
4983-
char * dst_ptr = (char *) dst->data + id*rs;
5027+
// dst counters
5028+
int64_t i10 = 0;
5029+
int64_t i11 = 0;
5030+
int64_t i12 = 0;
5031+
int64_t i13 = 0;
49845032

4985-
memcpy(dst_ptr, src0_ptr, rs);
4986-
4987-
id++;
4988-
}
4989-
}
4990-
}
4991-
} else if (dst->type == GGML_TYPE_F16) {
4992-
size_t id = 0;
4993-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
4994-
4995-
for (int64_t i03 = 0; i03 < ne03; i03++) {
4996-
for (int64_t i02 = 0; i02 < ne02; i02++) {
4997-
for (int64_t i01 = 0; i01 < ne01; i01++) {
4998-
for (int64_t i00 = 0; i00 < ne00; i00++) {
4999-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5000-
5001-
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5002-
id++;
5033+
if (dst->type == GGML_TYPE_F32) {
5034+
for (int64_t i03 = 0; i03 < ne03; i03++) {
5035+
for (int64_t i02 = 0; i02 < ne02; i02++) {
5036+
for (int64_t i01 = 0; i01 < ne01; i01++) {
5037+
for (int64_t i00 = 0; i00 < ne00; i00++) {
5038+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5039+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5040+
5041+
memcpy(dst_ptr, src0_ptr, sizeof(float));
5042+
5043+
if (++i10 == dst->ne[0]) {
5044+
i10 = 0;
5045+
if (++i11 == dst->ne[1]) {
5046+
i11 = 0;
5047+
if (++i12 == dst->ne[2]) {
5048+
i12 = 0;
5049+
if (++i13 == dst->ne[3]) {
5050+
i13 = 0;
5051+
}
5052+
}
5053+
}
50035054
}
50045055
}
50055056
}
50065057
}
5007-
} else {
5008-
GGML_ASSERT(false); // TODO: implement
50095058
}
5010-
} else {
5011-
//printf("%s: this is not optimal - fix me\n", __func__);
5012-
5013-
if (dst->type == GGML_TYPE_F32) {
5014-
size_t id = 0;
5015-
float * dst_ptr = (float *) dst->data;
5016-
5017-
for (int64_t i03 = 0; i03 < ne03; i03++) {
5018-
for (int64_t i02 = 0; i02 < ne02; i02++) {
5019-
for (int64_t i01 = 0; i01 < ne01; i01++) {
5020-
for (int64_t i00 = 0; i00 < ne00; i00++) {
5021-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5022-
5023-
dst_ptr[id] = *src0_ptr;
5024-
id++;
5025-
}
5026-
}
5027-
}
5028-
}
5029-
} else if (dst->type == GGML_TYPE_F16) {
5030-
size_t id = 0;
5031-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5032-
5033-
for (int64_t i03 = 0; i03 < ne03; i03++) {
5034-
for (int64_t i02 = 0; i02 < ne02; i02++) {
5035-
for (int64_t i01 = 0; i01 < ne01; i01++) {
5036-
for (int64_t i00 = 0; i00 < ne00; i00++) {
5037-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5038-
5039-
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5040-
id++;
5059+
} else if (dst->type == GGML_TYPE_F16) {
5060+
for (int64_t i03 = 0; i03 < ne03; i03++) {
5061+
for (int64_t i02 = 0; i02 < ne02; i02++) {
5062+
for (int64_t i01 = 0; i01 < ne01; i01++) {
5063+
for (int64_t i00 = 0; i00 < ne00; i00++) {
5064+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5065+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5066+
5067+
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5068+
5069+
if (++i10 == dst->ne[0]) {
5070+
i10 = 0;
5071+
if (++i11 == dst->ne[1]) {
5072+
i11 = 0;
5073+
if (++i12 == dst->ne[2]) {
5074+
i12 = 0;
5075+
if (++i13 == dst->ne[3]) {
5076+
i13 = 0;
5077+
}
5078+
}
5079+
}
50415080
}
50425081
}
50435082
}
50445083
}
5045-
} else {
5046-
GGML_ASSERT(false); // TODO: implement
50475084
}
5085+
} else {
5086+
GGML_ASSERT(false); // TODO: implement
50485087
}
50495088
}
50505089

ggml.h

+10
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,16 @@ struct ggml_tensor * ggml_view_2d(
558558
size_t nb1, // row stride in bytes
559559
size_t offset);
560560

561+
struct ggml_tensor * ggml_view_3d(
562+
struct ggml_context * ctx,
563+
struct ggml_tensor * a,
564+
int64_t ne0,
565+
int64_t ne1,
566+
int64_t ne2,
567+
size_t nb1, // row stride in bytes
568+
size_t nb2, // slice stride in bytes
569+
size_t offset);
570+
561571
struct ggml_tensor * ggml_permute(
562572
struct ggml_context * ctx,
563573
struct ggml_tensor * a,

0 commit comments

Comments
 (0)