@@ -63,23 +63,25 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
6363 stride_v_nh ,
6464 stride_v_hz ,
6565 stride_bt_s ,
66+ head_size ,
67+ query_group_size ,
6668 compute_type : gl .constexpr ,
67- HEAD_SZ : gl .constexpr ,
68- QUERY_GRP_SZ : gl .constexpr ,
69- KV_BLK_SZ : gl .constexpr ,
69+ HEAD_SZ_POW2 : gl .constexpr ,
70+ QUERY_GRP_SZ_POW2 : gl .constexpr ,
71+ KV_BLK_SZ_POW2 : gl .constexpr ,
7072 SEQ_PARTITION_SZ : gl .constexpr ,
7173 SLIDING_WINDOW : gl .constexpr ,
7274):
7375 """
7476 #TODO: Add Doc
7577 """
76- if QUERY_GRP_SZ <= 16 :
77- QUERY_GRP_SZ_POW2 : gl .constexpr = 16
78- else :
79- QUERY_GRP_SZ_POW2 : gl .constexpr = triton .next_power_of_2 (QUERY_GRP_SZ )
78+ # if QUERY_GRP_SZ <= 16:
79+ # QUERY_GRP_SZ_POW2: gl.constexpr = 16
80+ # else:
81+ # QUERY_GRP_SZ_POW2: gl.constexpr = triton.next_power_of_2(QUERY_GRP_SZ)
8082
81- HEAD_SZ_POW2 : gl .constexpr = triton .next_power_of_2 (HEAD_SZ )
82- KV_BLK_SZ_POW2 : gl .constexpr = triton .next_power_of_2 (KV_BLK_SZ )
83+ # HEAD_SZ_POW2: gl.constexpr = triton.next_power_of_2(HEAD_SZ)
84+ # KV_BLK_SZ_POW2: gl.constexpr = triton.next_power_of_2(KV_BLK_SZ)
8385 seq_idx = gl .program_id (0 )
8486 kv_head_idx = gl .program_id (1 )
8587 seq_part_idx = gl .program_id (2 )
@@ -94,8 +96,10 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
9496 return
9597
9698 seq_end_idx = gl .minimum (seq_start_idx + SEQ_PARTITION_SZ , seq_len )
97- MAX_NUM_KV_BLKS : gl .constexpr = (SEQ_PARTITION_SZ + KV_BLK_SZ - 1 ) // KV_BLK_SZ
98- num_kv_blks = gl .cdiv (seq_end_idx - seq_start_idx , KV_BLK_SZ )
99+ MAX_NUM_KV_BLKS : gl .constexpr = (
100+ SEQ_PARTITION_SZ + KV_BLK_SZ_POW2 - 1
101+ ) // KV_BLK_SZ_POW2
102+ num_kv_blks = gl .cdiv (seq_end_idx - seq_start_idx , KV_BLK_SZ_POW2 )
99103
100104 # # 1 x QUERY_GRP_SZ_POW2 x HEAD_SZ_POW2
101105 # # 1 x 8(mdim) x 128(kdim)
@@ -184,9 +188,9 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
184188 alibi_slope = gl .zeros ([QUERY_GRP_SZ_POW2 ], dtype = gl .float32 )
185189 else :
186190 alibi_slope = gl .amd .cdna3 .buffer_load (
187- ptr = alibi_slopes + kv_head_idx * QUERY_GRP_SZ ,
191+ ptr = alibi_slopes + kv_head_idx * query_group_size ,
188192 offsets = qk_row_offs ,
189- mask = qk_row_offs < QUERY_GRP_SZ ,
193+ mask = qk_row_offs < query_group_size ,
190194 )
191195
192196 # load all kv blocks in one time
@@ -200,10 +204,12 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
200204 # load q[QUERY_GRP_SZ_POW2, HEAD_SZ_POW2]
201205 q_offs = (
202206 seq_idx * stride_q_s
203- + (kv_head_idx * QUERY_GRP_SZ + q_grp_offs [:, None ]) * stride_q_nh
207+ + (kv_head_idx * query_group_size + q_grp_offs [:, None ]) * stride_q_nh
204208 + head_sz_offs [None , :]
205209 )
206- q_mask = (q_grp_offs [:, None ] < QUERY_GRP_SZ ) & (head_sz_offs [None , :] < HEAD_SZ )
210+ q_mask = (q_grp_offs [:, None ] < query_group_size ) & (
211+ head_sz_offs [None , :] < head_size
212+ )
207213 q = gl .amd .cdna3 .buffer_load (ptr = q_ptr , offsets = q_offs , mask = q_mask )
208214 q_shared = gl .allocate_shared_memory (q .dtype , q .shape , shared_a_layout , q )
209215 # k_blk_offs[MAX_NUM_KV_BLKS, K_HEAD_SZ_POW2_SPLIT, KV_BLK_SZ_POW2, CONTIGUOUS_KV_ELEMS_16B_LOAD]
@@ -244,7 +250,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
244250 )
245251
246252 qk = gl .where (
247- (qk_row_offs [:, None ] < QUERY_GRP_SZ ) & (qk_col_offs [None , :] < seq_len ),
253+ (qk_row_offs [:, None ] < query_group_size ) & (qk_col_offs [None , :] < seq_len ),
248254 qk ,
249255 float ("-inf" ),
250256 )
@@ -266,7 +272,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
266272 + seq_part_idx * stride_max_logits_p
267273 + m_l_base_offs
268274 )
269- m_l_grp_mask = m_l_base_offs < QUERY_GRP_SZ
275+ m_l_grp_mask = m_l_base_offs < query_group_size
270276 gl .amd .cdna3 .buffer_store (
271277 stored_value = max_logit_new ,
272278 ptr = max_logits_ptr ,
@@ -276,6 +282,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
276282 gl .amd .cdna3 .buffer_store (
277283 stored_value = exp_sum , ptr = exp_sums_ptr , offsets = m_l_offs , mask = m_l_grp_mask
278284 )
285+
279286 # MAX_NUM_KV_BLKS x HEAD_SZ_POW2 x KV_BLK_SZ_POW2
280287 # 16(kdim0) x 128(ndim) x 16(kdim1)
281288 blocked_v_layout : gl .constexpr = gl .DistributedLinearLayout ( # 256x128
@@ -346,14 +353,17 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
346353 o_head_sz_offs = gl .arange (
347354 0 , HEAD_SZ_POW2 , layout = gl .SliceLayout (0 , pv_mfma_layout )
348355 )
349- o_mask = (o_grp_offs [:, None ] < QUERY_GRP_SZ ) & (o_head_sz_offs [None , :] < HEAD_SZ )
356+ o_mask = (o_grp_offs [:, None ] < query_group_size ) & (
357+ o_head_sz_offs [None , :] < head_size
358+ )
350359 logits_offs = seq_idx * stride_logits_s
351360 logits_offs += kv_head_idx * stride_logits_nh
352361 logits_offs += (
353362 seq_part_idx * stride_logits_p
354363 + o_grp_offs [:, None ] * stride_logits_g
355364 + o_head_sz_offs [None , :]
356365 )
366+
357367 gl .amd .cdna3 .buffer_store (
358368 stored_value = acc , ptr = logits_ptr , offsets = logits_offs , mask = o_mask
359369 )
@@ -419,6 +429,10 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_wrapper(
419429 # pdb.set_trace()
420430
421431 # if 1:
432+ if QUERY_GRP_SZ < 16 :
433+ QUERY_GRP_SZ_POW2 = 16
434+ else :
435+ QUERY_GRP_SZ_POW2 = triton .next_power_of_2 (QUERY_GRP_SZ )
422436 if 0 :
423437 with open (ttgir_file_path , "r" ) as f :
424438 ttgir_content = f .read ()
@@ -488,12 +502,16 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_wrapper(
488502 stride_v_hz ,
489503 stride_v_bz ,
490504 stride_bt_s ,
505+ HEAD_SZ ,
506+ QUERY_GRP_SZ ,
491507 compute_type = compute_type ,
492- HEAD_SZ = HEAD_SZ ,
493- QUERY_GRP_SZ = QUERY_GRP_SZ ,
494- KV_BLK_SZ = KV_BLK_SZ ,
508+ HEAD_SZ_POW2 = triton . next_power_of_2 ( HEAD_SZ ) ,
509+ QUERY_GRP_SZ_POW2 = QUERY_GRP_SZ_POW2 ,
510+ KV_BLK_SZ_POW2 = triton . next_power_of_2 ( KV_BLK_SZ ) ,
495511 SEQ_PARTITION_SZ = SEQ_PARTITION_SZ ,
496512 SLIDING_WINDOW = SLIDING_WINDOW ,
513+ waves_per_eu = 4 ,
514+ num_stages = 1 ,
497515 )
498516
499517
0 commit comments