Skip to content
3 changes: 3 additions & 0 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def consume_block_sparse_loads(
batch_idx,
head_idx,
m_block,
seqlen,
kv_consumer_state,
mma_pv_fn,
mma_one_n_block,
Expand Down Expand Up @@ -380,6 +381,7 @@ def consume_block_sparse_loads(
kv_consumer_state = process_first_half_block(
n_block=mask_n_block,
kv_consumer_state=kv_consumer_state,
seqlen=seqlen,
mask_fn=partial(
mask_fn,
mask_mod=mask_mod,
Expand All @@ -405,6 +407,7 @@ def consume_block_sparse_loads(
kv_consumer_state = process_first_half_block(
n_block=full_n_block,
kv_consumer_state=kv_consumer_state,
seqlen=seqlen,
mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
score_mod_fn=score_mod_fn,
is_first_block=True,
Expand Down
Loading
Loading