Skip to content

Commit 2068908

Browse files
committed
Enable GDN also for prefill, move TODO for chunked_GDN
1 parent 1623bbc commit 2068908

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

ggml/src/ggml-cuda/gated_delta_net.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ static void launch_gated_delta_net(
145145
int64_t sb1, int64_t sb2, int64_t sb3,
146146
int64_t neqk1, int64_t rq3,
147147
float scale, cudaStream_t stream) {
148-
148+
//TODO: Add chunked kernel for even faster pre-fill
149149
constexpr uint32_t warp_size = ggml_cuda_get_physical_warp_size();
150150
const int num_warps = 4;
151151
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5001,7 +5001,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
50015001
#else
50025002
// KDA is faster using the AR kernel even when n_tokens >= 512
50035003
//TODO: Add chunked kernel
5004-
return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0];
5004+
return true;
50055005
#endif // GGML_USE_MUSA
50065006
case GGML_OP_FLASH_ATTN_EXT:
50075007
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);

0 commit comments

Comments
 (0)