Skip to content

Commit 81f2c2d

Browse files
committed
[Sm90] Use functions from quack.sm90_utils
1 parent d39b629 commit 81f2c2d

File tree

4 files changed

+60
-216
lines changed

4 files changed

+60
-216
lines changed

flash_attn/cute/flash_bwd_postprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from cutlass import Float32, const_expr
1515
from cutlass.utils import LayoutEnum
1616

17+
import quack.sm90_utils as sm90_utils
18+
1719
from flash_attn.cute import utils
1820
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1921
from flash_attn.cute import copy_utils
2022
from flash_attn.cute import ampere_helpers as sm80_utils
21-
from flash_attn.cute import hopper_helpers as sm90_utils
2223
from flash_attn.cute.seqlen_info import SeqlenInfoQK
2324
import cutlass.cute.nvgpu.tcgen05 as tcgen05
2425
from flash_attn.cute.tile_scheduler import (

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 44 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from cutlass import Float32, Int32, Boolean, const_expr
1313
from cutlass.utils import LayoutEnum
1414

15-
from flash_attn.cute import hopper_helpers as sm90_utils
15+
import quack.sm90_utils as sm90_utils
16+
from quack.sm90_utils import gemm_zero_init, gemm_w_idx
17+
1618
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1719
from flash_attn.cute import utils
1820
from flash_attn.cute import copy_utils
19-
from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx
2021
from flash_attn.cute.mask import AttentionMask
2122
from flash_attn.cute.seqlen_info import SeqlenInfoQK
2223
from flash_attn.cute.block_info import BlockInfo
@@ -33,21 +34,6 @@
3334
)
3435

3536

36-
def mma_partition_fragment_AB(
37-
thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool
38-
):
39-
if const_expr(not swap_AB):
40-
return (
41-
thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None,
42-
thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None,
43-
)
44-
else:
45-
return (
46-
thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None,
47-
thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None,
48-
)
49-
50-
5137
class FlashAttentionBackwardSm90:
5238
arch = 90
5339

@@ -1033,20 +1019,56 @@ def mma(
10331019
wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
10341020
wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx))
10351021
# S = Q @ K.T
1036-
tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB)
1022+
shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
1023+
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
1024+
wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB
1025+
)
1026+
mma_qk_fn = partial(
1027+
gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB
1028+
)
10371029
# dP = dO @ V.T
1038-
tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB)
1030+
shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv)
1031+
_, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC(
1032+
wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB
1033+
)
1034+
mma_dov_fn = partial(
1035+
gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB
1036+
)
10391037
# dV += P.T @ dO
10401038
sPt = utils.transpose_view(sP) if sP is not None else None
10411039
sdOt = utils.transpose_view(sdO)
1042-
tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB)
1040+
shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m)
1041+
acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC(
1042+
wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB
1043+
)
1044+
if const_expr(not self.mma_dkv_is_rs):
1045+
mma_pdo_fn = partial(
1046+
gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB
1047+
)
1048+
else:
1049+
mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)
10431050
# dK += dS.T @ Q
10441051
sdSt = utils.transpose_view(sdS)
10451052
sQt = utils.transpose_view(sQ)
1046-
tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB)
1053+
shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m)
1054+
acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC(
1055+
wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB
1056+
)
1057+
if const_expr(not self.mma_dkv_is_rs):
1058+
mma_dsq_fn = partial(
1059+
gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB
1060+
)
1061+
else:
1062+
mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)
10471063
# dQ = dS @ K
10481064
sKt = utils.transpose_view(sK)
1049-
tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB)
1065+
shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
1066+
_, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
1067+
wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
1068+
)
1069+
mma_dsk_fn = partial(
1070+
gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB
1071+
)
10501072

10511073
# Smem copy atom tiling
10521074
smem_copy_atom_PdS = utils.get_smem_store_atom(
@@ -1084,53 +1106,6 @@ def mma(
10841106
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
10851107
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
10861108

1087-
dV_shape = (self.tile_n, self.tile_hdimv)
1088-
acc_dV = cute.make_fragment(
1089-
tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]),
1090-
Float32,
1091-
)
1092-
dK_shape = (self.tile_n, self.tile_hdim)
1093-
acc_dK = cute.make_fragment(
1094-
tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]),
1095-
Float32,
1096-
)
1097-
1098-
mma_qk_fn = partial(
1099-
gemm_zero_init,
1100-
tiled_mma_SdP,
1101-
(self.tile_m, self.tile_n),
1102-
tSrQ,
1103-
tSrK,
1104-
swap_AB=self.SdP_swapAB,
1105-
)
1106-
mma_dov_fn = partial(
1107-
gemm_zero_init,
1108-
tiled_mma_SdP,
1109-
(self.tile_m, self.tile_n),
1110-
tdPrdO,
1111-
tdPrV,
1112-
swap_AB=self.SdP_swapAB,
1113-
)
1114-
if const_expr(not self.mma_dkv_is_rs):
1115-
mma_pdo_fn = partial(
1116-
gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB
1117-
)
1118-
mma_dsq_fn = partial(
1119-
gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB
1120-
)
1121-
else:
1122-
assert not self.dKV_swapAB
1123-
mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)
1124-
mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)
1125-
mma_dsk_fn = partial(
1126-
gemm_zero_init,
1127-
tiled_mma_dQ,
1128-
(self.tile_m, self.tile_hdim),
1129-
tdQrdS,
1130-
tdQrKt,
1131-
swap_AB=self.dQ_swapAB,
1132-
)
1133-
11341109
mma_one_m_block_all = partial(
11351110
self.mma_one_m_block,
11361111
warp_group_idx=warp_group_idx,

flash_attn/cute/flash_fwd.py

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from flash_attn.cute import ampere_helpers as sm80_utils
2727
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
28-
from flash_attn.cute import hopper_helpers as sm90_utils
2928
from flash_attn.cute import utils
3029
from flash_attn.cute.mask import AttentionMask
3130
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
@@ -1206,17 +1205,7 @@ def _get_tiled_mma(self):
12061205
if self.mma_pv_is_rs
12071206
else warpgroup.OperandSource.SMEM,
12081207
)
1209-
tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma(
1210-
self.dtype,
1211-
self.dtype,
1212-
warpgroup.OperandMajorMode.K,
1213-
warpgroup.OperandMajorMode.MN,
1214-
Float32,
1215-
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1216-
tiler_mn=(64, self.tile_hdimv),
1217-
a_source=warpgroup.OperandSource.RMEM,
1218-
)
1219-
return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs
1208+
return tiled_mma_qk, tiled_mma_pv
12201209

12211210
def _get_shared_storage_cls(self):
12221211
sQ_struct, sK_struct, sV_struct = [
@@ -1296,7 +1285,7 @@ def __call__(
12961285
LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
12971286
mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None
12981287

1299-
tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma()
1288+
tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
13001289
self.num_mma_threads = tiled_mma_qk.size
13011290
self.num_threads_per_warp_group = 128
13021291
self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group
@@ -1342,7 +1331,7 @@ def __call__(
13421331
self.sP_layout = None
13431332
if const_expr(not self.mma_pv_is_rs):
13441333
self.sP_layout = sm90_utils.make_smem_layout(
1345-
mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
1334+
mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
13461335
)
13471336

13481337
SharedStorage = self._get_shared_storage_cls()
@@ -1526,7 +1515,6 @@ def __call__(
15261515
self.gmem_tiled_copy_O,
15271516
tiled_mma_qk,
15281517
tiled_mma_pv,
1529-
tiled_mma_pv_rs,
15301518
tile_sched_params,
15311519
TileScheduler,
15321520
SharedStorage,
@@ -1572,7 +1560,6 @@ def kernel(
15721560
gmem_tiled_copy_O: cute.TiledCopy,
15731561
tiled_mma_qk: cute.TiledMma,
15741562
tiled_mma_pv: cute.TiledMma,
1575-
tiled_mma_pv_rs: cute.TiledMma,
15761563
tile_sched_params: ParamsBase,
15771564
TileScheduler: cutlass.Constexpr[Callable],
15781565
SharedStorage: cutlass.Constexpr[Callable],
@@ -1701,7 +1688,6 @@ def kernel(
17011688
self.mma(
17021689
tiled_mma_qk,
17031690
tiled_mma_pv,
1704-
tiled_mma_pv_rs,
17051691
mQ,
17061692
mO,
17071693
mLSE,
@@ -1855,7 +1841,6 @@ def mma(
18551841
self,
18561842
tiled_mma_qk: cute.TiledMma,
18571843
tiled_mma_pv: cute.TiledMma,
1858-
tiled_mma_pv_rs: cute.TiledMma,
18591844
# softmax: Softmax,
18601845
# acc_O: cute.Tensor,
18611846
mQ: cute.Tensor,
@@ -1891,46 +1876,32 @@ def mma(
18911876
thr_mma_qk = tiled_mma_qk.get_slice(tidx)
18921877
wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
18931878
wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
1894-
tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ))
1895-
tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK))
1896-
if const_expr(self.mma_pv_is_rs):
1897-
acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n))
1898-
tOrP = cute.make_fragment(
1899-
utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype
1900-
)
1901-
else:
1902-
tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP))
1903-
tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt))
1879+
_, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
1880+
wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK
1881+
)
1882+
mma_qk_fn = partial(
1883+
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
1884+
)
1885+
acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(
1886+
wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt
1887+
)
1888+
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
19041889

19051890
# ///////////////////////////////////////////////////////////////////////////////
19061891
# Smem copy atom tiling
19071892
# ///////////////////////////////////////////////////////////////////////////////
19081893
smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype)
19091894
smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
1910-
# tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None
19111895
tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
1912-
# if cute.arch.thread_idx()[0] == 0:
1913-
# cute.printf(sP_pi.layout, sP_pi.iterator)
1914-
# cute.printf(sP.layout, sP.iterator)
1915-
# cute.printf(tPsP.layout, tPsP.iterator)
1916-
1917-
self.mma_init()
1918-
1919-
acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))
1920-
acc_O = cute.make_fragment(acc_shape_O, Float32)
19211896
smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
19221897

1923-
mma_qk_fn = partial(
1924-
sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
1925-
)
1926-
mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
1898+
self.mma_init()
19271899

19281900
mma_one_n_block_all = partial(
19291901
self.mma_one_n_block_intrawg_overlap
19301902
if const_expr(self.intra_wg_overlap)
19311903
else self.mma_one_n_block,
19321904
mma_qk_fn=mma_qk_fn,
1933-
tiled_mma_pv_rs=tiled_mma_pv_rs,
19341905
pipeline_k=pipeline_k,
19351906
pipeline_v=pipeline_v,
19361907
acc_O=acc_O,
@@ -2273,7 +2244,6 @@ def mma_one_n_block(
22732244
n_block: Int32,
22742245
mma_qk_fn: Callable,
22752246
mma_pv_fn: Callable,
2276-
tiled_mma_pv_rs: cute.TiledMma,
22772247
pipeline_k: cutlass.pipeline.PipelineAsync,
22782248
pipeline_v: cutlass.pipeline.PipelineAsync,
22792249
acc_O: cute.Tensor,
@@ -2333,7 +2303,6 @@ def mma_one_n_block_intrawg_overlap(
23332303
n_block: Int32,
23342304
mma_qk_fn: Callable,
23352305
mma_pv_fn: Callable,
2336-
tiled_mma_pv_rs: cute.TiledMma,
23372306
pipeline_k: cutlass.pipeline.PipelineAsync,
23382307
pipeline_v: cutlass.pipeline.PipelineAsync,
23392308
acc_O: cute.Tensor,

0 commit comments

Comments
 (0)