Skip to content

Commit ef2d770

Browse files
ggml : fix broken CPU concat implementation for quantized types (#25247)
* ggml : fix broken CPU concat implementation for quantized types * tests : concat tests for quantized types --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
1 parent 2d97363 commit ef2d770

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

ggml/src/ggml-cpu/ops.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,16 +1913,20 @@ static void ggml_compute_forward_concat_any(
19131913
GGML_ASSERT(dim >= 0 && dim < 4);
19141914

19151915
int64_t o[4] = {0, 0, 0, 0};
1916-
o[dim] = src0->ne[dim];
1916+
if (dim == 0) {
1917+
o[dim] = src0->ne[dim]/ggml_blck_size(src0->type);
1918+
} else {
1919+
o[dim] = src0->ne[dim];
1920+
}
19171921

19181922
const char * x;
19191923

19201924
// TODO: smarter multi-theading
19211925
for (int i3 = 0; i3 < ne3; i3++) {
19221926
for (int i2 = ith; i2 < ne2; i2 += nth) {
19231927
for (int i1 = 0; i1 < ne1; i1++) {
1924-
for (int i0 = 0; i0 < ne0; i0++) {
1925-
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1928+
for (int i0 = 0; i0 < ne0/ggml_blck_size(dst->type); i0++) {
1929+
if (i0 < ne00/ggml_blck_size(src0->type) && i1 < ne01 && i2 < ne02 && i3 < ne03) {
19261930
x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
19271931
} else {
19281932
x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
@@ -2071,6 +2075,14 @@ void ggml_compute_forward_concat(
20712075
ggml_tensor * dst) {
20722076

20732077
const ggml_tensor * src0 = dst->src[0];
2078+
const ggml_tensor * src1 = dst->src[1];
2079+
2080+
if (ggml_is_quantized(src0->type)) {
2081+
GGML_ASSERT(ggml_is_contiguous(src0));
2082+
GGML_ASSERT(ggml_is_contiguous(src1));
2083+
GGML_ASSERT(src0->ne[0] % ggml_blck_size(src0->type) == 0);
2084+
GGML_ASSERT(src1->ne[0] % ggml_blck_size(src1->type) == 0);
2085+
}
20742086

20752087
switch (src0->type) {
20762088
case GGML_TYPE_F16:

tests/test-backend-ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8918,6 +8918,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
89188918
}
89198919
}
89208920

8921+
for (ggml_type type_a : { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0 }) {
8922+
for (int dim : { 0, 1, 2, 3, }) {
8923+
test_cases.emplace_back(new test_concat(type_a, {128, 12, 13, 14}, dim == 0 ? 256 : 7, dim, 0));
8924+
}
8925+
}
8926+
89218927
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
89228928
for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
89238929
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));

0 commit comments

Comments
 (0)