|
4 | 4 | from typing import Optional |
5 | 5 | from dataclasses import dataclass |
6 | 6 |
|
7 | | -import cutlass |
8 | | -import cutlass.cute as cute |
9 | 7 | from cutlass import Boolean, Int32, const_expr |
10 | 8 | from cutlass.cutlass_dsl import if_generate |
11 | | -from cutlass.pipeline import PipelineState, Agent, CooperativeGroup |
12 | | -from cutlass.pipeline import PipelineUserType, PipelineOp |
| 9 | +from cutlass.pipeline import PipelineState |
| 10 | +from cutlass.pipeline import PipelineUserType |
13 | 11 | from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg |
14 | 12 | from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg |
15 | 13 |
|
16 | 14 |
|
17 | | -# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed |
18 | | -def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): |
19 | | - """ |
20 | | - Fences the mbarrier init and syncs the threadblock or cluster |
21 | | - """ |
22 | | - cute.arch.mbarrier_init_fence() |
23 | | - |
24 | | - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: |
25 | | - # If not using clusters, sync the threadblock |
26 | | - _sync(Agent.ThreadBlock) |
27 | | - else: |
28 | | - # If using clusters, sync the cluster |
29 | | - _sync(Agent.ThreadBlockCluster) |
30 | | - |
31 | | - |
32 | | -def _sync(group: Agent): |
33 | | - """ |
34 | | - Syncs all threads within an agent. |
35 | | - """ |
36 | | - if group is Agent.Thread: |
37 | | - raise NotImplementedError("Error: Not supported.") |
38 | | - elif group is Agent.ThreadBlock: |
39 | | - cute.arch.sync_threads() |
40 | | - elif group is Agent.ThreadBlockCluster: |
41 | | - cute.arch.cluster_arrive_relaxed() |
42 | | - cute.arch.cluster_wait() |
43 | | - else: |
44 | | - assert False, ( |
45 | | - "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." |
46 | | - ) |
47 | | - |
48 | | - |
49 | 15 | class PipelineStateSimple: |
50 | 16 | """ |
51 | 17 | Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. |
@@ -172,84 +138,17 @@ def producer_acquire( |
172 | 138 |
|
173 | 139 | @dataclass(frozen=True) |
174 | 140 | class PipelineTmaUmma(PipelineTmaUmmaOg): |
175 | | - @staticmethod |
176 | | - def create( |
177 | | - *, |
178 | | - num_stages: int, |
179 | | - producer_group: CooperativeGroup, |
180 | | - consumer_group: CooperativeGroup, |
181 | | - tx_count: int, |
182 | | - barrier_storage: cute.Pointer = None, |
183 | | - cta_layout_vmnk: Optional[cute.Layout] = None, |
184 | | - mcast_mode_mn: tuple[int, int] = (1, 1), |
185 | | - init_wait: cutlass.Constexpr[bool] = True, |
186 | | - ): |
187 | | - """ |
188 | | - This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. |
189 | | - :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers |
190 | | - :type barrier_storage: cute.Pointer |
191 | | - :param num_stages: Number of buffer stages for this pipeline |
192 | | - :type num_stages: Int32 |
193 | | - :param producer_group: `CooperativeGroup` for the producer agent |
194 | | - :type producer_group: CooperativeGroup |
195 | | - :param consumer_group: `CooperativeGroup` for the consumer agent |
196 | | - :type consumer_group: CooperativeGroup |
197 | | - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage |
198 | | - :type tx_count: int |
199 | | - :param cta_layout_vmnk: Layout of the cluster shape |
200 | | - :type cta_layout_vmnk: cute.Layout | None |
201 | | - :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. |
202 | | - :type mcast_mode_mn: tuple[int, int] |
203 | | - """ |
204 | | - if not isinstance(barrier_storage, cute.Pointer): |
205 | | - raise ValueError( |
206 | | - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" |
207 | | - ) |
208 | | - |
209 | | - producer_type = PipelineOp.TmaLoad |
210 | | - consumer_type = PipelineOp.TCGen05Mma |
211 | | - |
212 | | - producer = (producer_type, producer_group) |
213 | | - consumer = (consumer_type, consumer_group) |
214 | | - |
215 | | - sync_object_full = PipelineTmaUmmaOg._make_sync_object( |
216 | | - barrier_storage.align(min_align=8), num_stages, producer, tx_count |
217 | | - ) |
218 | | - sync_object_empty = PipelineTmaUmmaOg._make_sync_object( |
219 | | - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer |
220 | | - ) |
221 | | - |
222 | | - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: |
223 | | - # No mcast mask if not using clusters |
224 | | - producer_mask = None |
225 | | - # All threadblocks are leaders if not using clusters |
226 | | - is_leader_cta = True |
227 | | - else: |
228 | | - producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask( |
229 | | - cta_layout_vmnk, mcast_mode_mn |
230 | | - ) |
231 | | - is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) |
232 | | - |
233 | | - cta_group = ( |
234 | | - cute.nvgpu.tcgen05.CtaGroup.ONE |
235 | | - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 |
236 | | - else cute.nvgpu.tcgen05.CtaGroup.TWO |
237 | | - ) |
238 | | - |
239 | | - consumer_mask = producer_mask |
240 | | - |
241 | | - if const_expr(init_wait): |
242 | | - pipeline_init_wait(cta_layout_vmnk) |
| 141 | + """ |
| 142 | + Override producer_acquire to take in extra_tx_count parameter. |
| 143 | + """ |
243 | 144 |
|
244 | | - return PipelineTmaUmma( |
245 | | - sync_object_full, |
246 | | - sync_object_empty, |
247 | | - num_stages, |
248 | | - producer_mask, |
249 | | - consumer_mask, |
250 | | - is_leader_cta, |
251 | | - cta_group, |
252 | | - ) |
| 145 | + @staticmethod |
| 146 | + def create(*args, **kwargs): |
| 147 | + obj = PipelineTmaUmmaOg.create(*args, **kwargs) |
| 148 | + # Can't assign to __class__ directly since the dataclass is frozen |
| 149 | + # obj.__class__ = PipelineTmaUmma |
| 150 | + object.__setattr__(obj, "__class__", PipelineTmaUmma) |
| 151 | + return obj |
253 | 152 |
|
254 | 153 | def producer_acquire( |
255 | 154 | self, |
|
0 commit comments