Skip to content

Commit a804a5a

Browse files
committed
[DSL] Replace old fence with cute.arch.fence_view_async_shared()
1 parent 48af662 commit a804a5a

File tree

4 files changed

+19
-35
lines changed

4 files changed

+19
-35
lines changed

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,9 +2277,7 @@ def compute_loop(
22772277

22782278
if const_expr(not self.use_smem_dS_for_mma_dK):
22792279
cute.arch.fence_view_async_tmem_store()
2280-
cute.arch.fence_proxy(
2281-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2282-
)
2280+
cute.arch.fence_view_async_shared()
22832281
self.compute_sync_barrier.arrive_and_wait()
22842282

22852283
# with cute.arch.elect_one():
@@ -2528,9 +2526,7 @@ def dQacc_reduce(
25282526
)
25292527
cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s)
25302528
# Fence and barrier to make sure shared memory store is visible to TMA store
2531-
cute.arch.fence_proxy(
2532-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2533-
)
2529+
cute.arch.fence_view_async_shared()
25342530
# semaphore acquire
25352531
if const_expr(self.deterministic and stage == 0):
25362532
if const_expr(self.spt):
@@ -2886,9 +2882,7 @@ def epilogue_dK_or_dV_tma(
28862882
# RMEM -> SMEM -- copy, fence and barrier
28872883
tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)
28882884
cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)
2889-
cute.arch.fence_proxy(
2890-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2891-
)
2885+
cute.arch.fence_view_async_shared()
28922886
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
28932887

28942888
# SMEM -> GMEM
@@ -2910,9 +2904,7 @@ def epilogue_dK_or_dV_tma(
29102904
)
29112905

29122906
# Barrier since all warps need to wait for SMEM to be freed
2913-
cute.arch.fence_proxy(
2914-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2915-
)
2907+
cute.arch.fence_view_async_shared()
29162908
cute.arch.barrier(
29172909
barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
29182910
)

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import cutlass.cute as cute
99
import cutlass.utils.hopper_helpers as sm90_utils_basic
1010
from cutlass.cute.nvgpu import cpasync, warpgroup
11-
from cutlass.cute.arch import ProxyKind, SharedSpace
1211
from cutlass.cute import FastDivmodDivisor
1312
from cutlass import Float32, Int32, Boolean, const_expr
1413
from cutlass.utils import LayoutEnum
@@ -1409,7 +1408,7 @@ def mma_one_m_block(
14091408
# This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and
14101409
# (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
14111410
if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
1412-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1411+
cute.arch.fence_view_async_shared()
14131412
cute.arch.barrier(
14141413
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
14151414
)
@@ -1427,7 +1426,7 @@ def mma_one_m_block(
14271426
mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)
14281427

14291428
# smem fence to make sure sdS is written before it's read by WGMMA
1430-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1429+
cute.arch.fence_view_async_shared()
14311430
cute.arch.barrier(
14321431
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
14331432
)
@@ -1451,7 +1450,7 @@ def mma_one_m_block(
14511450
)
14521451
tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
14531452
cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
1454-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1453+
cute.arch.fence_view_async_shared()
14551454
cute.arch.barrier_arrive(
14561455
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
14571456
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
@@ -1524,7 +1523,7 @@ def epilogue_dKV(
15241523
sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV)
15251524
taccdVsdV = smem_thr_copy_dV.partition_D(sdV)
15261525
cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
1527-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1526+
cute.arch.fence_view_async_shared()
15281527
cute.arch.barrier(
15291528
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
15301529
)
@@ -1534,7 +1533,7 @@ def epilogue_dKV(
15341533
sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK)
15351534
taccdKsdK = smem_thr_copy_dK.partition_D(sdK)
15361535
cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
1537-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1536+
cute.arch.fence_view_async_shared()
15381537
cute.arch.barrier(
15391538
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
15401539
)
@@ -1573,7 +1572,7 @@ def epilogue_dKV(
15731572
acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape)
15741573
)
15751574
cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum)
1576-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1575+
cute.arch.fence_view_async_shared()
15771576
cute.arch.barrier(
15781577
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
15791578
)
@@ -1597,7 +1596,7 @@ def epilogue_dKV(
15971596
acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape)
15981597
)
15991598
cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum)
1600-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
1599+
cute.arch.fence_view_async_shared()
16011600
cute.arch.barrier(
16021601
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
16031602
)

flash_attn/cute/flash_fwd.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@
1616
import cutlass.cute as cute
1717
from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
1818
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
19-
from cutlass.cute.arch import ProxyKind, SharedSpace
2019
import cutlass.utils as utils_basic
2120
from cutlass.utils import LayoutEnum
2221
import cutlass.utils.hopper_helpers as sm90_utils_basic
2322

24-
from quack import copy_utils as quack_copy_utils
23+
from quack import copy_utils
24+
from quack import sm90_utils
2525

2626
from flash_attn.cute import ampere_helpers as sm80_utils
2727
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
2828
from flash_attn.cute import hopper_helpers as sm90_utils
2929
from flash_attn.cute import utils
30-
from flash_attn.cute import copy_utils
3130
from flash_attn.cute.mask import AttentionMask
3231
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
3332
from flash_attn.cute.seqlen_info import SeqlenInfoQK
@@ -357,7 +356,7 @@ def epilogue(
357356
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
358357
taccOrO = smem_thr_copy_O.retile(rO)
359358
taccOsO = smem_thr_copy_O.partition_D(sO)
360-
# taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
359+
# taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
361360
# copy acc O from rmem to smem with the smem copy atom
362361
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
363362

@@ -406,7 +405,7 @@ def epilogue(
406405
# sync to make sure all smem stores are done
407406
if const_expr(self.use_tma_O):
408407
# ensure smem writes are visible to TMA
409-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
408+
cute.arch.fence_view_async_shared()
410409
cute.arch.barrier_arrive(
411410
barrier_id=int(NamedBarrierFwd.Epilogue),
412411
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
@@ -1220,7 +1219,6 @@ def _get_tiled_mma(self):
12201219
return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs
12211220

12221221
def _get_shared_storage_cls(self):
1223-
# If we use cp.async to load Q, we want sQ to align to 1024 bytes
12241222
sQ_struct, sK_struct, sV_struct = [
12251223
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
12261224
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
@@ -2247,9 +2245,7 @@ def first_half_block_overlap(
22472245
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
22482246
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
22492247
# Fence and barrier to make smem store visible to WGMMA
2250-
cute.arch.fence_proxy(
2251-
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
2252-
)
2248+
cute.arch.fence_view_async_shared()
22532249
cute.arch.sync_warp()
22542250

22552251
return kv_consumer_state
@@ -2320,7 +2316,7 @@ def mma_one_n_block(
23202316
softmax.rescale_O(acc_O, row_scale)
23212317
if const_expr(not self.mma_pv_is_rs):
23222318
# Fence and barrier to make sure smem store is visible to WGMMA
2323-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
2319+
cute.arch.fence_view_async_shared()
23242320
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
23252321
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
23262322
self.warp_scheduler_barrier_sync()
@@ -2387,7 +2383,7 @@ def mma_one_n_block_intrawg_overlap(
23872383
softmax.rescale_O(acc_O, row_scale)
23882384
if const_expr(not self.mma_pv_is_rs):
23892385
# Fence and barrier to make sure smem store is visible to WGMMA
2390-
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
2386+
cute.arch.fence_view_async_shared()
23912387
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
23922388
return smem_pipe_read
23932389

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,10 +2428,7 @@ def correction_epilogue(
24282428
tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype))
24292429
cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i)
24302430
# fence view async shared
2431-
cute.arch.fence_proxy(
2432-
cute.arch.ProxyKind.async_shared,
2433-
space=cute.arch.SharedSpace.shared_cta,
2434-
)
2431+
cute.arch.fence_view_async_shared()
24352432

24362433
if const_expr(self.use_correction_warps_for_epi):
24372434
assert(not self.use_tma_O)

0 commit comments

Comments
 (0)