|
12 | 12 | from cutlass import Float32, Int32, Boolean, const_expr |
13 | 13 | from cutlass.utils import LayoutEnum |
14 | 14 |
|
15 | | -from flash_attn.cute import hopper_helpers as sm90_utils |
| 15 | +import quack.sm90_utils as sm90_utils |
| 16 | +from quack.sm90_utils import gemm_zero_init, gemm_w_idx |
| 17 | + |
16 | 18 | from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned |
17 | 19 | from flash_attn.cute import utils |
18 | 20 | from flash_attn.cute import copy_utils |
19 | | -from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx |
20 | 21 | from flash_attn.cute.mask import AttentionMask |
21 | 22 | from flash_attn.cute.seqlen_info import SeqlenInfoQK |
22 | 23 | from flash_attn.cute.block_info import BlockInfo |
|
33 | 34 | ) |
34 | 35 |
|
35 | 36 |
|
36 | | -def mma_partition_fragment_AB( |
37 | | - thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool |
38 | | -): |
39 | | - if const_expr(not swap_AB): |
40 | | - return ( |
41 | | - thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None, |
42 | | - thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None, |
43 | | - ) |
44 | | - else: |
45 | | - return ( |
46 | | - thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None, |
47 | | - thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None, |
48 | | - ) |
49 | | - |
50 | | - |
51 | 37 | class FlashAttentionBackwardSm90: |
52 | 38 | arch = 90 |
53 | 39 |
|
@@ -1033,20 +1019,56 @@ def mma( |
1033 | 1019 | wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) |
1034 | 1020 | wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) |
1035 | 1021 | # S = Q @ K.T |
1036 | | - tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB) |
| 1022 | + shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim) |
| 1023 | + _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( |
| 1024 | + wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB |
| 1025 | + ) |
| 1026 | + mma_qk_fn = partial( |
| 1027 | + gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB |
| 1028 | + ) |
1037 | 1029 | # dP = dO @ V.T |
1038 | | - tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB) |
| 1030 | + shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv) |
| 1031 | + _, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC( |
| 1032 | + wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB |
| 1033 | + ) |
| 1034 | + mma_dov_fn = partial( |
| 1035 | + gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB |
| 1036 | + ) |
1039 | 1037 | # dV += P.T @ dO |
1040 | 1038 | sPt = utils.transpose_view(sP) if sP is not None else None |
1041 | 1039 | sdOt = utils.transpose_view(sdO) |
1042 | | - tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB) |
| 1040 | + shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m) |
| 1041 | + acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC( |
| 1042 | + wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB |
| 1043 | + ) |
| 1044 | + if const_expr(not self.mma_dkv_is_rs): |
| 1045 | + mma_pdo_fn = partial( |
| 1046 | + gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB |
| 1047 | + ) |
| 1048 | + else: |
| 1049 | + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) |
1043 | 1050 | # dK += dS.T @ Q |
1044 | 1051 | sdSt = utils.transpose_view(sdS) |
1045 | 1052 | sQt = utils.transpose_view(sQ) |
1046 | | - tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB) |
| 1053 | + shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m) |
| 1054 | + acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC( |
| 1055 | + wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB |
| 1056 | + ) |
| 1057 | + if const_expr(not self.mma_dkv_is_rs): |
| 1058 | + mma_dsq_fn = partial( |
| 1059 | + gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB |
| 1060 | + ) |
| 1061 | + else: |
| 1062 | + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) |
1047 | 1063 | # dQ = dS @ K |
1048 | 1064 | sKt = utils.transpose_view(sK) |
1049 | | - tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB) |
| 1065 | + shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n) |
| 1066 | + _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( |
| 1067 | + wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB |
| 1068 | + ) |
| 1069 | + mma_dsk_fn = partial( |
| 1070 | + gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB |
| 1071 | + ) |
1050 | 1072 |
|
1051 | 1073 | # Smem copy atom tiling |
1052 | 1074 | smem_copy_atom_PdS = utils.get_smem_store_atom( |
@@ -1084,53 +1106,6 @@ def mma( |
1084 | 1106 | smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) |
1085 | 1107 | tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) |
1086 | 1108 |
|
1087 | | - dV_shape = (self.tile_n, self.tile_hdimv) |
1088 | | - acc_dV = cute.make_fragment( |
1089 | | - tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]), |
1090 | | - Float32, |
1091 | | - ) |
1092 | | - dK_shape = (self.tile_n, self.tile_hdim) |
1093 | | - acc_dK = cute.make_fragment( |
1094 | | - tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]), |
1095 | | - Float32, |
1096 | | - ) |
1097 | | - |
1098 | | - mma_qk_fn = partial( |
1099 | | - gemm_zero_init, |
1100 | | - tiled_mma_SdP, |
1101 | | - (self.tile_m, self.tile_n), |
1102 | | - tSrQ, |
1103 | | - tSrK, |
1104 | | - swap_AB=self.SdP_swapAB, |
1105 | | - ) |
1106 | | - mma_dov_fn = partial( |
1107 | | - gemm_zero_init, |
1108 | | - tiled_mma_SdP, |
1109 | | - (self.tile_m, self.tile_n), |
1110 | | - tdPrdO, |
1111 | | - tdPrV, |
1112 | | - swap_AB=self.SdP_swapAB, |
1113 | | - ) |
1114 | | - if const_expr(not self.mma_dkv_is_rs): |
1115 | | - mma_pdo_fn = partial( |
1116 | | - gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB |
1117 | | - ) |
1118 | | - mma_dsq_fn = partial( |
1119 | | - gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB |
1120 | | - ) |
1121 | | - else: |
1122 | | - assert not self.dKV_swapAB |
1123 | | - mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) |
1124 | | - mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) |
1125 | | - mma_dsk_fn = partial( |
1126 | | - gemm_zero_init, |
1127 | | - tiled_mma_dQ, |
1128 | | - (self.tile_m, self.tile_hdim), |
1129 | | - tdQrdS, |
1130 | | - tdQrKt, |
1131 | | - swap_AB=self.dQ_swapAB, |
1132 | | - ) |
1133 | | - |
1134 | 1109 | mma_one_m_block_all = partial( |
1135 | 1110 | self.mma_one_m_block, |
1136 | 1111 | warp_group_idx=warp_group_idx, |
|
0 commit comments