Skip to content

Commit deb1830

Browse files
committed
[Bwd,Sm100] Shorten PipelineTmaUmma create
1 parent c912a37 commit deb1830

File tree

3 files changed

+17
-118
lines changed

3 files changed

+17
-118
lines changed

flash_attn/cute/flash_bwd_sm100.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ def kernel(
957957
consumer_group=pipeline_consumer_group_compute,
958958
tx_count=self.tma_copy_bytes["LSE"],
959959
# cta_layout_vmnk=cluster_layout_vmnk,
960-
# init_wait=False,
960+
defer_sync=True,
961961
)
962962
pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(
963963
barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),
@@ -966,7 +966,7 @@ def kernel(
966966
consumer_group=pipeline_consumer_group_compute,
967967
tx_count=self.tma_copy_bytes["dPsum"],
968968
# cta_layout_vmnk=cluster_layout_vmnk,
969-
# init_wait=False,
969+
defer_sync=True,
970970
)
971971
pipeline_Q = pipeline.PipelineTmaUmma.create(
972972
barrier_storage=storage.Q_mbar_ptr.data_ptr(),
@@ -975,7 +975,7 @@ def kernel(
975975
consumer_group=pipeline_consumer_group,
976976
tx_count=self.tma_copy_bytes["Q"],
977977
cta_layout_vmnk=cluster_layout_vmnk,
978-
init_wait=False,
978+
defer_sync=True,
979979
)
980980
pipeline_dO = pipeline.PipelineTmaUmma.create(
981981
barrier_storage=storage.dO_mbar_ptr.data_ptr(),
@@ -984,7 +984,7 @@ def kernel(
984984
consumer_group=pipeline_consumer_group,
985985
tx_count=self.tma_copy_bytes["dO"],
986986
cta_layout_vmnk=cluster_layout_vmnk,
987-
init_wait=True,
987+
defer_sync=False,
988988
)
989989

990990
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ def mma(
10681068
)
10691069

10701070
# Smem copy atom tiling
1071-
smem_copy_atom_PdS = utils.get_smem_store_atom(
1071+
smem_copy_atom_PdS = copy_utils.get_smem_store_atom(
10721072
self.arch, self.dtype, transpose=self.SdP_swapAB
10731073
)
10741074
smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice(

flash_attn/cute/pipeline.py

Lines changed: 12 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,14 @@
44
from typing import Optional
55
from dataclasses import dataclass
66

7-
import cutlass
8-
import cutlass.cute as cute
97
from cutlass import Boolean, Int32, const_expr
108
from cutlass.cutlass_dsl import if_generate
11-
from cutlass.pipeline import PipelineState, Agent, CooperativeGroup
12-
from cutlass.pipeline import PipelineUserType, PipelineOp
9+
from cutlass.pipeline import PipelineState
10+
from cutlass.pipeline import PipelineUserType
1311
from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
1412
from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
1513

1614

17-
# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed
18-
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
19-
"""
20-
Fences the mbarrier init and syncs the threadblock or cluster
21-
"""
22-
cute.arch.mbarrier_init_fence()
23-
24-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
25-
# If not using clusters, sync the threadblock
26-
_sync(Agent.ThreadBlock)
27-
else:
28-
# If using clusters, sync the cluster
29-
_sync(Agent.ThreadBlockCluster)
30-
31-
32-
def _sync(group: Agent):
33-
"""
34-
Syncs all threads within an agent.
35-
"""
36-
if group is Agent.Thread:
37-
raise NotImplementedError("Error: Not supported.")
38-
elif group is Agent.ThreadBlock:
39-
cute.arch.sync_threads()
40-
elif group is Agent.ThreadBlockCluster:
41-
cute.arch.cluster_arrive_relaxed()
42-
cute.arch.cluster_wait()
43-
else:
44-
assert False, (
45-
"Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
46-
)
47-
48-
4915
class PipelineStateSimple:
5016
"""
5117
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
@@ -172,84 +138,17 @@ def producer_acquire(
172138

173139
@dataclass(frozen=True)
174140
class PipelineTmaUmma(PipelineTmaUmmaOg):
175-
@staticmethod
176-
def create(
177-
*,
178-
num_stages: int,
179-
producer_group: CooperativeGroup,
180-
consumer_group: CooperativeGroup,
181-
tx_count: int,
182-
barrier_storage: cute.Pointer = None,
183-
cta_layout_vmnk: Optional[cute.Layout] = None,
184-
mcast_mode_mn: tuple[int, int] = (1, 1),
185-
init_wait: cutlass.Constexpr[bool] = True,
186-
):
187-
"""
188-
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
189-
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
190-
:type barrier_storage: cute.Pointer
191-
:param num_stages: Number of buffer stages for this pipeline
192-
:type num_stages: Int32
193-
:param producer_group: `CooperativeGroup` for the producer agent
194-
:type producer_group: CooperativeGroup
195-
:param consumer_group: `CooperativeGroup` for the consumer agent
196-
:type consumer_group: CooperativeGroup
197-
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
198-
:type tx_count: int
199-
:param cta_layout_vmnk: Layout of the cluster shape
200-
:type cta_layout_vmnk: cute.Layout | None
201-
:param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1.
202-
:type mcast_mode_mn: tuple[int, int]
203-
"""
204-
if not isinstance(barrier_storage, cute.Pointer):
205-
raise ValueError(
206-
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
207-
)
208-
209-
producer_type = PipelineOp.TmaLoad
210-
consumer_type = PipelineOp.TCGen05Mma
211-
212-
producer = (producer_type, producer_group)
213-
consumer = (consumer_type, consumer_group)
214-
215-
sync_object_full = PipelineTmaUmmaOg._make_sync_object(
216-
barrier_storage.align(min_align=8), num_stages, producer, tx_count
217-
)
218-
sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
219-
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
220-
)
221-
222-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
223-
# No mcast mask if not using clusters
224-
producer_mask = None
225-
# All threadblocks are leaders if not using clusters
226-
is_leader_cta = True
227-
else:
228-
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
229-
cta_layout_vmnk, mcast_mode_mn
230-
)
231-
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
232-
233-
cta_group = (
234-
cute.nvgpu.tcgen05.CtaGroup.ONE
235-
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
236-
else cute.nvgpu.tcgen05.CtaGroup.TWO
237-
)
238-
239-
consumer_mask = producer_mask
240-
241-
if const_expr(init_wait):
242-
pipeline_init_wait(cta_layout_vmnk)
141+
"""
142+
Override producer_acquire to take in extra_tx_count parameter.
143+
"""
243144

244-
return PipelineTmaUmma(
245-
sync_object_full,
246-
sync_object_empty,
247-
num_stages,
248-
producer_mask,
249-
consumer_mask,
250-
is_leader_cta,
251-
cta_group,
252-
)
145+
@staticmethod
146+
def create(*args, **kwargs):
147+
obj = PipelineTmaUmmaOg.create(*args, **kwargs)
148+
# Can't assign to __class__ directly since the dataclass is frozen
149+
# obj.__class__ = PipelineTmaUmma
150+
object.__setattr__(obj, "__class__", PipelineTmaUmma)
151+
return obj
253152

254153
def producer_acquire(
255154
self,

0 commit comments

Comments
 (0)