Skip to content

Commit 844b10f

Browse files
authored
score-mod backward SM90 (#2137)
1 parent 27a3b54 commit 844b10f

File tree

6 files changed

+293
-31
lines changed

6 files changed

+293
-31
lines changed

flash_attn/cute/block_sparsity.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Block-sparsity utilities for FlexAttention
33
"""
44

5-
from typing import NamedTuple, Optional, Tuple
5+
from typing import Callable, NamedTuple, Tuple
66

77
import cutlass.cute as cute
88
import torch
@@ -17,8 +17,8 @@ def ceildiv(a: int, b: int) -> int:
1717
class BlockSparseTensors(NamedTuple):
1818
mask_block_cnt: cute.Tensor
1919
mask_block_idx: cute.Tensor
20-
full_block_cnt: Optional[cute.Tensor]
21-
full_block_idx: Optional[cute.Tensor]
20+
full_block_cnt: cute.Tensor | None
21+
full_block_idx: cute.Tensor | None
2222

2323
def __new_from_mlir_values__(self, values):
2424
if len(values) == 2:
@@ -29,34 +29,42 @@ def __new_from_mlir_values__(self, values):
2929
class BlockSparseTensorsTorch(NamedTuple):
3030
mask_block_cnt: torch.Tensor
3131
mask_block_idx: torch.Tensor
32-
full_block_cnt: Optional[torch.Tensor] = None
33-
full_block_idx: Optional[torch.Tensor] = None
32+
full_block_cnt: torch.Tensor | None = None
33+
full_block_idx: torch.Tensor | None = None
3434

3535

3636
def _expand_sparsity_tensor(
3737
tensor: torch.Tensor,
3838
expected_shape: Tuple[int, ...],
3939
tensor_name: str,
40+
context: str | None,
41+
hint: str | Callable[[], str] | None,
4042
) -> torch.Tensor:
4143
"""Check if we need to expand the tensor to expected shape, and do so if possible."""
4244
needs_expand = tensor.shape != expected_shape
4345
if not needs_expand:
4446
return tensor
4547
can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
4648
if not can_expand:
49+
context_clause = f" ({context})" if context else ""
50+
resolved_hint = hint() if callable(hint) else hint
51+
hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
4752
raise ValueError(
48-
f"{tensor_name} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
53+
f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
54+
f"{hint_clause}"
4955
)
5056
return tensor.expand(*expected_shape).contiguous()
5157

5258

5359
def _check_and_expand_block(
5460
name: str,
55-
cnt: Optional[torch.Tensor],
56-
idx: Optional[torch.Tensor],
61+
cnt: torch.Tensor | None,
62+
idx: torch.Tensor | None,
5763
expected_count_shape: Tuple[int, int, int],
5864
expected_index_shape: Tuple[int, int, int, int],
59-
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
65+
context: str | None,
66+
hint: str | Callable[[], str] | None,
67+
) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
6068
if (cnt is None) != (idx is None):
6169
raise ValueError(
6270
f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
@@ -69,8 +77,12 @@ def _check_and_expand_block(
6977
raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
7078
if not cnt.is_cuda or not idx.is_cuda:
7179
raise ValueError(f"{name}_block tensors must live on CUDA")
72-
expanded_cnt = _expand_sparsity_tensor(cnt, expected_count_shape, f"{name}_block_cnt")
73-
expanded_idx = _expand_sparsity_tensor(idx, expected_index_shape, f"{name}_block_idx")
80+
expanded_cnt = _expand_sparsity_tensor(
81+
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
82+
)
83+
expanded_idx = _expand_sparsity_tensor(
84+
idx, expected_index_shape, f"{name}_block_idx", context, hint
85+
)
7486
return expanded_cnt, expanded_idx
7587

7688

@@ -120,6 +132,8 @@ def normalize_block_sparse_tensors(
120132
*,
121133
expected_count_shape: Tuple[int, int, int],
122134
expected_index_shape: Tuple[int, int, int, int],
135+
context: str | None = None,
136+
hint: str | Callable[[], str] | None = None,
123137
) -> BlockSparseTensorsTorch:
124138
if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
125139
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
@@ -130,6 +144,8 @@ def normalize_block_sparse_tensors(
130144
tensors.mask_block_idx,
131145
expected_count_shape,
132146
expected_index_shape,
147+
context,
148+
hint,
133149
)
134150
if mask_cnt is None or mask_idx is None:
135151
raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
@@ -140,6 +156,8 @@ def normalize_block_sparse_tensors(
140156
tensors.full_block_idx,
141157
expected_count_shape,
142158
expected_index_shape,
159+
context,
160+
hint,
143161
)
144162
if full_cnt is not None and mask_cnt.device != full_cnt.device:
145163
raise ValueError("All block sparse tensors must be on the same device")
@@ -158,7 +176,7 @@ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
158176

159177
def to_cute_block_sparse_tensors(
160178
tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
161-
) -> Optional[BlockSparseTensors]:
179+
) -> BlockSparseTensors | None:
162180
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
163181
if not is_block_sparsity_enabled(tensors):
164182
return None

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flash_attn.cute import pipeline
2424
from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase
2525
from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd
26+
from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
2627
from flash_attn.cute.block_sparsity import BlockSparseTensors
2728
from flash_attn.cute.block_sparse_utils import (
2829
get_total_q_block_count_bwd,
@@ -70,6 +71,8 @@ def __init__(
7071
AtomLayoutMdQ: int = 1,
7172
num_threads: int = 384,
7273
V_in_regs: bool = False,
74+
score_mod: cutlass.Constexpr | None = None,
75+
score_mod_bwd: cutlass.Constexpr | None = None,
7376
mask_mod: cutlass.Constexpr | None = None,
7477
has_aux_tensors: cutlass.Constexpr = False,
7578
subtile_factor: cutlass.Constexpr[int] = 1,
@@ -118,13 +121,16 @@ def __init__(
118121
self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
119122
self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
120123

124+
self.score_mod = score_mod
125+
self.score_mod_bwd = score_mod_bwd
121126
self.mask_mod = mask_mod
122127
self.has_aux_tensors = has_aux_tensors
123128
self.subtile_factor = subtile_factor
124129
if cutlass.const_expr(has_aux_tensors):
125130
self.vec_size: cutlass.Constexpr = 1
126131
else:
127132
self.vec_size: cutlass.Constexpr = 4
133+
self.qk_acc_dtype = Float32
128134

129135
@staticmethod
130136
def can_implement(
@@ -443,7 +449,10 @@ def __call__(
443449
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
444450

445451
LOG2_E = math.log2(math.e)
446-
softmax_scale_log2 = softmax_scale * LOG2_E
452+
if const_expr(self.score_mod is None):
453+
softmax_scale_log2 = softmax_scale * LOG2_E
454+
else:
455+
softmax_scale_log2 = LOG2_E
447456

448457
fastdiv_mods = None
449458
if const_expr(aux_tensors is not None):
@@ -856,6 +865,93 @@ def load(
856865
tile_scheduler.advance_to_next_work()
857866
work_tile = tile_scheduler.get_current_work()
858867

868+
@cute.jit
869+
def apply_score_mod(
870+
self,
871+
acc_S: cute.Tensor,
872+
thr_mma_SdP: cute.core.ThrMma,
873+
batch_idx,
874+
head_idx,
875+
m_block,
876+
n_block,
877+
softmax_scale,
878+
seqlen_info: SeqlenInfoQK,
879+
aux_tensors=None,
880+
fastdiv_mods=(None, None),
881+
):
882+
# [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing
883+
cS = cute.make_identity_tensor(
884+
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
885+
)
886+
cS = cute.domain_offset(
887+
(n_block * self.tile_n, m_block * self.tile_m)
888+
if self.SdP_swapAB
889+
else (m_block * self.tile_m, n_block * self.tile_n),
890+
cS,
891+
)
892+
tScS = thr_mma_SdP.partition_C(cS)
893+
894+
apply_score_mod_inner(
895+
acc_S,
896+
tScS,
897+
self.score_mod,
898+
batch_idx,
899+
head_idx,
900+
softmax_scale,
901+
self.vec_size,
902+
self.qk_acc_dtype,
903+
aux_tensors,
904+
fastdiv_mods,
905+
seqlen_info,
906+
constant_q_idx=None,
907+
qhead_per_kvhead=self.qhead_per_kvhead,
908+
transpose_indices=self.SdP_swapAB,
909+
)
910+
911+
@cute.jit
912+
def apply_score_mod_bwd(
913+
self,
914+
grad_tensor: cute.Tensor,
915+
score_tensor: cute.Tensor,
916+
thr_mma_SdP: cute.core.ThrMma,
917+
batch_idx,
918+
head_idx,
919+
m_block,
920+
n_block,
921+
softmax_scale,
922+
seqlen_info: SeqlenInfoQK,
923+
aux_tensors=None,
924+
fastdiv_mods=(None, None),
925+
):
926+
cS = cute.make_identity_tensor(
927+
(self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
928+
)
929+
cS = cute.domain_offset(
930+
(n_block * self.tile_n, m_block * self.tile_m)
931+
if self.SdP_swapAB
932+
else (m_block * self.tile_m, n_block * self.tile_n),
933+
cS,
934+
)
935+
tScS = thr_mma_SdP.partition_C(cS)
936+
937+
apply_score_mod_bwd_inner(
938+
grad_tensor,
939+
score_tensor,
940+
tScS,
941+
self.score_mod_bwd,
942+
batch_idx,
943+
head_idx,
944+
softmax_scale,
945+
self.vec_size,
946+
self.qk_acc_dtype,
947+
aux_tensors,
948+
fastdiv_mods,
949+
seqlen_info,
950+
constant_q_idx=None,
951+
qhead_per_kvhead=self.qhead_per_kvhead,
952+
transpose_indices=self.SdP_swapAB,
953+
)
954+
859955
@cute.jit
860956
def mma(
861957
self,
@@ -1196,6 +1292,24 @@ def mma_one_m_block(
11961292
)
11971293
acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)
11981294

1295+
if const_expr(self.score_mod_bwd is not None):
1296+
acc_S_pre = cute.make_fragment_like(acc_S)
1297+
cute.autovec_copy(acc_S, acc_S_pre)
1298+
1299+
if const_expr(self.score_mod is not None):
1300+
self.apply_score_mod(
1301+
acc_S,
1302+
thr_mma_SdP,
1303+
batch_idx,
1304+
head_idx,
1305+
m_block,
1306+
n_block,
1307+
softmax_scale,
1308+
seqlen,
1309+
aux_tensors,
1310+
fastdiv_mods,
1311+
)
1312+
11991313
# (3) [Pointwise 1] P = exp(S - LSE)
12001314
if cutlass.const_expr(mask_fn is not None):
12011315
mask_fn(acc_S, m_block=m_block)
@@ -1226,6 +1340,21 @@ def mma_one_m_block(
12261340
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
12271341
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
12281342

1343+
if const_expr(self.score_mod_bwd is not None):
1344+
self.apply_score_mod_bwd(
1345+
acc_dP,
1346+
acc_S_pre,
1347+
thr_mma_SdP,
1348+
batch_idx,
1349+
head_idx,
1350+
m_block,
1351+
n_block,
1352+
softmax_scale,
1353+
seqlen,
1354+
aux_tensors,
1355+
fastdiv_mods,
1356+
)
1357+
12291358
# Convert dS from f32 -> f16
12301359
tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)
12311360

flash_attn/cute/interface.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,6 @@ def _flash_attn_bwd(
713713
assert cu_seqlens_q is None and cu_seqlens_k is None, (
714714
"varlen + score_mod not supported in bwd yet"
715715
)
716-
assert compute_capability in [10, 11], "score_mod in bwd only supported on SM100/SM110 for now"
717716

718717
device = q.device
719718
out_torch_dtype = q.dtype
@@ -910,7 +909,6 @@ def _flash_attn_bwd(
910909
num_aux_tensors,
911910
use_block_sparsity,
912911
)
913-
cute_aux_tensors = None
914912
else:
915913
compile_key = (
916914
compute_capability,
@@ -999,6 +997,8 @@ def _flash_attn_bwd(
999997
AtomLayoutMdQ,
1000998
num_threads,
1001999
V_in_regs=V_in_regs,
1000+
score_mod=score_mod,
1001+
score_mod_bwd=score_mod_bwd,
10021002
mask_mod=mask_mod,
10031003
has_aux_tensors=aux_tensors is not None,
10041004
subtile_factor=subtile_factor,
@@ -1034,6 +1034,12 @@ def _flash_attn_bwd(
10341034
block_sparse_tensors,
10351035
expected_count_shape=expected_count_shape,
10361036
expected_index_shape=expected_index_shape,
1037+
context="_flash_attn_bwd",
1038+
hint=lambda: (
1039+
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
1040+
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
1041+
f"(sparse_block_size_q={sparse_block_size_q})."
1042+
),
10371043
)
10381044
sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized)
10391045

@@ -1076,6 +1082,12 @@ def _flash_attn_bwd(
10761082
block_sparse_tensors,
10771083
expected_count_shape=expected_count_shape,
10781084
expected_index_shape=expected_index_shape,
1085+
context="_flash_attn_bwd",
1086+
hint=lambda: (
1087+
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
1088+
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
1089+
f"(sparse_block_size_q={sparse_block_size_q})."
1090+
),
10791091
)
10801092

10811093
_flash_attn_bwd.compile_cache[compile_key](

flash_attn/cute/mask.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def apply_mask(
139139
): # FlexAttention mask mod
140140
nrow = const_expr(cute.size(tScS_mn.shape[0]))
141141
ncol = const_expr(cute.size(tScS_mn.shape[1]))
142-
thr_col_offset = tScS_mn[0, 0][1]
143142
has_fastdiv = const_expr(
144143
fastdiv_mods is not None
145144
and fastdiv_mods[0] is not None
@@ -150,7 +149,9 @@ def apply_mask(
150149
)
151150

152151
for r in cutlass.range_constexpr(nrow):
153-
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
152+
# Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
153+
local_row = tScS_mn[r, 0][ROW]
154+
global_row_idx = local_row + m_block * self.tile_m
154155
row_for_mod = global_row_idx
155156
head_idx_for_mod = head_idx
156157
if const_expr(self.qhead_per_kvhead_packgqa != 1):
@@ -162,7 +163,7 @@ def apply_mask(
162163
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
163164

164165
for col in cutlass.range_constexpr(ncol):
165-
col_idx_local = t0ScS_mn[0, col][1]
166+
col_idx_local = t0ScS_mn[0, col][COL]
166167
# Convert to absolute column index
167168
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
168169
col_for_mod = global_col_idx
@@ -354,7 +355,7 @@ def apply_mask_sm100(
354355
mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
355356

356357
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
357-
# Block sparse w/ mask_mod
358+
# Block sparse case w/ mask_mod
358359
has_fastdiv = const_expr(
359360
fastdiv_mods is not None
360361
and fastdiv_mods[0] is not None

0 commit comments

Comments
 (0)