Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def consume_block_sparse_loads(
score_mod_fn,
O_should_accumulate,
mask_mod,
fastdiv_mods,
intra_wg_overlap: cutlass.Constexpr,
warp_scheduler_barrier_sync: Callable,
warp_scheduler_barrier_arrive: Callable,
Expand All @@ -309,7 +310,12 @@ def consume_block_sparse_loads(
kv_consumer_state,
n_block=mask_n_block,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=True),
mask_fn=partial(
mask_fn,
mask_mod=mask_mod,
mask_seqlen=True,
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
),
is_first_n_block=True,
)
O_should_accumulate = True
Expand Down Expand Up @@ -374,7 +380,12 @@ def consume_block_sparse_loads(
kv_consumer_state = process_first_half_block(
n_block=mask_n_block,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=mask_mod),
mask_fn=partial(
mask_fn,
mask_mod=mask_mod,
mask_seqlen=True,
fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
),
score_mod_fn=score_mod_fn,
is_first_block=True,
)
Expand All @@ -394,7 +405,7 @@ def consume_block_sparse_loads(
kv_consumer_state = process_first_half_block(
n_block=full_n_block,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=None),
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
score_mod_fn=score_mod_fn,
is_first_block=True,
)
Expand Down
5 changes: 4 additions & 1 deletion flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ def preprocess_Q():
thr_mma=thr_mma_qk,
mask_causal=self.is_causal,
mask_local=self.is_local,
fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
)

# First iteration with seqlen masking
Expand Down Expand Up @@ -1991,6 +1992,7 @@ def mma(
mask_causal=self.is_causal,
mask_local=self.is_local,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
score_mod_fn = None
if const_expr(self.score_mod is not None):
Expand Down Expand Up @@ -2131,11 +2133,12 @@ def mma(
score_mod_fn,
O_should_accumulate,
self.mask_mod,
fastdiv_mods,
self.intra_wg_overlap,
self.warp_scheduler_barrier_sync,
self.warp_scheduler_barrier_arrive,
)

# Handle empty case (when no blocks to process)
if not processed_any:
softmax.reset()
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,13 +1628,15 @@ def softmax_loop(
mask_fn = partial(
mask.apply_mask_sm100,
mask_mod=block_mask_mod,
fastdiv_mods=fastdiv_mods,
**shared_mask_kwargs,
)
if const_expr(self.use_block_sparsity):
# Full blocks dont need mask_mod
mask_fn_none = partial(
mask.apply_mask_sm100,
mask_mod=None,
fastdiv_mods=fastdiv_mods,
**shared_mask_kwargs,
)
else:
Expand Down
44 changes: 35 additions & 9 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def apply_mask(
mask_local: cutlass.Constexpr[bool] = False,
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
Expand Down Expand Up @@ -131,24 +132,33 @@ def apply_mask(
nrow = const_expr(cute.size(tScS_mn.shape[0]))
ncol = const_expr(cute.size(tScS_mn.shape[1]))
thr_col_offset = tScS_mn[0, 0][1]
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
)
wrap_aux_indices = const_expr(
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
)

for r in cutlass.range_constexpr(nrow):
global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
row_for_mod = global_row_idx
if const_expr(wrap_aux_indices):
_, row_for_mod = fastdiv_mods[0].divmod(global_row_idx)

for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][1]
# Convert to absolute column index
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
col_for_mod = global_col_idx
if const_expr(wrap_aux_indices):
_, col_for_mod = fastdiv_mods[1].divmod(global_col_idx)

batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
q_idx_ssa = utils.scalar_to_ssa(
tScS_mn[r, 0][0] + m_block * self.tile_m, cutlass.Int32
)
kv_idx_ssa = utils.scalar_to_ssa(
thr_col_offset + t0ScS_mn[0, col][1] + n_block * self.tile_n,
cutlass.Int32,
)
q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
Expand Down Expand Up @@ -302,6 +312,7 @@ def apply_mask_sm100(
batch_idx: Int32 = None,
head_idx: Int32 = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_shape = (self.tile_m, self.tile_n)
Expand All @@ -328,6 +339,14 @@ def apply_mask_sm100(

elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse case w/ mask_mod
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
)
wrap_aux_indices = const_expr(
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
row_coord_first = tScS_t2r[0][0]
Expand All @@ -336,17 +355,24 @@ def apply_mask_sm100(
mask_row = global_row // self.qhead_per_kvhead_packgqa
else:
mask_row = global_row
mask_row_ssa = utils.scalar_to_ssa(mask_row, cutlass.Int32)
mask_row_for_mod = mask_row
if const_expr(wrap_aux_indices):
_, mask_row_for_mod = fastdiv_mods[0].divmod(mask_row)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)

ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_col = col_coord + n_block * self.tile_n
global_col_for_mod = global_col
if const_expr(wrap_aux_indices):
_, global_col_for_mod = fastdiv_mods[1].divmod(global_col)
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
mask_row_ssa,
utils.scalar_to_ssa(global_col, cutlass.Int32),
kv_idx_ssa,
aux_tensors,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
Expand Down
18 changes: 18 additions & 0 deletions flash_attn/cute/mask_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,23 @@ def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx):
return in_window & dilated


def flex_ima_mask(b, h, q_idx, kv_idx, bias):
return kv_idx >= bias[kv_idx]


@cute.jit
def cute_ima_mask(
batch: cute.TensorSSA,
head: cute.TensorSSA,
m_idx: cute.TensorSSA,
n_idx: cute.TensorSSA,
aux_tensors,
) -> cute.TensorSSA:
bias = aux_tensors[0]
threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32)
return n_idx >= threshold


def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"):
doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device)
for b in range(batch):
Expand All @@ -226,6 +243,7 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"):
"prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask),
"dilated_sliding_window": (cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask),
"document": (cute_document_mask, flex_document_mask),
"ima": (cute_ima_mask, flex_ima_mask),
}

PARAMETERIZED_MASK_FACTORIES = {
Expand Down
26 changes: 26 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids):
return original_flex_mask(b, h, q_idx, kv_idx, doc_ids)

aux_tensors_arg = [doc_ids]
elif mask_name == "ima":
bias_threshold = (seqlen_k // 4) * 3
bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device="cuda")
original_flex_mask = mask_mod_flex

def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
return original_flex_mask(b, h, q_idx, kv_idx, bias)

aux_tensors_arg = [bias]
causal = False

if causal and seqlen_k < seqlen_q:
Expand Down Expand Up @@ -347,6 +356,23 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids):
)


def test_mask_mod_ima_partial_block():
_run_mask_test(
seqlen_q=257,
seqlen_k=257,
nheads=1,
kv_mode="mha",
headdim=128,
dtype=torch.bfloat16,
mask_name="ima",
window_size=None,
window_left=None,
window_right=None,
tile_m=128,
tile_n=128,
)


@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE)
@pytest.mark.parametrize("nheads", [16])
@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"])
Expand Down