Skip to content

Commit 5a460de

Browse files
authored
Remove redundant CUDA copies after gated_delta_net. (#23940)
* Remove redundant CUDA copies after gated_delta_net. Currently, GDN writes recurrent state snapshots into its output tail, then the graph immediately copies those snapshots into ssm_states_all. With MTP draft length 3, target decode uses K=4, so that becomes 4 extra ggml_cuda_cpy calls. The change detects that gated_delta_net -> view -> cpy pattern and makes the CUDA GDN kernel write the state snapshot(s) directly into the recurrent cache, skipping the intermediate tail writes and copy kernels when safe. * Address review comments
1 parent c8ae9a7 commit 5a460de

3 files changed

Lines changed: 135 additions & 27 deletions

File tree

ggml/src/ggml-cuda/gated_delta_net.cu

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
1010
const float * beta,
1111
const float * curr_state,
1212
float * dst,
13+
float * state,
1314
int64_t H,
1415
int64_t n_tokens,
1516
int64_t n_seqs,
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
2526
const uint3 neqk1_magic,
2627
const uint3 rq3_magic,
2728
float scale,
29+
int64_t state_slot_stride,
2830
int K) {
2931
const uint32_t h_idx = blockIdx.x;
3032
const uint32_t sequence = blockIdx.y;
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
3537
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
3638
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
3739

38-
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
3940
float * attn_data = dst;
40-
float * state = dst + attn_score_elems;
4141

4242
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
4343
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
145145
if constexpr (keep_rs_t) {
146146
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
147147
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
148-
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
149148
const int target_slot = (int) n_tokens - 1 - t;
150149
if (target_slot >= 0 && target_slot < K) {
151-
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
150+
float * curr_state = state + target_slot * state_slot_stride;
152151
#pragma unroll
153152
for (int r = 0; r < rows_per_lane; r++) {
154153
const int i = r * warp_size + lane;
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
171170
static void launch_gated_delta_net(
172171
const float * q_d, const float * k_d, const float * v_d,
173172
const float * g_d, const float * b_d, const float * s_d,
174-
float * dst_d,
173+
float * dst_d, float * state_d,
175174
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
176175
int64_t sq1, int64_t sq2, int64_t sq3,
177176
int64_t sv1, int64_t sv2, int64_t sv3,
178177
int64_t sb1, int64_t sb2, int64_t sb3,
179178
int64_t neqk1, int64_t rq3,
180-
float scale, int K, cudaStream_t stream) {
179+
float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
181180
//TODO: Add chunked kernel for even faster pre-fill
182181
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
183182
const int num_warps = 4;
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
187186
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
188187
const uint3 rq3_magic = init_fastdiv_values(rq3);
189188

190-
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
191-
192189
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
193190
switch (S_v) {
194191
case 16:
195192
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
196-
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
193+
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
197194
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
198-
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
195+
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
199196
break;
200197
case 32:
201198
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
202-
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
199+
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
203200
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
204-
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
201+
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
205202
break;
206203
case 64: {
207204
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
208-
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
205+
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
209206
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
210-
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
207+
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
211208
break;
212209
}
213210
case 128: {
214211
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
215-
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
212+
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
216213
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
217-
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
214+
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
218215
break;
219216
}
220217
default:
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
223220
}
224221
}
225222

226-
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223+
static void ggml_cuda_op_gated_delta_net_impl(
224+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
227225
ggml_tensor * src_q = dst->src[0];
228226
ggml_tensor * src_k = dst->src[1];
229227
ggml_tensor * src_v = dst->src[2];
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
288286
const int K = ggml_get_op_params_i32(dst, 0);
289287
const bool keep_rs = K > 1;
290288

289+
// recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
290+
float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
291+
int64_t state_slot_stride = S_v * S_v * H * n_seqs;
292+
if (cache != nullptr) {
293+
state_d = cache->data;
294+
state_slot_stride = cache->slot_stride;
295+
}
296+
291297
if (kda) {
292298
if (keep_rs) {
293-
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
299+
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
294300
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
295-
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
301+
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
296302
} else {
297-
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
303+
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
298304
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
299-
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
305+
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
300306
}
301307
} else {
302308
if (keep_rs) {
303-
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
309+
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
304310
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
305-
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
311+
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
306312
} else {
307-
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
313+
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
308314
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
309-
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
315+
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
310316
}
311317
}
312318
}
319+
320+
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
321+
ggml_cuda_op_gated_delta_net_impl(ctx, dst, nullptr);
322+
}
323+
324+
void ggml_cuda_op_gated_delta_net_fused_cache(
325+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
326+
ggml_cuda_op_gated_delta_net_impl(ctx, dst, &cache);
327+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
#include "common.cuh"
22
#include "ggml.h"
33

4+
// fused-kernel recurrent-state output; strides in elements (per-seq stride is always D, set in-kernel)
5+
struct ggml_cuda_gated_delta_net_fused_cache {
6+
float * data; // rollback slot 0
7+
int64_t slot_stride; // between rollback slots (0 when K==1)
8+
};
9+
410
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
11+
12+
// same op, but writes the snapshot(s) into the cache instead of dst (see ggml_cuda_try_gdn_cache_fusion)
13+
void ggml_cuda_op_gated_delta_net_fused_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
14+
ggml_cuda_gated_delta_net_fused_cache cache);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,6 +3251,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
32513251
GGML_UNUSED(backend);
32523252
}
32533253

3254+
static bool ggml_cuda_is_view_or_noop(const ggml_tensor * t) {
3255+
return ggml_is_empty(t) || t->op == GGML_OP_RESHAPE || t->op == GGML_OP_TRANSPOSE ||
3256+
t->op == GGML_OP_VIEW || t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
3257+
}
3258+
32543259
#ifdef USE_CUDA_GRAPH
32553260
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
32563261

@@ -3260,7 +3265,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
32603265
for (int i = 0; i < cgraph->n_nodes; i++) {
32613266
ggml_tensor * node = cgraph->nodes[i];
32623267

3263-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
3268+
if (ggml_cuda_is_view_or_noop(node)) {
32643269
continue;
32653270
}
32663271

@@ -3403,6 +3408,70 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
34033408
return true;
34043409
}
34053410

3411+
// match gated_delta_net + the strided cpy that scatters its state snapshots into the cache
3412+
// (slot i -> rollback group i, slot 0 newest), so the kernel can write them and skip the cpy.
3413+
static int ggml_cuda_try_gdn_cache_fusion(
3414+
const ggml_cgraph * cgraph, int node_idx, ggml_cuda_gated_delta_net_fused_cache & fused_state_cpy) {
3415+
const ggml_tensor * gdn = cgraph->nodes[node_idx];
3416+
// the kernel skips the snapshot tail, so the gdn output must not be a graph output
3417+
if (gdn->op != GGML_OP_GATED_DELTA_NET || gdn->type != GGML_TYPE_F32 ||
3418+
(gdn->flags & GGML_TENSOR_FLAG_OUTPUT)) {
3419+
return 0;
3420+
}
3421+
3422+
const ggml_tensor * src_v = gdn->src[2];
3423+
const int64_t S_v = src_v->ne[0];
3424+
const int64_t H = src_v->ne[1];
3425+
const int64_t n_tokens = src_v->ne[2];
3426+
const int64_t n_seqs = src_v->ne[3];
3427+
const int64_t D = S_v * S_v * H;
3428+
const int64_t K = ggml_get_op_params_i32(gdn, 0); // snapshot slot count
3429+
const int64_t n_written = std::min<int64_t>(n_tokens, K); // newest n_written slots are written
3430+
3431+
// snapshot tail starts right after the attention scores
3432+
const size_t tail_off = ggml_row_size(GGML_TYPE_F32, S_v * H * n_tokens * n_seqs);
3433+
3434+
// snapshot cpy is the first real node after the gdn (skip views/no-ops)
3435+
const ggml_tensor * cpy = nullptr;
3436+
int skip = 0;
3437+
for (int j = node_idx + 1; j < cgraph->n_nodes && cpy == nullptr; ++j) {
3438+
const ggml_tensor * n = cgraph->nodes[j];
3439+
if (ggml_cuda_is_view_or_noop(n)) {
3440+
continue;
3441+
}
3442+
if (n->op != GGML_OP_CPY || (n->flags & GGML_TENSOR_FLAG_OUTPUT)) {
3443+
return 0;
3444+
}
3445+
cpy = n;
3446+
skip = j - node_idx;
3447+
}
3448+
if (cpy == nullptr) {
3449+
return 0;
3450+
}
3451+
3452+
const ggml_tensor * src = cpy->src[0]; // view of the gdn snapshot tail
3453+
const ggml_tensor * dst = cpy->src[1]; // cache view the kernel writes to
3454+
3455+
// src must be this gdn's snapshot tail (contiguous, at the tail offset)
3456+
if (src->op != GGML_OP_VIEW || src->view_src != gdn || src->view_offs != tail_off ||
3457+
!ggml_is_contiguous(src)) {
3458+
return 0;
3459+
}
3460+
3461+
// dst is the [D, n_seqs, n_written] cache view; require nb[1] == D (the per-seq stride the kernel
3462+
// assumes). ggml_cpy pins src to the same element count.
3463+
const std::array<int64_t, GGML_MAX_DIMS> expected_ne = { D, n_seqs, n_written, 1 };
3464+
if (dst->op != GGML_OP_VIEW || dst->type != GGML_TYPE_F32 || dst->data == nullptr ||
3465+
!std::equal(expected_ne.begin(), expected_ne.end(), dst->ne) ||
3466+
dst->nb[0] != ggml_type_size(GGML_TYPE_F32) || dst->nb[1] != (size_t) ggml_row_size(GGML_TYPE_F32, D)) {
3467+
return 0;
3468+
}
3469+
3470+
fused_state_cpy.data = (float *) dst->data; // rollback group 0 (newest)
3471+
fused_state_cpy.slot_stride = K > 1 ? (int64_t) (dst->nb[2] / sizeof(float)) : 0;
3472+
return skip;
3473+
}
3474+
34063475
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
34073476
args.sigmoid = false;
34083477
args.softmax = false;
@@ -3844,6 +3913,20 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
38443913

38453914
ggml_tensor * node = cgraph->nodes[i];
38463915

3916+
// gated_delta_net -> cpy: scatter recurrent-state snapshots into the cache
3917+
if (node->op == GGML_OP_GATED_DELTA_NET) {
3918+
ggml_cuda_gated_delta_net_fused_cache fused_state_cpy;
3919+
const int nodes_to_skip = ggml_cuda_try_gdn_cache_fusion(cgraph, i, fused_state_cpy);
3920+
if (nodes_to_skip > 0) {
3921+
#ifdef GGML_CUDA_DEBUG
3922+
GGML_LOG_INFO("%s: fused gated_delta_net snapshot copies for %s (skipped %d nodes)\n",
3923+
__func__, node->name, nodes_to_skip);
3924+
#endif
3925+
ggml_cuda_op_gated_delta_net_fused_cache(*cuda_ctx, node, fused_state_cpy);
3926+
return nodes_to_skip;
3927+
}
3928+
}
3929+
38473930
//topk-moe
38483931
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
38493932
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
@@ -4372,7 +4455,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
43724455
#endif
43734456
prev_i = i;
43744457

4375-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
4458+
if (ggml_cuda_is_view_or_noop(node)) {
43764459
continue;
43774460
}
43784461

0 commit comments

Comments
 (0)