Skip to content

Commit c912a37

Browse files
committed
[Bwd,Sm90] Use quack.copy_utils
1 parent b9148ce commit c912a37

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

flash_attn/cute/block_sparse_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,8 +1122,7 @@ def _load_q_do_block_sm90(
11221122
else:
11231123
pipeline_Q.producer_acquire(producer_state_Q)
11241124
load_Q(m_block, producer_state=producer_state_Q)
1125-
with cute.arch.elect_one():
1126-
load_LSE(m_block, producer_state=producer_state_Q)
1125+
load_LSE(m_block, producer_state=producer_state_Q)
11271126

11281127
producer_state_dO_cur = (
11291128
producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
@@ -1134,8 +1133,7 @@ def _load_q_do_block_sm90(
11341133
else:
11351134
pipeline_dO.producer_acquire(producer_state_dO_cur)
11361135
load_dO(m_block, producer_state=producer_state_dO_cur)
1137-
with cute.arch.elect_one():
1138-
load_dPsum(m_block, producer_state=producer_state_dO_cur)
1136+
load_dPsum(m_block, producer_state=producer_state_dO_cur)
11391137

11401138
producer_state_Q.advance()
11411139
producer_state_dO.advance()

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from cutlass import Float32, Int32, Boolean, const_expr
1313
from cutlass.utils import LayoutEnum
1414

15+
from quack import copy_utils
1516
from quack import layout_utils
1617
from quack import sm90_utils
1718
from quack.sm90_utils import gemm_zero_init, gemm_w_idx
1819

1920
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
2021
from flash_attn.cute import utils
21-
from flash_attn.cute import copy_utils
2222
from flash_attn.cute.mask import AttentionMask
2323
from flash_attn.cute.seqlen_info import SeqlenInfoQK
2424
from flash_attn.cute.block_info import BlockInfo
@@ -825,8 +825,7 @@ def load(
825825
)
826826
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
827827
load_Q(first_m_block, producer_state=producer_state_Q)
828-
with cute.arch.elect_one():
829-
load_LSE(first_m_block, producer_state=producer_state_Q)
828+
load_LSE(first_m_block, producer_state=producer_state_Q)
830829
producer_state_dO_cur = (
831830
producer_state_dO
832831
if const_expr(self.Q_stage != self.dO_stage)
@@ -837,25 +836,22 @@ def load(
837836
)
838837
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
839838
load_dO(first_m_block, producer_state=producer_state_dO_cur)
840-
with cute.arch.elect_one():
841-
load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
839+
load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
842840
producer_state_Q.advance()
843841
producer_state_dO.advance()
844842

845843
for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
846844
pipeline_Q.producer_acquire(producer_state_Q)
847845
load_Q(m_block, producer_state=producer_state_Q)
848-
with cute.arch.elect_one():
849-
load_LSE(m_block, producer_state=producer_state_Q)
846+
load_LSE(m_block, producer_state=producer_state_Q)
850847
producer_state_dO_cur = (
851848
producer_state_dO
852849
if const_expr(self.Q_stage != self.dO_stage)
853850
else producer_state_Q
854851
)
855852
pipeline_dO.producer_acquire(producer_state_dO_cur)
856853
load_dO(m_block, producer_state=producer_state_dO_cur)
857-
with cute.arch.elect_one():
858-
load_dPsum(m_block, producer_state=producer_state_dO_cur)
854+
load_dPsum(m_block, producer_state=producer_state_dO_cur)
859855
producer_state_Q.advance()
860856
producer_state_dO.advance()
861857
else:

0 commit comments

Comments
 (0)