Skip to content

Commit 4c14027

Browse files
author
root
committed
optimize kernel
1 parent 5aec290 commit 4c14027

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

aiter/ops/triton/gluon/pa_decode_gluon.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,25 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
6363
stride_v_nh,
6464
stride_v_hz,
6565
stride_bt_s,
66+
head_size,
67+
query_group_size,
6668
compute_type: gl.constexpr,
67-
HEAD_SZ: gl.constexpr,
68-
QUERY_GRP_SZ: gl.constexpr,
69-
KV_BLK_SZ: gl.constexpr,
69+
HEAD_SZ_POW2: gl.constexpr,
70+
QUERY_GRP_SZ_POW2: gl.constexpr,
71+
KV_BLK_SZ_POW2: gl.constexpr,
7072
SEQ_PARTITION_SZ: gl.constexpr,
7173
SLIDING_WINDOW: gl.constexpr,
7274
):
7375
"""
7476
#TODO: Add Doc
7577
"""
76-
if QUERY_GRP_SZ <= 16:
77-
QUERY_GRP_SZ_POW2: gl.constexpr = 16
78-
else:
79-
QUERY_GRP_SZ_POW2: gl.constexpr = triton.next_power_of_2(QUERY_GRP_SZ)
78+
# if QUERY_GRP_SZ <= 16:
79+
# QUERY_GRP_SZ_POW2: gl.constexpr = 16
80+
# else:
81+
# QUERY_GRP_SZ_POW2: gl.constexpr = triton.next_power_of_2(QUERY_GRP_SZ)
8082

81-
HEAD_SZ_POW2: gl.constexpr = triton.next_power_of_2(HEAD_SZ)
82-
KV_BLK_SZ_POW2: gl.constexpr = triton.next_power_of_2(KV_BLK_SZ)
83+
# HEAD_SZ_POW2: gl.constexpr = triton.next_power_of_2(HEAD_SZ)
84+
# KV_BLK_SZ_POW2: gl.constexpr = triton.next_power_of_2(KV_BLK_SZ)
8385
seq_idx = gl.program_id(0)
8486
kv_head_idx = gl.program_id(1)
8587
seq_part_idx = gl.program_id(2)
@@ -94,8 +96,10 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
9496
return
9597

9698
seq_end_idx = gl.minimum(seq_start_idx + SEQ_PARTITION_SZ, seq_len)
97-
MAX_NUM_KV_BLKS: gl.constexpr = (SEQ_PARTITION_SZ + KV_BLK_SZ - 1) // KV_BLK_SZ
98-
num_kv_blks = gl.cdiv(seq_end_idx - seq_start_idx, KV_BLK_SZ)
99+
MAX_NUM_KV_BLKS: gl.constexpr = (
100+
SEQ_PARTITION_SZ + KV_BLK_SZ_POW2 - 1
101+
) // KV_BLK_SZ_POW2
102+
num_kv_blks = gl.cdiv(seq_end_idx - seq_start_idx, KV_BLK_SZ_POW2)
99103

100104
# # 1 x QUERY_GRP_SZ_POW2 x HEAD_SZ_POW2
101105
# # 1 x 8(mdim) x 128(kdim)
@@ -184,9 +188,9 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
184188
alibi_slope = gl.zeros([QUERY_GRP_SZ_POW2], dtype=gl.float32)
185189
else:
186190
alibi_slope = gl.amd.cdna3.buffer_load(
187-
ptr=alibi_slopes + kv_head_idx * QUERY_GRP_SZ,
191+
ptr=alibi_slopes + kv_head_idx * query_group_size,
188192
offsets=qk_row_offs,
189-
mask=qk_row_offs < QUERY_GRP_SZ,
193+
mask=qk_row_offs < query_group_size,
190194
)
191195

192196
# load all kv blocks in one time
@@ -200,10 +204,12 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
200204
# load q[QUERY_GRP_SZ_POW2, HEAD_SZ_POW2]
201205
q_offs = (
202206
seq_idx * stride_q_s
203-
+ (kv_head_idx * QUERY_GRP_SZ + q_grp_offs[:, None]) * stride_q_nh
207+
+ (kv_head_idx * query_group_size + q_grp_offs[:, None]) * stride_q_nh
204208
+ head_sz_offs[None, :]
205209
)
206-
q_mask = (q_grp_offs[:, None] < QUERY_GRP_SZ) & (head_sz_offs[None, :] < HEAD_SZ)
210+
q_mask = (q_grp_offs[:, None] < query_group_size) & (
211+
head_sz_offs[None, :] < head_size
212+
)
207213
q = gl.amd.cdna3.buffer_load(ptr=q_ptr, offsets=q_offs, mask=q_mask)
208214
q_shared = gl.allocate_shared_memory(q.dtype, q.shape, shared_a_layout, q)
209215
# k_blk_offs[MAX_NUM_KV_BLKS, K_HEAD_SZ_POW2_SPLIT, KV_BLK_SZ_POW2, CONTIGUOUS_KV_ELEMS_16B_LOAD]
@@ -244,7 +250,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
244250
)
245251

246252
qk = gl.where(
247-
(qk_row_offs[:, None] < QUERY_GRP_SZ) & (qk_col_offs[None, :] < seq_len),
253+
(qk_row_offs[:, None] < query_group_size) & (qk_col_offs[None, :] < seq_len),
248254
qk,
249255
float("-inf"),
250256
)
@@ -266,7 +272,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
266272
+ seq_part_idx * stride_max_logits_p
267273
+ m_l_base_offs
268274
)
269-
m_l_grp_mask = m_l_base_offs < QUERY_GRP_SZ
275+
m_l_grp_mask = m_l_base_offs < query_group_size
270276
gl.amd.cdna3.buffer_store(
271277
stored_value=max_logit_new,
272278
ptr=max_logits_ptr,
@@ -276,6 +282,7 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
276282
gl.amd.cdna3.buffer_store(
277283
stored_value=exp_sum, ptr=exp_sums_ptr, offsets=m_l_offs, mask=m_l_grp_mask
278284
)
285+
279286
# MAX_NUM_KV_BLKS x HEAD_SZ_POW2 x KV_BLK_SZ_POW2
280287
# 16(kdim0) x 128(ndim) x 16(kdim1)
281288
blocked_v_layout: gl.constexpr = gl.DistributedLinearLayout( # 256x128
@@ -346,14 +353,17 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk_gluon(
346353
o_head_sz_offs = gl.arange(
347354
0, HEAD_SZ_POW2, layout=gl.SliceLayout(0, pv_mfma_layout)
348355
)
349-
o_mask = (o_grp_offs[:, None] < QUERY_GRP_SZ) & (o_head_sz_offs[None, :] < HEAD_SZ)
356+
o_mask = (o_grp_offs[:, None] < query_group_size) & (
357+
o_head_sz_offs[None, :] < head_size
358+
)
350359
logits_offs = seq_idx * stride_logits_s
351360
logits_offs += kv_head_idx * stride_logits_nh
352361
logits_offs += (
353362
seq_part_idx * stride_logits_p
354363
+ o_grp_offs[:, None] * stride_logits_g
355364
+ o_head_sz_offs[None, :]
356365
)
366+
357367
gl.amd.cdna3.buffer_store(
358368
stored_value=acc, ptr=logits_ptr, offsets=logits_offs, mask=o_mask
359369
)
@@ -419,6 +429,10 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_wrapper(
419429
# pdb.set_trace()
420430

421431
# if 1:
432+
if QUERY_GRP_SZ < 16:
433+
QUERY_GRP_SZ_POW2 = 16
434+
else:
435+
QUERY_GRP_SZ_POW2 = triton.next_power_of_2(QUERY_GRP_SZ)
422436
if 0:
423437
with open(ttgir_file_path, "r") as f:
424438
ttgir_content = f.read()
@@ -488,12 +502,16 @@ def _paged_attn_decode_v2_w_dot_kernel_reshape_wrapper(
488502
stride_v_hz,
489503
stride_v_bz,
490504
stride_bt_s,
505+
HEAD_SZ,
506+
QUERY_GRP_SZ,
491507
compute_type=compute_type,
492-
HEAD_SZ=HEAD_SZ,
493-
QUERY_GRP_SZ=QUERY_GRP_SZ,
494-
KV_BLK_SZ=KV_BLK_SZ,
508+
HEAD_SZ_POW2=triton.next_power_of_2(HEAD_SZ),
509+
QUERY_GRP_SZ_POW2=QUERY_GRP_SZ_POW2,
510+
KV_BLK_SZ_POW2=triton.next_power_of_2(KV_BLK_SZ),
495511
SEQ_PARTITION_SZ=SEQ_PARTITION_SZ,
496512
SLIDING_WINDOW=SLIDING_WINDOW,
513+
waves_per_eu=4,
514+
num_stages=1,
497515
)
498516

499517

0 commit comments

Comments
 (0)