Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ repos:
interface|
pack_gqa|
testing|
tile_scheduler|
utils
)\.py$
- id: ruff-format
Expand Down
5 changes: 4 additions & 1 deletion flash_attn/cute/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "flash-attn-cute"
version = "0.1.0"
description = "Flash Attention CUTE (CUDA Template Engine) implementation"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.10"
license = {text = "BSD 3-Clause License"}
authors = [
{name = "Tri Dao"},
Expand All @@ -16,13 +16,16 @@ classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

dependencies = [
"nvidia-cutlass-dsl==4.3.0.dev0",
"torch",
"einops",
"typing_extensions",
]

[project.optional-dependencies]
Expand Down
64 changes: 49 additions & 15 deletions flash_attn/cute/tile_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from typing import Optional, Tuple
from dataclasses import dataclass, fields
from typing import override

try:
Copy link
Collaborator Author

@drisspg drisspg Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the only change the rest was just adding this file to ruff which I though I would do while Im here

from typing import override
except ImportError: # Python < 3.12
from typing_extensions import override

import cutlass
from cutlass._mlir import ir
Expand Down Expand Up @@ -120,7 +124,11 @@ def get_grid_shape(
) -> Tuple[Int32, Int32, Int32]:
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
return cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head * params.num_splits, params.num_batch
return (
cute.round_up(params.num_block, params.cluster_shape_mn[0]),
params.num_head * params.num_splits,
params.num_batch,
)

def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
block_idx, head_idx, batch_idx = self._blk_coord
Expand Down Expand Up @@ -231,7 +239,10 @@ def __extract_mlir_values__(self):

def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip([self.params, self._tile_idx], self._values_pos,):
for obj, n_items in zip(
[self.params, self._tile_idx],
self._values_pos,
):
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
Expand Down Expand Up @@ -382,7 +393,9 @@ def create(
num_hb_remainder = (args.num_head * args.num_batch) % swizzle
num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])
return SingleTileLPTBwdScheduler.Params(
total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch,
total_blocks=(num_block * args.cluster_shape_mn[0])
* args.num_head
* args.num_batch,
num_head_divmod=FastDivmod.create(args.num_head),
l2_minor_divmod=FastDivmod.create(swizzle),
l2_major_divmod=FastDivmod.create(swizzle * num_block),
Expand Down Expand Up @@ -437,9 +450,7 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
is_valid = self._tile_idx < params.total_blocks
bidx_in_cluster = cute.arch.block_in_cluster_idx()
block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]
return WorkTileInfo(
(Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
)
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)

def initial_work_tile_info(self, *, loc=None, ip=None):
return self.get_current_work(loc=loc, ip=ip)
Expand Down Expand Up @@ -488,7 +499,9 @@ def create(
args: TileSchedulerArguments, *, loc=None, ip=None
) -> "SingleTileVarlenScheduler.Params":
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1])
max_kvblock_in_l2 = size_l2 // (
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
)
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
)
Expand Down Expand Up @@ -610,16 +623,37 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
# the seqlen can vary per batch.
# TODO: is there any case where num_m_blocks is 0?
# TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
num_n_blocks = num_m_blocks * params.tile_shape_mn[0] // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1]
num_n_blocks = (
num_m_blocks
* params.tile_shape_mn[0]
// params.qhead_per_kvhead_packgqa
// params.tile_shape_mn[1]
)
# nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
# Seems faster to have this be a power of 2
nheads_in_l2 = 16 if num_n_blocks * 16 <= params.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= params.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= params.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)))
nheads_in_l2 = (
16
if num_n_blocks * 16 <= params.max_kvblock_in_l2
else (
8
if num_n_blocks * 8 <= params.max_kvblock_in_l2
else (
4
if num_n_blocks * 4 <= params.max_kvblock_in_l2
else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
)
)
)
nheads_in_l2 = min(nheads_in_l2, params.num_head)
mh_in_l2 = nheads_in_l2 * num_m_blocks
section_idx = mh_block // mh_in_l2
l2_mod = mh_block - section_idx * mh_in_l2
# Deal with tail section
nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= params.num_head else params.num_head - section_idx * nheads_in_l2
nheads_in_this_section = (
nheads_in_l2
if nheads_in_l2 * (section_idx + 1) <= params.num_head
else params.num_head - section_idx * nheads_in_l2
)
block = l2_mod // nheads_in_this_section
head_idx_residual = l2_mod - block * nheads_in_this_section
head_idx = section_idx * nheads_in_l2 + head_idx_residual
Expand All @@ -630,9 +664,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
is_valid = self._is_first_block and batch_idx < params.num_batch
# if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
return WorkTileInfo(
(Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid
)
return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)

def initial_work_tile_info(self, *, loc=None, ip=None):
return self.get_current_work(loc=loc, ip=ip)
Expand All @@ -654,7 +686,9 @@ def __extract_mlir_values__(self):

def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos,
for obj, n_items in zip(
[self.params, self._tile_idx, self._split_idx],
self._values_pos,
):
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
Expand Down