@@ -455,6 +455,14 @@ def cp_reduce_bf16_add(dst_addr, src_addr, size, *, loc=None, ip=None):
455455 """Async bulk copy-reduce (add) from shared memory to global memory.
456456
457457 PTX instruction: cp.reduce.async.bulk.global.shared::cta.bulk_group.add.bf16.noftz [dstMem], [srcMem], size;
458+
459+ Performs an asynchronous bulk reduction operation that atomically adds bf16 values
460+ from shared memory (CTA scope) to global memory. The operation is part of a bulk group.
461+
462+ Args:
463+ dst_addr: Destination global memory address (pointer)
464+ src_addr: Source shared memory address (pointer)
465+ size: Size in bytes of the data to reduce (must be a multiple of 4)
458466 """
459467 llvm .inline_asm (
460468 None ,
@@ -1994,14 +2002,15 @@ def kernel(
19942002 reg_out_ptrs = cute .make_tensor (
19952003 smem_out_ptrs_ptr , (8 ,)
19962004 ).load ()
1997- for iter in cutlass .range_constexpr (8 ):
1998- m_row = warp_idx * 8 + iter
2005+ # 8 elements per warp for a total of 32 elements per subtile
2006+ for idx_in_warp in cutlass .range_constexpr (8 ):
2007+ m_row = warp_idx * 8 + idx_in_warp
19992008 m_idx = (
20002009 cur_tile_dim1_offset
20012010 + subtile_idx * epi_tile [1 ].shape
20022011 + m_row
20032012 )
2004- out_ptr = reg_out_ptrs [iter ]
2013+ out_ptr = reg_out_ptrs [idx_in_warp ]
20052014 if m_idx < tile_sched_params .masked_m [l_idx ]:
20062015 src_ptr = cute .make_ptr (
20072016 self .c_dtype ,
@@ -2021,40 +2030,43 @@ def kernel(
20212030 src_smem_tensor [((None , m_row ),)].iterator ,
20222031 epi_tile [0 ].shape * c_dtype_bytes ,
20232032 )
2033+ # Fence and barrier to make sure shared memory store is visible to TMA store
2034+ c_pipeline .producer_commit ()
2035+ c_pipeline .producer_acquire ()
20242036 else :
20252037 if warp_idx == self .epilog_warp_id [0 ]:
20262038 cute .copy (
20272039 tma_atom_c ,
20282040 bSG_sC [(None , c_buffer )],
20292041 bSG_gC [(None , subtile_idx )],
20302042 )
2031- # Fence and barrier to make sure shared memory store is visible to TMA store
2032- with cute .arch .elect_one ():
2033- c_pipeline .producer_commit ()
2043+ # Fence and barrier to make sure shared memory store is visible to TMA store
2044+ c_pipeline .producer_commit ()
20342045
2035- if cutlass .const_expr (tile_sched_params .dst_signals is not None ):
2036- dsm_counter = (dsm_counter + 1 ).to (Uint8 )
2037- will_write_signals = (
2038- read_byte (dsm_pending_packed , dsm_pending_idx )
2039- == dsm_counter
2040- )
2046+ if cutlass .const_expr (
2047+ tile_sched_params .dst_signals is not None
2048+ ):
2049+ dsm_counter = (dsm_counter + 1 ).to (Uint8 )
2050+ will_write_signals = (
2051+ read_byte (dsm_pending_packed , dsm_pending_idx )
2052+ == dsm_counter
2053+ )
20412054
2042- if will_write_signals :
2043- # The original c_pipeline.producer_acquire()
2044- # := PipelineTmaStore.producer_acquire()
2045- # := TmaStoreFence.wait()
2046- # := cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
2047- cute .arch .cp_async_bulk_wait_group (
2048- self .num_c_stage - 1 ,
2049- # Change `read` from True to False to also wait writes
2050- read = False ,
2051- )
2052- else :
2053- c_pipeline .producer_acquire ()
2055+ if will_write_signals :
2056+ # The original c_pipeline.producer_acquire()
2057+ # := PipelineTmaStore.producer_acquire()
2058+ # := TmaStoreFence.wait()
2059+ # := cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
2060+ cute .arch .cp_async_bulk_wait_group (
2061+ self .num_c_stage - 1 ,
2062+ # Change `read` from True to False to also wait writes
2063+ read = False ,
2064+ )
2065+ else :
2066+ c_pipeline .producer_acquire ()
20542067
2055- else :
2056- with cute .arch .elect_one ():
2057- c_pipeline .producer_acquire ()
2068+ else :
2069+ c_pipeline .producer_acquire ()
20582070
20592071 cute .arch .barrier (
20602072 barrier_id = self .epilog_sync_bar_id ,
@@ -3007,7 +3019,7 @@ def __call__(
30073019 ),
30083020 )
30093021 row_size = self ._n if self ._c_major == "m" else self ._m
3010- if cutlass .const_expr (self ._num_ranks > 0 ):
3022+ if cutlass .const_expr (self ._is_combine_fusion ):
30113023 topk_weights_tensor = cute .make_tensor (
30123024 topk_weights_ptr ,
30133025 layout = cute .make_ordered_layout ((self ._l , row_size ), order = (1 , 0 )),
@@ -3188,8 +3200,8 @@ def get_cute_pointers(
31883200 if not enable_dst_signals :
31893201 dst_signals_data_ptr = None
31903202
3191- # When num_ranks == 0 , combine-related pointers should be None
3192- if num_ranks == 0 :
3203+ # When combine fusion is disabled , combine-related pointers should be None
3204+ if not is_combine_fusion :
31933205 topk_weights_data_ptr = None
31943206 idx_src_info_data_ptr = None
31953207 rank_src_info_data_ptr = None
@@ -3555,27 +3567,6 @@ def grouped_gemm_nt_masked(
35553567 - The result is written to c_tensor.
35563568 """
35573569
3558- if is_combine_fusion :
3559- required = {
3560- "topk_weights" : topk_weights ,
3561- "idx_src_info" : idx_src_info ,
3562- "rank_src_info" : rank_src_info ,
3563- "out_ptrs" : out_ptrs ,
3564- "barrier_flag_local" : barrier_flag_local ,
3565- "barrier_flag_multicast" : barrier_flag_multicast ,
3566- }
3567- missing = [name for name , val in required .items () if val is None ]
3568- if missing :
3569- raise ValueError (
3570- "is_combine_fusion=True requires non-None values for: "
3571- f"{ ', ' .join (missing )} "
3572- )
3573- if num_ranks <= 0 :
3574- raise ValueError ("is_combine_fusion=True requires num_ranks > 0" )
3575- # For combine fusion, we always swap the AB input tensors
3576- if not is_swap_ab :
3577- raise ValueError ("is_combine_fusion=True requires is_swap_ab=True" )
3578-
35793570 if is_swap_ab :
35803571 a_torch , sfa_torch = rhs
35813572 b_torch , sfb_torch = lhs
@@ -3607,6 +3598,29 @@ def grouped_gemm_nt_masked(
36073598 if major == 11 and minor == 0 :
36083599 raise ValueError ("SM110 is not supported for cute-dsl backend." )
36093600
3601+ if is_combine_fusion :
3602+ required = {
3603+ "topk_weights" : topk_weights ,
3604+ "idx_src_info" : idx_src_info ,
3605+ "rank_src_info" : rank_src_info ,
3606+ "out_ptrs" : out_ptrs ,
3607+ "barrier_flag_local" : barrier_flag_local ,
3608+ "barrier_flag_multicast" : barrier_flag_multicast ,
3609+ }
3610+ missing = [name for name , val in required .items () if val is None ]
3611+ if missing :
3612+ raise ValueError (
3613+ "is_combine_fusion=True requires non-None values for: "
3614+ f"{ ', ' .join (missing )} "
3615+ )
3616+ if num_ranks <= 0 :
3617+ raise ValueError ("is_combine_fusion=True requires num_ranks > 0" )
3618+ # For combine fusion, we always swap the AB input tensors
3619+ if not is_swap_ab :
3620+ raise ValueError ("is_combine_fusion=True requires is_swap_ab=True" )
3621+ if mma_tiler_mn != (128 , 128 ):
3622+ raise ValueError ("is_combine_fusion=True requires mma_tiler_mn=(128, 128)" )
3623+
36103624 return get_cute_dsl_compiled_masked_gemm_kernel (
36113625 m = m ,
36123626 n = n ,
0 commit comments