Skip to content

Commit 1cc40ae

Browse files
committed
[CUTE]Bump to Cutedsl
stack-info: PR: #2216, branch: drisspg/stack/7
1 parent 97b7ed6 commit 1cc40ae

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

flash_attn/cute/pipeline.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import cutlass.cute as cute
99
from cutlass import Boolean, Int32, const_expr
1010
from cutlass.cutlass_dsl import if_generate
11-
from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup
11+
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
1212
from cutlass.pipeline import PipelineUserType, PipelineOp
1313
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
1414
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
@@ -150,19 +150,24 @@ def producer_acquire(
150150
state: PipelineState,
151151
try_acquire_token: Optional[Boolean] = None,
152152
extra_tx_count: int = 0,
153+
*,
154+
loc=None,
155+
ip=None,
153156
):
154157
"""
155158
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
156159
"""
157160
if_generate(
158161
try_acquire_token is None or try_acquire_token == 0,
159-
lambda: self.sync_object_empty.wait(state.index, state.phase),
162+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
163+
loc=loc,
164+
ip=ip,
160165
)
161166
if const_expr(extra_tx_count == 0):
162-
self.sync_object_full.arrive(state.index, self.producer_mask)
167+
self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
163168
else:
164169
tx_count = self.sync_object_full.tx_count + extra_tx_count
165-
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count)
170+
self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
166171

167172

168173
@dataclass(frozen=True)
@@ -207,10 +212,10 @@ def create(
207212
producer = (producer_type, producer_group)
208213
consumer = (consumer_type, consumer_group)
209214

210-
sync_object_full = PipelineAsync._make_sync_object(
215+
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
211216
barrier_storage.align(min_align=8), num_stages, producer, tx_count
212217
)
213-
sync_object_empty = PipelineAsync._make_sync_object(
218+
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
214219
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
215220
)
216221

@@ -251,22 +256,35 @@ def producer_acquire(
251256
state: PipelineState,
252257
try_acquire_token: Optional[Boolean] = None,
253258
extra_tx_count: int = 0,
259+
*,
260+
loc=None,
261+
ip=None,
254262
):
255263
"""
256264
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
257265
"""
258266
if_generate(
259267
try_acquire_token is None or try_acquire_token == 0,
260-
lambda: self.sync_object_empty.wait(state.index, state.phase),
268+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
269+
loc=loc,
270+
ip=ip,
261271
)
262272
if const_expr(extra_tx_count == 0):
263273
if_generate(
264274
self.is_leader_cta,
265-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
275+
lambda: self.sync_object_full.arrive(
276+
state.index, self.producer_mask, loc=loc, ip=ip
277+
),
278+
loc=loc,
279+
ip=ip,
266280
)
267281
else:
268282
tx_count = self.sync_object_full.tx_count + extra_tx_count
269283
if_generate(
270284
self.is_leader_cta,
271-
lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count),
285+
lambda: self.sync_object_full.arrive_and_expect_tx(
286+
state.index, tx_count, loc=loc, ip=ip
287+
),
288+
loc=loc,
289+
ip=ip,
272290
)

flash_attn/cute/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ classifiers = [
2222
]
2323

2424
dependencies = [
25-
"nvidia-cutlass-dsl>=4.3.5,<4.4.0",
25+
"nvidia-cutlass-dsl>=4.4.0.dev0",
2626
"torch",
2727
"einops",
2828
"typing_extensions",
2929
"apache-tvm-ffi>=0.1.5,<0.2",
3030
"torch-c-dlpack-ext",
31-
"quack-kernels==0.2.4",
31+
"quack-kernels>=0.2.4",
3232
]
3333

3434
[project.optional-dependencies]

0 commit comments

Comments
 (0)