|
16 | 16 | import cutlass.cute as cute |
17 | 17 | from cutlass import Constexpr, Float32, Int32, const_expr, Boolean |
18 | 18 | from cutlass.cute.nvgpu import cpasync, warp, warpgroup |
19 | | -from cutlass.cute.arch import ProxyKind, SharedSpace |
20 | 19 | import cutlass.utils as utils_basic |
21 | 20 | from cutlass.utils import LayoutEnum |
22 | 21 | import cutlass.utils.hopper_helpers as sm90_utils_basic |
23 | 22 |
|
24 | | -from quack import copy_utils as quack_copy_utils |
| 23 | +from quack import copy_utils |
| 24 | +from quack import sm90_utils |
25 | 25 |
|
26 | 26 | from flash_attn.cute import ampere_helpers as sm80_utils |
27 | 27 | from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned |
28 | 28 | from flash_attn.cute import hopper_helpers as sm90_utils |
29 | 29 | from flash_attn.cute import utils |
30 | | -from flash_attn.cute import copy_utils |
31 | 30 | from flash_attn.cute.mask import AttentionMask |
32 | 31 | from flash_attn.cute.softmax import Softmax, apply_score_mod_inner |
33 | 32 | from flash_attn.cute.seqlen_info import SeqlenInfoQK |
@@ -357,7 +356,7 @@ def epilogue( |
357 | 356 | smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) |
358 | 357 | taccOrO = smem_thr_copy_O.retile(rO) |
359 | 358 | taccOsO = smem_thr_copy_O.partition_D(sO) |
360 | | - # taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) |
| 359 | + # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) |
361 | 360 | # copy acc O from rmem to smem with the smem copy atom |
362 | 361 | cute.copy(smem_copy_atom_O, taccOrO, taccOsO) |
363 | 362 |
|
@@ -406,7 +405,7 @@ def epilogue( |
406 | 405 | # sync to make sure all smem stores are done |
407 | 406 | if const_expr(self.use_tma_O): |
408 | 407 | # ensure smem writes are visible to TMA |
409 | | - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) |
| 408 | + cute.arch.fence_view_async_shared() |
410 | 409 | cute.arch.barrier_arrive( |
411 | 410 | barrier_id=int(NamedBarrierFwd.Epilogue), |
412 | 411 | number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, |
@@ -1220,7 +1219,6 @@ def _get_tiled_mma(self): |
1220 | 1219 | return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs |
1221 | 1220 |
|
1222 | 1221 | def _get_shared_storage_cls(self): |
1223 | | - # If we use cp.async to load Q, we want sQ to align to 1024 bytes |
1224 | 1222 | sQ_struct, sK_struct, sV_struct = [ |
1225 | 1223 | cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes] |
1226 | 1224 | for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) |
@@ -2247,9 +2245,7 @@ def first_half_block_overlap( |
2247 | 2245 | tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) |
2248 | 2246 | cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) |
2249 | 2247 | # Fence and barrier to make smem store visible to WGMMA |
2250 | | - cute.arch.fence_proxy( |
2251 | | - cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta |
2252 | | - ) |
| 2248 | + cute.arch.fence_view_async_shared() |
2253 | 2249 | cute.arch.sync_warp() |
2254 | 2250 |
|
2255 | 2251 | return kv_consumer_state |
@@ -2320,7 +2316,7 @@ def mma_one_n_block( |
2320 | 2316 | softmax.rescale_O(acc_O, row_scale) |
2321 | 2317 | if const_expr(not self.mma_pv_is_rs): |
2322 | 2318 | # Fence and barrier to make sure smem store is visible to WGMMA |
2323 | | - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) |
| 2319 | + cute.arch.fence_view_async_shared() |
2324 | 2320 | cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV |
2325 | 2321 | pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) |
2326 | 2322 | self.warp_scheduler_barrier_sync() |
@@ -2387,7 +2383,7 @@ def mma_one_n_block_intrawg_overlap( |
2387 | 2383 | softmax.rescale_O(acc_O, row_scale) |
2388 | 2384 | if const_expr(not self.mma_pv_is_rs): |
2389 | 2385 | # Fence and barrier to make sure smem store is visible to WGMMA |
2390 | | - cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) |
| 2386 | + cute.arch.fence_view_async_shared() |
2391 | 2387 | cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV |
2392 | 2388 | return smem_pipe_read |
2393 | 2389 |
|
|
0 commit comments