Skip to content

Commit abaa878

Browse files
drisspgcursoragent
andauthored
[CUTE]Bump to Cutedsl (#2216)
Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 912c6c4 commit abaa878

File tree

11 files changed

+66
-123
lines changed

11 files changed

+66
-123
lines changed

flash_attn/cute/cute_dsl_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,24 @@ def cute_compile_patched(*args, **kwargs):
124124
return output
125125

126126

127+
def assume_strides_aligned(t):
128+
"""Assume all strides except the last are divisible by 128 bits.
129+
130+
Python int strides (e.g., stride=0 from GQA expand) are kept as-is
131+
since they're static and don't need alignment assumptions.
132+
"""
133+
divby = 128 // t.element_type.width
134+
strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
135+
return (*strides, t.stride[-1])
136+
137+
138+
def assume_tensor_aligned(t):
139+
"""Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
140+
if t is None:
141+
return None
142+
return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
143+
144+
127145
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
128146
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
129147
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)

flash_attn/cute/flash_bwd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import cutlass.utils as utils_basic
1616

1717
from flash_attn.cute import ampere_helpers as sm80_utils
18+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1819
from flash_attn.cute import utils
1920
from flash_attn.cute.mask import AttentionMask
2021
from flash_attn.cute.seqlen_info import SeqlenInfoQK
@@ -383,10 +384,9 @@ def __call__(
383384
# Get the data type and check if it is fp16 or bf16
384385
self._check_type(*(t.element_type if t is not None else None
385386
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
386-
# Assume all strides are divisible by 128 bits except the last stride
387-
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
388-
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1]), t.stride[-1])
389-
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)]
387+
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
388+
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
389+
]
390390
self.varlen_q = (mCuSeqlensQ is not None)
391391
self._setup_attributes()
392392
SharedStorage = self._get_shared_storage_cls()

flash_attn/cute/flash_bwd_postprocess.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from cutlass.utils import LayoutEnum
1616

1717
from flash_attn.cute import utils
18+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1819
from flash_attn.cute import copy_utils
1920
from flash_attn.cute import ampere_helpers as sm80_utils
2021
from flash_attn.cute import hopper_helpers as sm90_utils
@@ -211,15 +212,7 @@ def __call__(
211212
if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
212213
raise TypeError("dQaccum tensor must be Float32")
213214

214-
# Assume all strides are divisible by 128 bits except the last stride
215-
new_stride = lambda t: (
216-
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
217-
t.stride[-1],
218-
)
219-
mdQaccum, mdQ = [
220-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
221-
for t in (mdQaccum, mdQ)
222-
]
215+
mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)]
223216

224217
self.tiled_mma = self._get_tiled_mma()
225218
self._setup_attributes()

flash_attn/cute/flash_bwd_preprocess.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from cutlass import Float32
1313

1414
from flash_attn.cute import utils
15+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1516
from flash_attn.cute import copy_utils
1617
from flash_attn.cute.seqlen_info import SeqlenInfoQK
1718
from flash_attn.cute.tile_scheduler import (
@@ -135,17 +136,7 @@ def __call__(
135136
if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
136137
raise TypeError("LSElog2 tensor must be Float32")
137138

138-
# Assume all strides are divisible by 128 bits except the last stride
139-
new_stride = lambda t: (
140-
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
141-
t.stride[-1],
142-
)
143-
mO, mdO, mdQaccum = [
144-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
145-
if t is not None
146-
else None
147-
for t in (mO, mdO, mdQaccum)
148-
]
139+
mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)]
149140

150141
self._setup_attributes()
151142

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from cutlass.pipeline import PipelineAsync, PipelineConsumer
1616

1717
from flash_attn.cute import utils
18+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1819
from flash_attn.cute import copy_utils
1920
from flash_attn.cute import pipeline
2021
from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
@@ -411,29 +412,7 @@ def __call__(
411412
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
412413
assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"
413414

414-
# Assume all strides are divisible by 128 bits except the last stride
415-
# Skip assume for Python ints (e.g., stride=0 from GQA expand)
416-
new_stride = lambda t: (
417-
*(
418-
s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width)
419-
for s in t.stride[:-1]
420-
),
421-
t.stride[-1],
422-
)
423-
(
424-
mdQaccum,
425-
mdK,
426-
mdV,
427-
) = [
428-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
429-
if t is not None
430-
else None
431-
for t in (
432-
mdQaccum,
433-
mdK,
434-
mdV,
435-
)
436-
]
415+
mdQaccum, mdK, mdV = [assume_tensor_aligned(t) for t in (mdQaccum, mdK, mdV)]
437416

438417
# (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
439418
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cutlass.utils import LayoutEnum
1515

1616
from flash_attn.cute import hopper_helpers as sm90_utils
17+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1718
from flash_attn.cute import utils
1819
from flash_attn.cute import copy_utils
1920
from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx
@@ -350,22 +351,8 @@ def __call__(
350351
)
351352
)
352353

353-
# Assume all strides are divisible by 128 bits except the last stride
354-
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
355-
new_stride = lambda t: (
356-
*(
357-
cute.assume(s, divby=128 // t.element_type.width)
358-
if not isinstance(s, int) or s != 0
359-
else s
360-
for s in t.stride[:-1]
361-
),
362-
t.stride[-1],
363-
)
364354
mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
365-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
366-
if t is not None
367-
else None
368-
for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
355+
assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
369356
]
370357

371358
layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)

flash_attn/cute/flash_fwd.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from quack import copy_utils as quack_copy_utils
2525

2626
from flash_attn.cute import ampere_helpers as sm80_utils
27+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
2728
from flash_attn.cute import hopper_helpers as sm90_utils
2829
from flash_attn.cute import utils
2930
from flash_attn.cute import copy_utils
@@ -660,21 +661,7 @@ def __call__(
660661
self.use_tma_O = self.arch >= 90
661662
self._setup_attributes()
662663
SharedStorage = self._get_shared_storage_cls()
663-
# Assume all strides are divisible by 128 bits except the last stride
664-
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
665-
new_stride = lambda t: (
666-
*(
667-
cute.assume(s, divby=128 // t.element_type.width)
668-
if not isinstance(s, int) or s != 0
669-
else s
670-
for s in t.stride[:-1]
671-
),
672-
t.stride[-1],
673-
)
674-
mQ, mK, mV, mO = [
675-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
676-
for t in (mQ, mK, mV, mO)
677-
]
664+
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
678665
mQ, mK, mV, mO = [
679666
cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0]))
680667
for t in (mQ, mK, mV, mO)
@@ -1303,22 +1290,7 @@ def __call__(
13031290
)
13041291
)
13051292

1306-
# Assume all strides are divisible by 128 bits except the last stride
1307-
# Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints)
1308-
new_stride = lambda t: (
1309-
*(
1310-
cute.assume(s, divby=128 // t.element_type.width)
1311-
if not isinstance(s, int) or s != 0
1312-
else s
1313-
for s in t.stride[:-1]
1314-
),
1315-
t.stride[-1],
1316-
)
1317-
1318-
mQ, mK, mV, mO = [
1319-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
1320-
for t in (mQ, mK, mV, mO)
1321-
]
1293+
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
13221294
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
13231295
mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
13241296
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]

flash_attn/cute/flash_fwd_combine.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cutlass import Float32, Int32, const_expr
1515

1616
from flash_attn.cute import utils
17+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
1718
from flash_attn.cute.seqlen_info import SeqlenInfo
1819
from cutlass.cute import FastDivmodDivisor
1920

@@ -232,15 +233,7 @@ def __call__(
232233
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
233234
)
234235

235-
# Assume all strides are divisible by 128 bits except the last stride
236-
new_stride = lambda t: (
237-
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
238-
t.stride[-1],
239-
)
240-
mO_partial, mO = [
241-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
242-
for t in (mO_partial, mO)
243-
]
236+
mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
244237
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
245238
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
246239
O_partial_layout_transpose = (

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from flash_attn.cute.paged_kv import PagedKVManager
3131
import flash_attn.cute.utils as utils
32+
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
3233
from flash_attn.cute import copy_utils
3334
import flash_attn.cute.pipeline as pipeline
3435
from flash_attn.cute.mask import AttentionMask
@@ -297,16 +298,7 @@ def __call__(
297298
self.k_dtype = mK.element_type
298299
self.v_dtype = mV.element_type
299300
self.o_dtype = mO.element_type
300-
# Assume all strides are divisible by 128 bits except the last stride
301-
# Skip assume for Python ints (e.g., stride=0 from GQA expand)
302-
new_stride = lambda t: (
303-
*(s if isinstance(s, int) else cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
304-
t.stride[-1],
305-
)
306-
mQ, mK, mV, mO = [
307-
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
308-
for t in (mQ, mK, mV, mO)
309-
]
301+
mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
310302
Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
311303
mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose))
312304
# (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table

flash_attn/cute/pipeline.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import cutlass.cute as cute
99
from cutlass import Boolean, Int32, const_expr
1010
from cutlass.cutlass_dsl import if_generate
11-
from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup
11+
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
1212
from cutlass.pipeline import PipelineUserType, PipelineOp
1313
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
1414
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
@@ -150,19 +150,24 @@ def producer_acquire(
150150
state: PipelineState,
151151
try_acquire_token: Optional[Boolean] = None,
152152
extra_tx_count: int = 0,
153+
*,
154+
loc=None,
155+
ip=None,
153156
):
154157
"""
155158
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
156159
"""
157160
if_generate(
158161
try_acquire_token is None or try_acquire_token == 0,
159-
lambda: self.sync_object_empty.wait(state.index, state.phase),
162+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
163+
loc=loc,
164+
ip=ip,
160165
)
161166
if const_expr(extra_tx_count == 0):
162-
self.sync_object_full.arrive(state.index, self.producer_mask)
167+
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
163168
else:
164169
tx_count = self.sync_object_full.tx_count + extra_tx_count
165-
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count)
170+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
166171

167172

168173
@dataclass(frozen=True)
@@ -207,10 +212,10 @@ def create(
207212
producer = (producer_type, producer_group)
208213
consumer = (consumer_type, consumer_group)
209214

210-
sync_object_full = PipelineAsync._make_sync_object(
215+
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
211216
barrier_storage.align(min_align=8), num_stages, producer, tx_count
212217
)
213-
sync_object_empty = PipelineAsync._make_sync_object(
218+
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
214219
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
215220
)
216221

@@ -251,22 +256,35 @@ def producer_acquire(
251256
state: PipelineState,
252257
try_acquire_token: Optional[Boolean] = None,
253258
extra_tx_count: int = 0,
259+
*,
260+
loc=None,
261+
ip=None,
254262
):
255263
"""
256264
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
257265
"""
258266
if_generate(
259267
try_acquire_token is None or try_acquire_token == 0,
260-
lambda: self.sync_object_empty.wait(state.index, state.phase),
268+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
269+
loc=loc,
270+
ip=ip,
261271
)
262272
if const_expr(extra_tx_count == 0):
263273
if_generate(
264274
self.is_leader_cta,
265-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
275+
lambda: self.sync_object_full.arrive(
276+
state.index, self.producer_mask, loc=loc, ip=ip
277+
),
278+
loc=loc,
279+
ip=ip,
266280
)
267281
else:
268282
tx_count = self.sync_object_full.tx_count + extra_tx_count
269283
if_generate(
270284
self.is_leader_cta,
271-
lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count),
285+
lambda: self.sync_object_full.arrive_and_expect_tx(
286+
state.index, tx_count, loc=loc, ip=ip
287+
),
288+
loc=loc,
289+
ip=ip,
272290
)

0 commit comments

Comments
 (0)