@@ -418,23 +418,20 @@ def initial_work_tile_info(self, *, loc=None, ip=None):
418418 @cute .jit
419419 def prefetch_next_work (self , * , loc = None , ip = None ):
420420 params = self .params
421- next_tile_idx = 0
422- if cute .arch .lane_idx () == 0 :
423- next_tile_idx = cute .arch .grid_dim ()[0 ] + utils .atomic_add_i32 (
424- 1 , params .tile_count_semaphore ,
425- )
426- next_tile_idx = cute .arch .shuffle_sync (next_tile_idx , 0 )
427- work_info = self .get_current_work (next_tile_idx )
428421 self ._scheduler_pipeline .producer_acquire (
429422 self ._producer_pipeline_state ,
430423 self ._scheduler_pipeline .producer_try_acquire (self ._producer_pipeline_state ),
431424 )
432- block , head_idx , batch_idx , _ = work_info .tile_idx
433425 if cute .arch .lane_idx () == 0 :
426+ next_tile_idx = cute .arch .grid_dim ()[0 ] + utils .atomic_add_i32 (
427+ 1 , params .tile_count_semaphore ,
428+ )
429+ work_info = self .get_current_work (next_tile_idx )
430+ block , head_idx , batch_idx , _ = work_info .tile_idx
434431 self ._work_info [0 ] = block
435432 self ._work_info [1 ] = head_idx
436433 self ._work_info [2 ] = batch_idx
437- self ._scheduler_pipeline .producer_commit (self ._producer_pipeline_state )
434+ self ._scheduler_pipeline .producer_commit (self ._producer_pipeline_state )
438435 self ._producer_pipeline_state .advance ()
439436
440437 @cute .jit
0 commit comments