Skip to content

Commit b0d5299

Browse files
authored
cuda : fix supports_op condition for get_rows when number of blocks is too large (#15868)
* cuda : fix supports_op condition for get_rows when src1->ne2 > 1 ggml-ci * ggml : add comment about ggml_get_rows ggml-ci * cuda : add FIXME [no ci] * cuda : update support condition ggml-ci
1 parent f28d4f4 commit b0d5299

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

ggml/include/ggml.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,11 @@ extern "C" {
15291529
struct ggml_context * ctx,
15301530
struct ggml_tensor * a);
15311531

1532-
// supports 3D: a->ne[2] == b->ne[1]
1532+
// supports 4D a:
1533+
// a [n_embd, ne1, ne2, ne3]
1534+
// b I32 [n_rows, ne2, ne3, 1]
1535+
//
1536+
// return [n_embd, n_rows, ne2, ne3]
15331537
GGML_API struct ggml_tensor * ggml_get_rows(
15341538
struct ggml_context * ctx,
15351539
struct ggml_tensor * a, // data

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33923392
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
33933393
case GGML_OP_GET_ROWS:
33943394
{
3395+
// FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
3396+
if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
3397+
return false;
3398+
}
33953399
switch (op->src[0]->type) {
33963400
case GGML_TYPE_F16:
33973401
case GGML_TYPE_F32:

ggml/src/ggml.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
36233623
struct ggml_tensor * a,
36243624
struct ggml_tensor * b) {
36253625
GGML_ASSERT(a->ne[2] == b->ne[1]);
3626+
GGML_ASSERT(a->ne[3] == b->ne[2]);
36263627
GGML_ASSERT(b->ne[3] == 1);
36273628
GGML_ASSERT(b->type == GGML_TYPE_I32);
36283629

0 commit comments

Comments
 (0)