|
8 | 8 | import cutlass.cute as cute |
9 | 9 | from cutlass import Boolean, Int32, const_expr |
10 | 10 | from cutlass.cutlass_dsl import if_generate |
11 | | -from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup |
| 11 | +from cutlass.pipeline import PipelineState, Agent, CooperativeGroup |
12 | 12 | from cutlass.pipeline import PipelineUserType, PipelineOp |
13 | 13 | from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg |
14 | 14 | from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg |
@@ -150,19 +150,24 @@ def producer_acquire( |
150 | 150 | state: PipelineState, |
151 | 151 | try_acquire_token: Optional[Boolean] = None, |
152 | 152 | extra_tx_count: int = 0, |
| 153 | + *, |
| 154 | + loc=None, |
| 155 | + ip=None, |
153 | 156 | ): |
154 | 157 | """ |
155 | 158 | TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. |
156 | 159 | """ |
157 | 160 | if_generate( |
158 | 161 | try_acquire_token is None or try_acquire_token == 0, |
159 | | - lambda: self.sync_object_empty.wait(state.index, state.phase), |
| 162 | + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), |
| 163 | + loc=loc, |
| 164 | + ip=ip, |
160 | 165 | ) |
161 | 166 | if const_expr(extra_tx_count == 0): |
162 | | - self.sync_object_full.arrive(state.index, self.producer_mask) |
| 167 | + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) |
163 | 168 | else: |
164 | 169 | tx_count = self.sync_object_full.tx_count + extra_tx_count |
165 | | - self.sync_object_full.arrive_and_expect_tx(state.index, tx_count) |
| 170 | + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) |
166 | 171 |
|
167 | 172 |
|
168 | 173 | @dataclass(frozen=True) |
@@ -207,10 +212,10 @@ def create( |
207 | 212 | producer = (producer_type, producer_group) |
208 | 213 | consumer = (consumer_type, consumer_group) |
209 | 214 |
|
210 | | - sync_object_full = PipelineAsync._make_sync_object( |
| 215 | + sync_object_full = PipelineTmaUmmaOg._make_sync_object( |
211 | 216 | barrier_storage.align(min_align=8), num_stages, producer, tx_count |
212 | 217 | ) |
213 | | - sync_object_empty = PipelineAsync._make_sync_object( |
| 218 | + sync_object_empty = PipelineTmaUmmaOg._make_sync_object( |
214 | 219 | barrier_storage.align(min_align=8) + num_stages, num_stages, consumer |
215 | 220 | ) |
216 | 221 |
|
@@ -251,22 +256,35 @@ def producer_acquire( |
251 | 256 | state: PipelineState, |
252 | 257 | try_acquire_token: Optional[Boolean] = None, |
253 | 258 | extra_tx_count: int = 0, |
| 259 | + *, |
| 260 | + loc=None, |
| 261 | + ip=None, |
254 | 262 | ): |
255 | 263 | """ |
256 | 264 | TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. |
257 | 265 | """ |
258 | 266 | if_generate( |
259 | 267 | try_acquire_token is None or try_acquire_token == 0, |
260 | | - lambda: self.sync_object_empty.wait(state.index, state.phase), |
| 268 | + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), |
| 269 | + loc=loc, |
| 270 | + ip=ip, |
261 | 271 | ) |
262 | 272 | if const_expr(extra_tx_count == 0): |
263 | 273 | if_generate( |
264 | 274 | self.is_leader_cta, |
265 | | - lambda: self.sync_object_full.arrive(state.index, self.producer_mask), |
| 275 | + lambda: self.sync_object_full.arrive( |
| 276 | + state.index, self.producer_mask, loc=loc, ip=ip |
| 277 | + ), |
| 278 | + loc=loc, |
| 279 | + ip=ip, |
266 | 280 | ) |
267 | 281 | else: |
268 | 282 | tx_count = self.sync_object_full.tx_count + extra_tx_count |
269 | 283 | if_generate( |
270 | 284 | self.is_leader_cta, |
271 | | - lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count), |
| 285 | + lambda: self.sync_object_full.arrive_and_expect_tx( |
| 286 | + state.index, tx_count, loc=loc, ip=ip |
| 287 | + ), |
| 288 | + loc=loc, |
| 289 | + ip=ip, |
272 | 290 | ) |
0 commit comments