Skip to content

Commit e46c30b

Browse files
committed
tune dynamic scheduler prefetch
1 parent 6540d7b commit e46c30b

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ def __call__(
584584
mCuSeqlensQ is None and
585585
mSeqUsedQ is None
586586
)
587-
print("Dynamic persistent is ", self.dynamic_persistent)
588587
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
589588
TileScheduler = SingleTileVarlenScheduler
590589
elif const_expr(not self.dynamic_persistent):
@@ -717,7 +716,6 @@ class SharedStorage:
717716
if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None):
718717
raise NotImplementedError("Block sparsity + paged KV not supported on SM100")
719718

720-
print("smem = ", self.shared_storage.size_in_bytes())
721719
# Launch the kernel synchronously
722720
self.kernel(
723721
mQ,
@@ -911,7 +909,7 @@ def kernel(
911909
assert tile_count_semaphore is not None
912910
work_info = storage.work_info.get_tensor((4, ))
913911
sched_pipeline_producer_group = pipeline.CooperativeGroup(
914-
pipeline.Agent.Thread, 32
912+
pipeline.Agent.Thread, 1
915913
)
916914
sched_pipeline_consumer_group = pipeline.CooperativeGroup(
917915
pipeline.Agent.Thread, 32 * self.num_non_empty_warps

flash_attn/cute/tile_scheduler.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,23 +418,20 @@ def initial_work_tile_info(self, *, loc=None, ip=None):
418418
@cute.jit
419419
def prefetch_next_work(self, *, loc=None, ip=None):
420420
params = self.params
421-
next_tile_idx = 0
422-
if cute.arch.lane_idx() == 0:
423-
next_tile_idx = cute.arch.grid_dim()[0] + utils.atomic_add_i32(
424-
1, params.tile_count_semaphore,
425-
)
426-
next_tile_idx = cute.arch.shuffle_sync(next_tile_idx, 0)
427-
work_info = self.get_current_work(next_tile_idx)
428421
self._scheduler_pipeline.producer_acquire(
429422
self._producer_pipeline_state,
430423
self._scheduler_pipeline.producer_try_acquire(self._producer_pipeline_state),
431424
)
432-
block, head_idx, batch_idx, _ = work_info.tile_idx
433425
if cute.arch.lane_idx() == 0:
426+
next_tile_idx = cute.arch.grid_dim()[0] + utils.atomic_add_i32(
427+
1, params.tile_count_semaphore,
428+
)
429+
work_info = self.get_current_work(next_tile_idx)
430+
block, head_idx, batch_idx, _ = work_info.tile_idx
434431
self._work_info[0] = block
435432
self._work_info[1] = head_idx
436433
self._work_info[2] = batch_idx
437-
self._scheduler_pipeline.producer_commit(self._producer_pipeline_state)
434+
self._scheduler_pipeline.producer_commit(self._producer_pipeline_state)
438435
self._producer_pipeline_state.advance()
439436

440437
@cute.jit

0 commit comments

Comments
 (0)