1212from cutlass import Float32 , Int32 , Boolean , const_expr
1313from cutlass .utils import LayoutEnum
1414
15+ from quack import copy_utils
1516from quack import layout_utils
1617from quack import sm90_utils
1718from quack .sm90_utils import gemm_zero_init , gemm_w_idx
1819
1920from flash_attn .cute .cute_dsl_utils import assume_tensor_aligned
2021from flash_attn .cute import utils
21- from flash_attn .cute import copy_utils
2222from flash_attn .cute .mask import AttentionMask
2323from flash_attn .cute .seqlen_info import SeqlenInfoQK
2424from flash_attn .cute .block_info import BlockInfo
@@ -825,8 +825,7 @@ def load(
825825 )
826826 load_K (tma_bar_ptr = pipeline_Q .producer_get_barrier (producer_state_Q ))
827827 load_Q (first_m_block , producer_state = producer_state_Q )
828- with cute .arch .elect_one ():
829- load_LSE (first_m_block , producer_state = producer_state_Q )
828+ load_LSE (first_m_block , producer_state = producer_state_Q )
830829 producer_state_dO_cur = (
831830 producer_state_dO
832831 if const_expr (self .Q_stage != self .dO_stage )
@@ -837,25 +836,22 @@ def load(
837836 )
838837 load_V (tma_bar_ptr = pipeline_dO .producer_get_barrier (producer_state_dO_cur ))
839838 load_dO (first_m_block , producer_state = producer_state_dO_cur )
840- with cute .arch .elect_one ():
841- load_dPsum (first_m_block , producer_state = producer_state_dO_cur )
839+ load_dPsum (first_m_block , producer_state = producer_state_dO_cur )
842840 producer_state_Q .advance ()
843841 producer_state_dO .advance ()
844842
845843 for m_block in cutlass .range (m_block_min + 1 , m_block_max , unroll = 1 ):
846844 pipeline_Q .producer_acquire (producer_state_Q )
847845 load_Q (m_block , producer_state = producer_state_Q )
848- with cute .arch .elect_one ():
849- load_LSE (m_block , producer_state = producer_state_Q )
846+ load_LSE (m_block , producer_state = producer_state_Q )
850847 producer_state_dO_cur = (
851848 producer_state_dO
852849 if const_expr (self .Q_stage != self .dO_stage )
853850 else producer_state_Q
854851 )
855852 pipeline_dO .producer_acquire (producer_state_dO_cur )
856853 load_dO (m_block , producer_state = producer_state_dO_cur )
857- with cute .arch .elect_one ():
858- load_dPsum (m_block , producer_state = producer_state_dO_cur )
854+ load_dPsum (m_block , producer_state = producer_state_dO_cur )
859855 producer_state_Q .advance ()
860856 producer_state_dO .advance ()
861857 else :
0 commit comments