Skip to content

Conversation

ggerganov
Copy link
Member

cont #15687

Mark this case as unsupported until actual support is implemented.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 8, 2025
@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Sep 8, 2025

The value ne12 is used in the CUDA code, but I think the indices are being calculated incorrectly. In the CPU code:

const int64_t i12 = i03%ne12;
const int64_t i11 = i02%ne11;
const int64_t i10 = i;

In the CUDA code:

const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12; // gridDim.z == ne11*ne12
const int i12 = blockIdx.z % ne12;

In the CUDA code the same values are used for i11/i01 and i12/i02.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it fixes an immediate issue it would still be fine to merge this for now. But please add a FIXME comment.

@ggerganov
Copy link
Member Author

Ok, I didn't look in the implementation and assumed it was not implemented. So, will update the PR to fix implementation.

In the CUDA code the same values are used for i11/i01 and i12/i02.

The intention of the operator is that i10 queries rows from src0, hence it corresponds to i01. Respectively:

i10 -> i01
i11 -> i02
i12 -> i03

So I think the CPU implementation is correct. Looking into this.

@ggerganov
Copy link
Member Author

The CUDA implementation is correct. The problem is that in one of the new GET_ROWS tests, the number of blocks along the 3rd dimension of the kernel exceeds 65536:

const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);

Here ne11*n12 > 2^16 and it causes the kernel launch to fail.

For now, I updated the support_op condition to bail out in such cases. Will leave it to you to add proper support for larger sizes.

@ggerganov ggerganov changed the title cuda : fix supports_op condition for get_rows when src1->ne2 > 1 cuda : fix supports_op condition for get_rows when number of blocks is too large Sep 8, 2025
@ggerganov ggerganov merged commit b0d5299 into master Sep 8, 2025
51 of 55 checks passed
@ggerganov ggerganov deleted the gg/cuda-fix-supports-get-rows branch September 8, 2025 10:56
njsyw1997 pushed a commit to aizip/llama.cpp that referenced this pull request Sep 10, 2025
…s too large (ggml-org#15868)

* cuda : fix supports_op condition for get_rows when src1->ne2 > 1

ggml-ci

* ggml : add comment about ggml_get_rows

ggml-ci

* cuda : add FIXME [no ci]

* cuda : update support condition

ggml-ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants