Skip to content

Commit 2a8d39c

Browse files
committed
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
1 parent 17d2943 commit 2a8d39c

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,18 +1093,18 @@ def kernel(
10931093
# EMPTY
10941094
# (15)
10951095
if warp_idx == self.empty_warp_id:
1096-
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
1096+
cute.arch.setmaxregister_decrease(self.num_regs_empty)
10971097

10981098
# EPI
10991099
# (14)
11001100
if warp_idx == self.epi_warp_id:
11011101
# currently no-op, could use for tma store/reduce
1102-
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
1102+
cute.arch.setmaxregister_decrease(self.num_regs_empty)
11031103

11041104
# LOAD
11051105
# (13)
11061106
if warp_idx == self.load_warp_id:
1107-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1107+
cute.arch.setmaxregister_decrease(self.num_regs_other)
11081108
self.load(
11091109
thr_mma_S,
11101110
thr_mma_dP,
@@ -1141,7 +1141,7 @@ def kernel(
11411141
# MMA
11421142
# (12)
11431143
if warp_idx == self.mma_warp_id:
1144-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1144+
cute.arch.setmaxregister_decrease(self.num_regs_other)
11451145

11461146
# Alloc tmem buffer
11471147
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
@@ -1194,7 +1194,7 @@ def kernel(
11941194
# Compute
11951195
# (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps
11961196
if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
1197-
cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps
1197+
cute.arch.setmaxregister_increase(self.num_regs_compute) # 8 warps
11981198
self.compute_loop(
11991199
thr_mma_S,
12001200
thr_mma_dP,
@@ -1239,7 +1239,7 @@ def kernel(
12391239
# Reduce
12401240
# (0, 1, 2, 3) - dQ
12411241
if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:
1242-
cute.arch.warpgroup_reg_alloc(self.num_regs_reduce)
1242+
cute.arch.setmaxregister_increase(self.num_regs_reduce)
12431243
self.dQacc_reduce(
12441244
mdQaccum,
12451245
sdQaccum,

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def kernel(
640640
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
641641

642642
if warp_idx < 4:
643-
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
643+
cute.arch.setmaxregister_decrease(self.num_producer_regs)
644644
if warp_idx == 0:
645645
self.load(
646646
mQ,
@@ -682,7 +682,7 @@ def kernel(
682682
blocksparse_tensors,
683683
)
684684
else:
685-
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
685+
cute.arch.setmaxregister_increase(self.num_mma_regs)
686686
tidx, _, _ = cute.arch.thread_idx()
687687
tidx = tidx - 128
688688
self.mma(

flash_attn/cute/flash_fwd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,7 @@ def kernel(
16591659
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
16601660

16611661
if warp_idx < 4: # Producer
1662-
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
1662+
cute.arch.setmaxregister_decrease(self.num_producer_regs)
16631663
self.load(
16641664
mQ,
16651665
mK,
@@ -1680,7 +1680,7 @@ def kernel(
16801680
)
16811681

16821682
else: # Consumer
1683-
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
1683+
cute.arch.setmaxregister_increase(self.num_mma_regs)
16841684
# ///////////////////////////////////////////////////////////////////////////////
16851685
# Tile MMA compute thread partitions and allocate accumulators
16861686
# ///////////////////////////////////////////////////////////////////////////////

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -951,13 +951,13 @@ def kernel(
951951
# ///////////////////////////////////////////////////////////////////////////////
952952
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
953953
if warp_idx == self.empty_warp_ids[i]:
954-
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
954+
cute.arch.setmaxregister_decrease(self.num_regs_empty)
955955

956956
# ///////////////////////////////////////////////////////////////////////////////
957957
# LOAD
958958
# ///////////////////////////////////////////////////////////////////////////////
959959
if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:
960-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
960+
cute.arch.setmaxregister_decrease(self.num_regs_other)
961961
self.load(
962962
thr_mma_qk,
963963
thr_mma_pv,
@@ -985,7 +985,7 @@ def kernel(
985985
# ///////////////////////////////////////////////////////////////////////////////
986986
if warp_idx == self.mma_warp_id:
987987
# if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids:
988-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
988+
cute.arch.setmaxregister_decrease(self.num_regs_other)
989989
# Alloc tmem buffer
990990
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
991991
if warp_idx == self.mma_warp_id:
@@ -1028,7 +1028,7 @@ def kernel(
10281028
# ///////////////////////////////////////////////////////////////////////////////
10291029
if const_expr(not self.use_correction_warps_for_epi):
10301030
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
1031-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1031+
cute.arch.setmaxregister_decrease(self.num_regs_other)
10321032
self.epilogue_s2g(
10331033
mO,
10341034
sO,
@@ -1049,7 +1049,7 @@ def kernel(
10491049
(const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])
10501050
):
10511051
# increase register after decreasing
1052-
cute.arch.warpgroup_reg_alloc(self.num_regs_softmax)
1052+
cute.arch.setmaxregister_increase(self.num_regs_softmax)
10531053
softmax_loop = partial(
10541054
self.softmax_loop,
10551055
softmax_scale_log2=softmax_scale_log2,
@@ -1096,7 +1096,7 @@ def kernel(
10961096
# Correction
10971097
# ///////////////////////////////////////////////////////////////////////////////
10981098
if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:
1099-
cute.arch.warpgroup_reg_dealloc(self.num_regs_correction)
1099+
cute.arch.setmaxregister_decrease(self.num_regs_correction)
11001100
self.correction_loop(
11011101
thr_mma_qk,
11021102
thr_mma_pv,

0 commit comments

Comments
 (0)