forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathappend_attention_c8_impl.cuh
More file actions
1630 lines (1558 loc) · 66.6 KB
/
append_attention_c8_impl.cuh
File metadata and controls
1630 lines (1558 loc) · 66.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "append_attention_func.cuh"
#include "append_attention_kernel.h"
template <typename T,
typename CacheT,
bool partition_kv,
uint32_t GROUP_SIZE,
bool CAUSAL,
uint32_t NUM_WARPS,
uint32_t NUM_WARP_Q,
uint32_t NUM_WARP_KV,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t num_frags_x,
uint32_t num_frags_z,
uint32_t num_frags_y,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise = false,
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads]
const T *__restrict__ cache_v_scale, // [num_kv_heads]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head =
HEAD_DIM / num_elems_per_128b<T>(); // 128 / 8 = 16
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>(); // 128 / 16 = 8
constexpr uint32_t num_vecs_per_blocksize =
BLOCK_SIZE / num_elems_per_128b<CacheT>(); // 64 / 16 = 4
constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k;
constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize;
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE;
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
const uint32_t num_chunks = gridDim.y;
const uint32_t chunk_idx = blockIdx.y;
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
const int *block_table_now = nullptr;
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
uint32_t kv_len = seq_lens_kv[batch_id];
if (ENABLE_PREFILL) {
kv_len += q_len;
if (kv_len <= 0) {
return;
}
} else {
if (kv_len <= 0) {
return;
}
kv_len += q_len;
}
const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size);
if (chunk_idx >= num_chunks_this_seq) {
return;
}
const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0;
const uint32_t chunk_end =
partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
float o_frag[num_frags_x][num_frags_y][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
init_states<T, num_frags_x, num_frags_y>(o_frag, m_frag, d_frag);
const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
HEAD_DIM);
commit_group();
wait_group<0>();
__syncthreads();
q_smem_inplace_multiply_sm_scale<num_frags_x, num_frags_y, T>(&qo_smem,
scale);
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
const uint32_t num_iterations = div_up(
CAUSAL
? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE),
chunk_start)))
: chunk_len,
num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t v_smem_offset_r =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 4 + tid / 8,
tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64
uint32_t kv_idx_base = chunk_start;
const uint32_t const_k_offset = kv_head_idx * kv_h_stride +
(wid * 4 + tid / 8) * kv_b_stride +
tid % 8 * num_elems_per_128b<CacheT>();
const uint32_t const_v_offset = kv_head_idx * kv_h_stride +
(wid * 8 + tid / 4) * kv_d_stride +
tid % 4 * num_elems_per_128b<CacheT>();
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(k_smem,
&k_smem_offset_w,
cache_k,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_b_stride,
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(v_smem,
&v_smem_offset_w,
cache_v,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_d_stride,
kv_idx_base,
chunk_end,
const_v_offset);
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s<T,
partition_kv,
CAUSAL,
GROUP_SIZE,
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
}
// update m,d
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
kv_idx_base += num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(k_smem,
&k_smem_offset_w,
cache_k,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_b_stride,
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
wait_group<1>();
__syncthreads();
// compute sfm*v
compute_sfm_v_c8<num_frags_x,
num_frags_y,
num_frags_z,
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(v_smem,
&v_smem_offset_w,
cache_v,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_d_stride,
kv_idx_base,
chunk_end,
const_v_offset);
commit_group();
}
wait_group<0>();
__syncthreads();
if constexpr (!partition_kv) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
// write o
// [num_frags_x, 16, num_frags_y, 16]
if constexpr (partition_kv) {
write_o_reg_gmem_shift_smooth_quant<GROUP_SIZE,
num_frags_x,
num_frags_y,
partition_kv>(
o_frag,
&qo_smem,
o_base_ptr_T,
shift_bias,
smooth_weight,
q_base_seq_id_this_block,
q_head_idx,
quant_max_bound,
quant_min_bound,
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
HEAD_DIM);
} else {
write_o_reg_gmem_shift_smooth_quant<GROUP_SIZE,
num_frags_x,
num_frags_y,
partition_kv>(
o_frag,
&qo_smem,
o_base_ptr_int8,
shift_bias,
smooth_weight,
q_base_seq_id_this_block,
q_head_idx,
quant_max_bound,
quant_min_bound,
in_scale,
q_len,
partition_kv ? q_n_stride * num_chunks : q_n_stride,
HEAD_DIM);
}
if constexpr (partition_kv) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t qo_idx_now =
q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16;
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j];
}
}
}
}
}
template <typename T,
typename CacheT,
bool partition_kv,
uint32_t GROUP_SIZE,
bool CAUSAL,
uint32_t NUM_WARPS,
uint32_t NUM_WARP_Q,
uint32_t NUM_WARP_KV,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t num_frags_x,
uint32_t num_frags_z,
uint32_t num_frags_y,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
constexpr uint32_t num_vecs_per_blocksize =
BLOCK_SIZE / num_elems_per_128b<CacheT>();
constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k;
constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize;
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE;
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
const uint32_t num_chunks = gridDim.y;
const uint32_t chunk_idx = blockIdx.y;
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
uint32_t kv_len = seq_lens_kv[batch_id];
if (ENABLE_PREFILL) {
kv_len += q_len;
if (kv_len <= 0) {
return;
}
} else {
if (kv_len <= 0) {
return;
}
kv_len += q_len;
}
const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size);
if (chunk_idx >= num_chunks_this_seq) {
return;
}
const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0;
const uint32_t chunk_end =
partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
float s_frag[num_frags_x][num_frags_z][8];
float o_frag[num_frags_x][num_frags_y][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
init_states<T, num_frags_x, num_frags_y>(o_frag, m_frag, d_frag);
const uint32_t q_n_stride = q_num_heads * HEAD_DIM;
const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM;
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
const uint32_t o_offset = q_start_seq_id * q_n_stride +
q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr;
if (num_chunks_this_seq <= 1) {
o_base_ptr_int8 = out + o_offset;
} else {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, tid / 16); // 16 * 16
load_q_global_smem_multi_warps<GROUP_SIZE,
num_frags_x,
num_frags_y,
HEAD_DIM,
T>(q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
q_end,
q_ori_n_stride,
HEAD_DIM);
commit_group();
wait_group<0>();
__syncthreads();
q_smem_inplace_multiply_sm_scale_multi_warps<num_frags_x, num_frags_y, T>(
&qo_smem, scale);
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
const uint32_t num_iterations = div_up(
CAUSAL
? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE),
chunk_start)))
: chunk_len,
NUM_WARP_KV * num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t v_smem_offset_r =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
(wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8,
(wid % 2) * num_frags_z + (tid % 16) / 8);
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 4 + tid / 8,
tid %
8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 8 + tid / 4, tid % 4);
uint32_t kv_idx_base = chunk_start;
const uint32_t const_k_offset = kv_head_idx * kv_h_stride +
(wid * 4 + tid / 8) * kv_b_stride +
tid % 8 * num_elems_per_128b<CacheT>();
const uint32_t const_v_offset = kv_head_idx * kv_h_stride +
(wid * 8 + tid / 4) * kv_d_stride +
tid % 4 * num_elems_per_128b<CacheT>();
// load BLOCK_SIZE * HEAD_DIM each time
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(k_smem,
&k_smem_offset_w,
cache_k,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_b_stride,
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(v_smem,
&v_smem_offset_w,
cache_v,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_d_stride,
kv_idx_base,
chunk_end,
const_v_offset);
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration) {
mask_s<T,
partition_kv,
CAUSAL,
GROUP_SIZE,
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
}
// update m,d
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(k_smem,
&k_smem_offset_w,
cache_k,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_b_stride,
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
wait_group<1>();
__syncthreads();
// compute sfm * v
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
num_frags_y,
num_frags_z,
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
BLOCK_SIZE,
num_frags_y,
num_frags_z,
NUM_WARP_Q>(v_smem,
&v_smem_offset_w,
cache_v,
block_table_now,
kv_head_idx,
kv_n_stride,
kv_h_stride,
kv_d_stride,
kv_idx_base,
chunk_end,
const_v_offset);
commit_group();
}
wait_group<0>();
__syncthreads();
merge_block_res_v2<num_frags_x, num_frags_y, T>(
o_frag, reinterpret_cast<float *>(smem), m_frag, d_frag, wid, tid);
if (num_chunks_this_seq <= 1) {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
// write o
// [num_frags_x, 16, num_frags_y, 16]
if (num_chunks_this_seq <= 1) {
write_o_reg_gmem_multi_warps_shift_smooth_quant<GROUP_SIZE,
num_frags_x,
num_frags_y,
false>(
o_frag,
&qo_smem,
o_base_ptr_int8,
shift_bias,
smooth_weight,
q_base_seq_id_this_block,
q_head_idx,
quant_max_bound,
quant_min_bound,
in_scale,
q_len,
q_n_stride,
HEAD_DIM);
} else {
write_o_reg_gmem_multi_warps_shift_smooth_quant<GROUP_SIZE,
num_frags_x,
num_frags_y,
partition_kv>(
o_frag,
&qo_smem,
o_base_ptr_T,
shift_bias,
smooth_weight,
q_base_seq_id_this_block,
q_head_idx,
quant_max_bound,
quant_min_bound,
in_scale,
q_len,
q_n_stride * num_chunks,
HEAD_DIM);
}
if (num_chunks_this_seq > 1) {
if (wid == 0) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t qo_idx_now =
q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16;
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j];
}
}
}
}
}
}
template <typename T,
uint32_t GROUP_SIZE,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
bool CAUSAL,
uint32_t BLOCK_SHAPE_Q,
uint32_t NUM_WARP_Q,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool IsFP8=false>
void MultiQueryAppendC8Attention(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::Tensor &cache_k_scale,
const paddle::Tensor &cache_v_scale,
const paddle::optional<paddle::Tensor> &shift_bias,
const paddle::optional<paddle::Tensor> &smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
const int num_blocks_x_cpu,
const int max_seq_len,
const int max_dec_len,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto token_num = meta_data.token_nums;
auto bsz = meta_data.batch_size;
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
constexpr uint32_t num_warps = 4;
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q);
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16;
auto *allocator = paddle::GetAllocator(qkv.place());
const float scale = 1.f / sqrt(HEAD_DIM);
bool is_scale_channel_wise = false;
if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) {
is_scale_channel_wise = true;
}
if constexpr (NUM_WARP_Q == 4) {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,