Skip to content

Commit 0ec504a

Browse files
committed
Address review comments
1 parent e8f3994 commit 0ec504a

2 files changed

Lines changed: 70 additions & 54 deletions

File tree

flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,14 @@ def cp_reduce_bf16_add(dst_addr, src_addr, size, *, loc=None, ip=None):
455455
"""Async bulk copy-reduce (add) from shared memory to global memory.
456456
457457
PTX instruction: cp.reduce.async.bulk.global.shared::cta.bulk_group.add.bf16.noftz [dstMem], [srcMem], size;
458+
459+
Performs an asynchronous bulk reduction operation that atomically adds bf16 values
460+
from shared memory (CTA scope) to global memory. The operation is part of a bulk group.
461+
462+
Args:
463+
dst_addr: Destination global memory address (pointer)
464+
src_addr: Source shared memory address (pointer)
465+
size: Size in bytes of the data to reduce (must be a multiple of 4)
458466
"""
459467
llvm.inline_asm(
460468
None,
@@ -1994,14 +2002,15 @@ def kernel(
19942002
reg_out_ptrs = cute.make_tensor(
19952003
smem_out_ptrs_ptr, (8,)
19962004
).load()
1997-
for iter in cutlass.range_constexpr(8):
1998-
m_row = warp_idx * 8 + iter
2005+
# 8 elements per warp for a total of 32 elements per subtile
2006+
for idx_in_warp in cutlass.range_constexpr(8):
2007+
m_row = warp_idx * 8 + idx_in_warp
19992008
m_idx = (
20002009
cur_tile_dim1_offset
20012010
+ subtile_idx * epi_tile[1].shape
20022011
+ m_row
20032012
)
2004-
out_ptr = reg_out_ptrs[iter]
2013+
out_ptr = reg_out_ptrs[idx_in_warp]
20052014
if m_idx < tile_sched_params.masked_m[l_idx]:
20062015
src_ptr = cute.make_ptr(
20072016
self.c_dtype,
@@ -2021,40 +2030,43 @@ def kernel(
20212030
src_smem_tensor[((None, m_row),)].iterator,
20222031
epi_tile[0].shape * c_dtype_bytes,
20232032
)
2033+
# Fence and barrier to make sure shared memory store is visible to TMA store
2034+
c_pipeline.producer_commit()
2035+
c_pipeline.producer_acquire()
20242036
else:
20252037
if warp_idx == self.epilog_warp_id[0]:
20262038
cute.copy(
20272039
tma_atom_c,
20282040
bSG_sC[(None, c_buffer)],
20292041
bSG_gC[(None, subtile_idx)],
20302042
)
2031-
# Fence and barrier to make sure shared memory store is visible to TMA store
2032-
with cute.arch.elect_one():
2033-
c_pipeline.producer_commit()
2043+
# Fence and barrier to make sure shared memory store is visible to TMA store
2044+
c_pipeline.producer_commit()
20342045

2035-
if cutlass.const_expr(tile_sched_params.dst_signals is not None):
2036-
dsm_counter = (dsm_counter + 1).to(Uint8)
2037-
will_write_signals = (
2038-
read_byte(dsm_pending_packed, dsm_pending_idx)
2039-
== dsm_counter
2040-
)
2046+
if cutlass.const_expr(
2047+
tile_sched_params.dst_signals is not None
2048+
):
2049+
dsm_counter = (dsm_counter + 1).to(Uint8)
2050+
will_write_signals = (
2051+
read_byte(dsm_pending_packed, dsm_pending_idx)
2052+
== dsm_counter
2053+
)
20412054

2042-
if will_write_signals:
2043-
# The original c_pipeline.producer_acquire()
2044-
# := PipelineTmaStore.producer_acquire()
2045-
# := TmaStoreFence.wait()
2046-
# := cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
2047-
cute.arch.cp_async_bulk_wait_group(
2048-
self.num_c_stage - 1,
2049-
# Change `read` from True to False to also wait writes
2050-
read=False,
2051-
)
2052-
else:
2053-
c_pipeline.producer_acquire()
2055+
if will_write_signals:
2056+
# The original c_pipeline.producer_acquire()
2057+
# := PipelineTmaStore.producer_acquire()
2058+
# := TmaStoreFence.wait()
2059+
# := cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
2060+
cute.arch.cp_async_bulk_wait_group(
2061+
self.num_c_stage - 1,
2062+
# Change `read` from True to False to also wait writes
2063+
read=False,
2064+
)
2065+
else:
2066+
c_pipeline.producer_acquire()
20542067

2055-
else:
2056-
with cute.arch.elect_one():
2057-
c_pipeline.producer_acquire()
2068+
else:
2069+
c_pipeline.producer_acquire()
20582070

20592071
cute.arch.barrier(
20602072
barrier_id=self.epilog_sync_bar_id,
@@ -3007,7 +3019,7 @@ def __call__(
30073019
),
30083020
)
30093021
row_size = self._n if self._c_major == "m" else self._m
3010-
if cutlass.const_expr(self._num_ranks > 0):
3022+
if cutlass.const_expr(self._is_combine_fusion):
30113023
topk_weights_tensor = cute.make_tensor(
30123024
topk_weights_ptr,
30133025
layout=cute.make_ordered_layout((self._l, row_size), order=(1, 0)),
@@ -3188,8 +3200,8 @@ def get_cute_pointers(
31883200
if not enable_dst_signals:
31893201
dst_signals_data_ptr = None
31903202

3191-
# When num_ranks == 0, combine-related pointers should be None
3192-
if num_ranks == 0:
3203+
# When combine fusion is disabled, combine-related pointers should be None
3204+
if not is_combine_fusion:
31933205
topk_weights_data_ptr = None
31943206
idx_src_info_data_ptr = None
31953207
rank_src_info_data_ptr = None
@@ -3555,27 +3567,6 @@ def grouped_gemm_nt_masked(
35553567
- The result is written to c_tensor.
35563568
"""
35573569

3558-
if is_combine_fusion:
3559-
required = {
3560-
"topk_weights": topk_weights,
3561-
"idx_src_info": idx_src_info,
3562-
"rank_src_info": rank_src_info,
3563-
"out_ptrs": out_ptrs,
3564-
"barrier_flag_local": barrier_flag_local,
3565-
"barrier_flag_multicast": barrier_flag_multicast,
3566-
}
3567-
missing = [name for name, val in required.items() if val is None]
3568-
if missing:
3569-
raise ValueError(
3570-
"is_combine_fusion=True requires non-None values for: "
3571-
f"{', '.join(missing)}"
3572-
)
3573-
if num_ranks <= 0:
3574-
raise ValueError("is_combine_fusion=True requires num_ranks > 0")
3575-
# For combine fusion, we always swap the AB input tensors
3576-
if not is_swap_ab:
3577-
raise ValueError("is_combine_fusion=True requires is_swap_ab=True")
3578-
35793570
if is_swap_ab:
35803571
a_torch, sfa_torch = rhs
35813572
b_torch, sfb_torch = lhs
@@ -3607,6 +3598,29 @@ def grouped_gemm_nt_masked(
36073598
if major == 11 and minor == 0:
36083599
raise ValueError("SM110 is not supported for cute-dsl backend.")
36093600

3601+
if is_combine_fusion:
3602+
required = {
3603+
"topk_weights": topk_weights,
3604+
"idx_src_info": idx_src_info,
3605+
"rank_src_info": rank_src_info,
3606+
"out_ptrs": out_ptrs,
3607+
"barrier_flag_local": barrier_flag_local,
3608+
"barrier_flag_multicast": barrier_flag_multicast,
3609+
}
3610+
missing = [name for name, val in required.items() if val is None]
3611+
if missing:
3612+
raise ValueError(
3613+
"is_combine_fusion=True requires non-None values for: "
3614+
f"{', '.join(missing)}"
3615+
)
3616+
if num_ranks <= 0:
3617+
raise ValueError("is_combine_fusion=True requires num_ranks > 0")
3618+
# For combine fusion, we always swap the AB input tensors
3619+
if not is_swap_ab:
3620+
raise ValueError("is_combine_fusion=True requires is_swap_ab=True")
3621+
if mma_tiler_mn != (128, 128):
3622+
raise ValueError("is_combine_fusion=True requires mma_tiler_mn=(128, 128)")
3623+
36103624
return get_cute_dsl_compiled_masked_gemm_kernel(
36113625
m=m,
36123626
n=n,

tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_blockscaled_gemm_python_interface(
288288
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
289289
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
290290
torch.distributed.init_process_group(backend="nccl", device_id=device)
291-
291+
c_dtype = "bfloat16"
292292
for BATCH_SIZE in [16, 64, 128, 256]:
293293
if torch.distributed.get_rank() == 0:
294294
print(f"\n{'=' * 60}")
@@ -300,15 +300,17 @@ def test_blockscaled_gemm_python_interface(
300300
ab_dtype="float4_e2m1fn",
301301
sf_dtype="float8_e4m3fn",
302302
sf_vec_size=16,
303-
c_dtype="bfloat16",
303+
c_dtype=c_dtype,
304304
a_major="k",
305305
b_major="k",
306306
c_major="n",
307307
fuse_alpha=True,
308308
alpha_dtype="float32",
309309
mma_tiler_mn=(128, 128),
310310
cluster_shape_mn=(1, 1),
311-
tolerance=10000,
311+
tolerance=10000
312+
if c_dtype == "bfloat16"
313+
else 1e-01, # Rely on the relative tolerance with bfloat16 accumulation
312314
iterations=1,
313315
)
314316
# WAR for https://github.com/pytorch/pytorch/issues/162429

0 commit comments

Comments
 (0)