Skip to content

Commit 91ba87d

Browse files
authored
ruff all the smaller files (#2040)
1 parent 672381f commit 91ba87d

File tree

7 files changed

+193
-90
lines changed

7 files changed

+193
-90
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,10 @@ repos:
77
files: ^flash_attn/cute/.*\.py$
88
exclude: &cute_exclude |
99
(?x)^flash_attn/cute/(
10-
__init__|
11-
copy_utils|
12-
cute_dsl_utils|
13-
fast_math|
1410
flash_bwd|
1511
flash_fwd|
16-
flash_fwd_combine|
1712
flash_fwd_sm100|
18-
hopper_helpers|
1913
interface|
20-
pack_gqa|
21-
testing|
22-
utils
2314
)\.py$
2415
- id: ruff-format
2516
files: ^flash_attn/cute/.*\.py$

flash_attn/cute/copy_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
22

33
import math
4-
from typing import Optional, Type, Tuple, Callable
4+
from typing import Optional, Type, Callable
55

66
import cutlass
77
import cutlass.cute as cute
8-
from cutlass import Float32, Int32, Boolean, const_expr
8+
from cutlass import Float32, Int32, const_expr
99
from cutlass.cute.nvgpu import cpasync
1010
import cutlass.utils.blackwell_helpers as sm100_utils
1111
from cutlass.cutlass_dsl import T, dsl_user_op
@@ -279,7 +279,7 @@ def copy_bulk(src_idx, dst_idx, **new_kwargs):
279279
dst[None, dst_idx].iterator,
280280
size=size,
281281
**new_kwargs,
282-
**kwargs
282+
**kwargs,
283283
)
284284

285285
def copy_bulk_single_stage(**new_kwargs):

flash_attn/cute/flash_fwd_combine.py

Lines changed: 107 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ def __init__(
5555

5656
@staticmethod
5757
def can_implement(
58-
dtype, dtype_partial, head_dim, m_block_size, k_block_size,
59-
log_max_splits, num_threads,
58+
dtype,
59+
dtype_partial,
60+
head_dim,
61+
m_block_size,
62+
k_block_size,
63+
log_max_splits,
64+
num_threads,
6065
) -> bool:
6166
"""Check if the kernel can be implemented with the given parameters."""
6267
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
@@ -83,8 +88,7 @@ def _setup_attributes(self):
8388
assert self.k_block_size % async_copy_elems == 0
8489

8590
k_block_gmem = (
86-
128 if self.k_block_size % 128 == 0 else
87-
(64 if self.k_block_size % 64 == 0 else 32)
91+
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
8892
)
8993
gmem_threads_per_row = k_block_gmem // async_copy_elems
9094
assert self.num_threads % gmem_threads_per_row == 0
@@ -111,16 +115,25 @@ def _setup_attributes(self):
111115
num_bits_per_copy=async_copy_elems * self.dtype.width,
112116
)
113117
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
114-
atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store
118+
atom_universal_copy,
119+
tOpartial_layout,
120+
vOpartial_layout, # 4 vals per store
115121
)
116122

117123
# LSE copy setup with async copy (alignment = 1)
118124
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
119125
m_block_smem = (
120-
128 if self.m_block_size % 128 == 0 else
121-
(64 if self.m_block_size % 64 == 0 else
122-
(32 if self.m_block_size % 32 == 0 else
123-
(16 if self.m_block_size % 16 == 0 else 8)))
126+
128
127+
if self.m_block_size % 128 == 0
128+
else (
129+
64
130+
if self.m_block_size % 64 == 0
131+
else (
132+
32
133+
if self.m_block_size % 32 == 0
134+
else (16 if self.m_block_size % 16 == 0 else 8)
135+
)
136+
)
124137
)
125138
gmem_threads_per_row_lse = m_block_smem
126139
assert self.num_threads % gmem_threads_per_row_lse == 0
@@ -167,21 +180,17 @@ def _setup_attributes(self):
167180
else:
168181
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
169182
smem_layout_atom_lse = cute.make_composed_layout(
170-
smem_lse_swizzle,
171-
0,
172-
cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
183+
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
173184
)
174185
self.smem_layout_lse = cute.tile_to_shape(
175186
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
176187
)
177188

178189
# O partial shared memory layout (simple layout for pipeline stages)
179190
self.smem_layout_o = cute.make_ordered_layout(
180-
(self.m_block_size, self.k_block_size, self.stages),
181-
order=(1, 0, 2)
191+
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
182192
)
183193

184-
185194
@cute.jit
186195
def __call__(
187196
self,
@@ -200,38 +209,63 @@ def __call__(
200209
raise TypeError("O partial tensor must match dtype_partial")
201210
if const_expr(not (mO.element_type == self.dtype)):
202211
raise TypeError("O tensor must match dtype")
203-
if const_expr(not mLSE_partial.element_type in [Float32]):
212+
if const_expr(mLSE_partial.element_type not in [Float32]):
204213
raise TypeError("LSE partial tensor must be Float32")
205-
if const_expr(mLSE is not None and not mLSE.element_type in [Float32]):
214+
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
206215
raise TypeError("LSE tensor must be Float32")
207216

208217
# Shape validation - input tensors are in user format, need to be converted to kernel format
209218
if const_expr(len(mO_partial.shape) not in [4, 5]):
210-
raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)")
219+
raise ValueError(
220+
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
221+
)
211222
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
212-
raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)")
223+
raise ValueError(
224+
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
225+
)
213226
if const_expr(len(mO.shape) not in [3, 4]):
214-
raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)")
227+
raise ValueError(
228+
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
229+
)
215230
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
216-
raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)")
231+
raise ValueError(
232+
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
233+
)
217234

218235
# Assume all strides are divisible by 128 bits except the last stride
219-
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
220-
mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)]
236+
new_stride = lambda t: (
237+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
238+
t.stride[-1],
239+
)
240+
mO_partial, mO = [
241+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
242+
for t in (mO_partial, mO)
243+
]
221244
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
222245
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
223-
O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
246+
O_partial_layout_transpose = (
247+
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
248+
)
224249
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
225-
mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose))
250+
mO_partial = cute.make_tensor(
251+
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
252+
)
226253
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
227254
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
228255
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
229256
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
230257
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
231-
mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose))
258+
mLSE_partial = cute.make_tensor(
259+
mLSE_partial.iterator,
260+
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
261+
)
232262
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
233263
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
234-
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None
264+
mLSE = (
265+
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
266+
if mLSE is not None
267+
else None
268+
)
235269

236270
# Determine if we have variable length sequences
237271
varlen = const_expr(cu_seqlens is not None or seqused is not None)
@@ -243,9 +277,7 @@ class SharedStorage:
243277
sLSE: cute.struct.Align[
244278
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
245279
]
246-
sMaxValidSplit: cute.struct.Align[
247-
cute.struct.MemRange[Int32, self.m_block_size], 128
248-
]
280+
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
249281
sO: cute.struct.Align[
250282
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
251283
]
@@ -255,7 +287,11 @@ class SharedStorage:
255287
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
256288
seqlen = mO_partial.shape[0]
257289
num_head = mO_partial.shape[3]
258-
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)
290+
batch_size = (
291+
mO_partial.shape[4]
292+
if const_expr(cu_seqlens is None)
293+
else Int32(cu_seqlens.shape[0] - 1)
294+
)
259295

260296
# Create FastDivmodDivisor objects for efficient division
261297
seqlen_divmod = FastDivmodDivisor(seqlen)
@@ -330,22 +366,26 @@ def kernel(
330366

331367
# Handle semaphore reset
332368
if const_expr(semaphore_to_reset is not None):
333-
if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and
334-
k_block == cute.arch.grid_dim()[1] - 1 and
335-
batch_idx == cute.arch.grid_dim()[2] - 1):
369+
if (
370+
tidx == 0
371+
and m_block == cute.arch.grid_dim()[0] - 1
372+
and k_block == cute.arch.grid_dim()[1] - 1
373+
and batch_idx == cute.arch.grid_dim()[2] - 1
374+
):
336375
semaphore_to_reset[0] = 0
337376

338377
# Get number of splits
339378
num_splits = (
340-
num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None)
379+
num_splits_dynamic_ptr[batch_idx]
380+
if const_expr(num_splits_dynamic_ptr is not None)
341381
else mLSE_partial.shape[1]
342382
)
343383
# Handle variable length sequences using SeqlenInfo
344384
seqlen_info = SeqlenInfo.create(
345385
batch_idx=batch_idx,
346386
seqlen_static=mO_partial.shape[0],
347387
cu_seqlens=cu_seqlens,
348-
seqused=seqused
388+
seqused=seqused,
349389
)
350390
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
351391

@@ -354,8 +394,9 @@ def kernel(
354394
max_idx = seqlen * num_head
355395

356396
# Early exit for single split if dynamic
357-
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx):
358-
397+
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
398+
const_expr(not varlen) or m_block * self.m_block_size < max_idx
399+
):
359400
# ===============================
360401
# Step 1: Load LSE_partial from gmem to shared memory
361402
# ===============================
@@ -390,7 +431,11 @@ def kernel(
390431
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
391432
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
392433
if si < num_splits:
393-
cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m])
434+
cute.copy(
435+
gmem_thr_copy_LSE,
436+
mLSE_partial_cur_copy[None, si],
437+
tLSEsLSE[None, s, m],
438+
)
394439
else:
395440
tLSEsLSE[None, s, m].fill(-Float32.inf)
396441
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
@@ -424,7 +469,9 @@ def kernel(
424469
else:
425470
tOhidx[m] = idx // seqlen
426471
tOmidx[m] = idx - tOhidx[m] * seqlen
427-
tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint()
472+
tOrOptr[m] = utils.elem_pointer_i64(
473+
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
474+
).toint()
428475
if idx >= max_idx:
429476
tOhidx[m] = -1
430477

@@ -483,7 +530,9 @@ def kernel(
483530
# Find max LSE value across splits
484531
threads_per_col = const_expr(self.smem_threads_per_col_lse)
485532
lse_max = utils.warp_reduce(
486-
ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
533+
ts2rrLSE[None, None, m]
534+
.load()
535+
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
487536
op=cute.arch.fmax,
488537
width=threads_per_col,
489538
)
@@ -496,7 +545,9 @@ def kernel(
496545
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
497546
max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
498547
# Compute exp scales and sum
499-
lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf
548+
lse_max_cur = (
549+
0.0 if lse_max == -Float32.inf else lse_max
550+
) # In case all local LSEs are -inf
500551
LOG2_E = math.log2(math.e)
501552
lse_sum_cur = 0.0
502553
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
@@ -506,7 +557,9 @@ def kernel(
506557
lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
507558
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
508559
# Normalize scales
509-
inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
560+
inv_sum = (
561+
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
562+
)
510563
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
511564
# Store the scales exp(lse - lse_logsum) back to smem
512565
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
@@ -584,7 +637,10 @@ def kernel(
584637
# Accumulate scaled partial results
585638
for m in cutlass.range(num_rows, unroll_full=True):
586639
if tOhidx[m] >= 0 and scale[m] > 0.0:
587-
tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32))
640+
tOrO[None, m, None].store(
641+
tOrO[None, m, None].load()
642+
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
643+
)
588644

589645
# ===============================
590646
# Step 7: Write final O to gmem
@@ -605,7 +661,9 @@ def kernel(
605661
# Write final results
606662
for m in cutlass.range(num_rows, unroll_full=True):
607663
if tOhidx[m] >= 0:
608-
mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,))
664+
mO_cur_copy = cute.tiled_divide(
665+
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
666+
)
609667
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
610668
k_idx = tOcO[0, 0, k][1] // elems_per_store
611669
if const_expr(self.is_even_k) or tOpO[k]:
@@ -631,7 +689,9 @@ def load_O_partial(
631689
o_gmem_ptr = cute.make_ptr(
632690
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
633691
)
634-
mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)))
692+
mO_partial_cur = cute.make_tensor(
693+
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
694+
)
635695
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
636696
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
637697
k_idx = tOcO[0, 0, k][1] // elems_per_load
@@ -640,5 +700,5 @@ def load_O_partial(
640700
gmem_tiled_copy_O_partial,
641701
# mO_partial_cur_copy[None, k_idx, split],
642702
utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
643-
tOsO_partial_cur[None, m, k]
703+
tOsO_partial_cur[None, m, k],
644704
)

flash_attn/cute/hopper_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import cutlass.cute as cute
55
from cutlass import Int32, Float32, Boolean, const_expr
66
from cutlass.cute.nvgpu import warpgroup
7-
from cutlass._mlir.dialects import llvm
87
from cutlass.cutlass_dsl import Numeric, dsl_user_op
98
from cutlass.utils import LayoutEnum
109
import cutlass.utils.hopper_helpers as sm90_utils_og

flash_attn/cute/pack_gqa.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Copyright (c) 2025, Tri Dao.
22

3-
import math
4-
import operator
53

64
import cutlass
75
import cutlass.cute as cute

0 commit comments

Comments
 (0)