@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
1010 const float * beta,
1111 const float * curr_state,
1212 float * dst,
13+ float * state,
1314 int64_t H,
1415 int64_t n_tokens,
1516 int64_t n_seqs,
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
2526 const uint3 neqk1_magic,
2627 const uint3 rq3_magic,
2728 float scale,
29+ int64_t state_slot_stride,
2830 int K) {
2931 const uint32_t h_idx = blockIdx .x ;
3032 const uint32_t sequence = blockIdx .y ;
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
3537 const uint32_t iq1 = fastmodulo (h_idx, neqk1_magic);
3638 const uint32_t iq3 = fastdiv (sequence, rq3_magic);
3739
38- const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
3940 float * attn_data = dst;
40- float * state = dst + attn_score_elems;
4141
4242 // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
4343 // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
145145 if constexpr (keep_rs_t ) {
146146 // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
147147 // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
148- const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
149148 const int target_slot = (int ) n_tokens - 1 - t;
150149 if (target_slot >= 0 && target_slot < K) {
151- float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset ;
150+ float * curr_state = state + target_slot * state_slot_stride ;
152151#pragma unroll
153152 for (int r = 0 ; r < rows_per_lane; r++) {
154153 const int i = r * warp_size + lane;
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
171170static void launch_gated_delta_net (
172171 const float * q_d, const float * k_d, const float * v_d,
173172 const float * g_d, const float * b_d, const float * s_d,
174- float * dst_d,
173+ float * dst_d, float * state_d,
175174 int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
176175 int64_t sq1, int64_t sq2, int64_t sq3,
177176 int64_t sv1, int64_t sv2, int64_t sv3,
178177 int64_t sb1, int64_t sb2, int64_t sb3,
179178 int64_t neqk1, int64_t rq3,
180- float scale, int K, cudaStream_t stream) {
179+ float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
181180 // TODO: Add chunked kernel for even faster pre-fill
182181 const int warp_size = ggml_cuda_info ().devices [ggml_cuda_get_device ()].warp_size ;
183182 const int num_warps = 4 ;
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
187186 const uint3 neqk1_magic = init_fastdiv_values (neqk1);
188187 const uint3 rq3_magic = init_fastdiv_values (rq3);
189188
190- int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
191-
192189 const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params (grid_dims, block_dims, 0 , stream);
193190 switch (S_v) {
194191 case 16 :
195192 ggml_cuda_kernel_launch (gated_delta_net_cuda<16 , KDA , keep_rs_t >, launch_params,
196- q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
193+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
197194 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
198- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
195+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
199196 break ;
200197 case 32 :
201198 ggml_cuda_kernel_launch (gated_delta_net_cuda<32 , KDA , keep_rs_t >, launch_params,
202- q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
199+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
203200 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
204- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
201+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
205202 break ;
206203 case 64 : {
207204 ggml_cuda_kernel_launch (gated_delta_net_cuda<64 , KDA , keep_rs_t >, launch_params,
208- q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
205+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
209206 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
210- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
207+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
211208 break ;
212209 }
213210 case 128 : {
214211 ggml_cuda_kernel_launch (gated_delta_net_cuda<128 , KDA , keep_rs_t >, launch_params,
215- q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
212+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
216213 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
217- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
214+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
218215 break ;
219216 }
220217 default :
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
223220 }
224221}
225222
226- void ggml_cuda_op_gated_delta_net (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223+ static void ggml_cuda_op_gated_delta_net_impl (
224+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
227225 ggml_tensor * src_q = dst->src [0 ];
228226 ggml_tensor * src_k = dst->src [1 ];
229227 ggml_tensor * src_v = dst->src [2 ];
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
288286 const int K = ggml_get_op_params_i32 (dst, 0 );
289287 const bool keep_rs = K > 1 ;
290288
289+ // recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
290+ float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
291+ int64_t state_slot_stride = S_v * S_v * H * n_seqs;
292+ if (cache != nullptr ) {
293+ state_d = cache->data ;
294+ state_slot_stride = cache->slot_stride ;
295+ }
296+
291297 if (kda) {
292298 if (keep_rs) {
293- launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
299+ launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
294300 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
295- sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
301+ sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
296302 } else {
297- launch_gated_delta_net<true , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
303+ launch_gated_delta_net<true , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
298304 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
299- sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
305+ sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
300306 }
301307 } else {
302308 if (keep_rs) {
303- launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
309+ launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
304310 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
305- sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
311+ sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
306312 } else {
307- launch_gated_delta_net<false , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
313+ launch_gated_delta_net<false , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
308314 S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
309- sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
315+ sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
310316 }
311317 }
312318}
319+
320+ void ggml_cuda_op_gated_delta_net (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
321+ ggml_cuda_op_gated_delta_net_impl (ctx, dst, nullptr );
322+ }
323+
324+ void ggml_cuda_op_gated_delta_net_fused_cache (
325+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
326+ ggml_cuda_op_gated_delta_net_impl (ctx, dst, &cache);
327+ }
0 commit comments