feat: Add CuTe DSL grouped-gemm + combine fusion support#2944
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdded AB-swap and optional combine-fusion to the SM100 block-scaled persistent GEMM: new kernel knobs ( ChangesBlock‑scaled persistent GEMM (kernel, scheduler, host, DSL, epilogue, test)
Sequence DiagramsequenceDiagram
participant Host as Host/Launcher
participant Rank as Rank Process
participant GPU as GPU Kernel
participant SymmMem as SymmetricMemory
participant Barrier as Barrier/Flags
Host ->> Rank: rendezvous & allocate symmetric outputs/flags
Rank ->> Barrier: initial global barrier
Rank ->> GPU: launch kernel (A,B,topk,idx,rank_info,out_ptrs,barrier_flags,is_combine_fusion,is_swap_ab)
GPU ->> SymmMem: prefetch topk_weights / out_ptrs into SMEM (S2R)
GPU ->> GPU: compute GEMM tiles (scheduler honours is_swap_ab)
GPU ->> GPU: scale accumulators by topk weights (and alpha)
GPU ->> SymmMem: cp.reduce.async.bulk...add (bulk SMEM → GMEM) using out_ptrs/idx/rank_info
GPU ->> Barrier: bulk-wait + set/release barrier flags
GPU -->> Rank: kernel completes
Rank ->> Host: copy & validate symmetric output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (7)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (6)
52-62: Useflashinfer.utils.is_sm100a_supported()for architecture check.Per coding guidelines, tests should use flashinfer utility functions for GPU capability checks. Replace the manual capability check with the standardized helper.
Proposed fix
+from flashinfer.utils import is_sm100a_supported + def test_blockscaled_gemm_python_interface( ... ): if not is_cute_dsl_available(): print("Skipping: Please `pip install nvidia-cutlass-dsl`") return torch.manual_seed(42) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - major, minor = torch.cuda.get_device_capability(device) - - if not (major == 10 and minor == 0): + if not is_sm100a_supported(device): print("Skipping: Cute-dsl backend is only supported on SM100.") returnAs per coding guidelines: Use flashinfer.utils functions (
is_sm100a_supported()) to skip tests on unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 52 - 62, Replace the manual SM version check using torch.cuda.get_device_capability with the flashinfer helper: import and call is_sm100a_supported() instead of evaluating major/minor; in the block after setting device (and after is_cute_dsl_available()), replace the if not (major == 10 and minor == 0): ... return with if not is_sm100a_supported(): print("Skipping: Cute-dsl backend is only supported on SM100.") return so the test uses the standardized GPU capability check.
111-155: Prefix unused unpacked variables with underscore.The variables
a_tensor,b_tensor,sfa_tensor, andsfb_tensorare unpacked but never used. Prefix them with_to indicate intentional non-use and silence linter warnings.Proposed fix
- a_tensor, a_torch = cutlass_torch.cute_tensor_like( + _a_tensor, a_torch = cutlass_torch.cute_tensor_like( a_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( + _b_tensor, b_torch = cutlass_torch.cute_tensor_like( b_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, ) ... - sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) - sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 111 - 155, The unpacked variables a_tensor, b_tensor, sfa_tensor and sfb_tensor from calls to cutlass_torch.cute_tensor_like and create_scale_factor_tensor are never used; rename them with a leading underscore (e.g., _a_tensor, _b_tensor, _sfa_tensor, _sfb_tensor) to indicate intentional non-use and silence the linter. Update the destructuring at the two cute_tensor_like calls and the two create_scale_factor_tensor calls (references to a_tensor, b_tensor, sfa_tensor, sfb_tensor) to the underscored names, leaving the used variables (a_torch, b_torch, sfa_ref, sfa_torch, sfb_ref, sfb_torch) unchanged.
30-33: Consider documenting the purpose of global handle lists.The global lists
c_torch_handle_listandbarrier_flag_local_handle_listprevent garbage collection of symmetric memory handles. Add a comment explaining this is necessary due to the PyTorch issue referenced at line 314.Suggested documentation
-# WAR for https://github.com/pytorch/pytorch/issues/162429 +# WAR for https://github.com/pytorch/pytorch/issues/162429 +# Keep symmetric memory handles alive to prevent premature GC during test execution c_torch_handle_list = [] barrier_flag_local_handle_list = []🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 30 - 33, Add a brief comment above the global lists c_torch_handle_list and barrier_flag_local_handle_list explaining they are intentionally kept as module-level globals to prevent garbage collection of symmetric memory handles (workaround for PyTorch issue `#162429`) and reference the issue number (and the related note at line 314) so future readers know why these handles must persist; update the comment near the definitions of c_torch_handle_list and barrier_flag_local_handle_list to clearly state their purpose and link to the PyTorch issue.
64-65: Rename ambiguous variableltobatch_sizeornum_batches.The variable name
lis flagged as ambiguous (E741). In the context of batched GEMM, this represents the batch dimension, so a clearer name likebatch_sizeornum_batcheswould improve readability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 64 - 65, The variable name "l" in the assignment "l, m = lm" is ambiguous; rename it to a clear batch dimension name (e.g., batch_size or num_batches) wherever it's defined and subsequently used in this test (search for "l" references in this file, especially around the test function and any loops or shape constructs) and update "l, m = lm" to "batch_size, m = lm" (or "num_batches, m = lm") and replace all uses of "l" with the new name (including any variable unpacking, shape assertions, and calls to the GEMM helpers) to preserve behavior.
214-216: Document the purpose of the jitter injection.The sleep injection on rank 1 is presumably to test race condition handling, but this should be documented. The hardcoded 1-second sleep could also cause test flakiness in CI environments with different timing characteristics.
Suggested improvement
if my_rank == 1: - # Inject jitter to trigger potential race conditions + # Inject jitter on rank 1 to stress-test the distributed barrier + # synchronization. This ensures slower ranks are handled correctly + # by the spin-lock barrier mechanism. torch.cuda._sleep(1000000000)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 214 - 216, Add a clear inline comment above the torch.cuda._sleep call explaining that the sleep simulates jitter to exercise potential GPU-side race conditions on the process where my_rank == 1, and replace the hardcoded long 1s sleep with a small, bounded, randomized jitter (e.g., a few micro- to milliseconds) or controlled via an environment variable (e.g., TEST_JITTER_US) so CI timing won’t be brittle; reference the current check (my_rank == 1) and the call site (torch.cuda._sleep) when making the change.
287-315: Document the test execution pattern more clearly.The test uses
torchrunfor distributed execution andos._exit(0)as a workaround. Consider adding a more detailed docstring explaining:
- Why this isn't a standard pytest test
- The expected invocation command
- What the workaround addresses
Suggested documentation improvement
if __name__ == "__main__": + # This test requires distributed execution via torchrun. + # Standard pytest discovery is not used because multi-GPU tests + # need explicit process group initialization. + # Usage: torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.distributed.init_process_group(backend="nccl", device_id=device) ... - # WAR for https://github.com/pytorch/pytorch/issues/162429 + # Workaround for https://github.com/pytorch/pytorch/issues/162429 + # Using os._exit(0) to avoid hanging during process group cleanup. os._exit(0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 287 - 315, Add a concise module-level docstring above the if __name__ == "__main__": block that explains this file is not a normal pytest test, that it must be run under torchrun/torch.distributed (e.g., torchrun --nproc_per_node=<N> python <file>) because it initializes a distributed process group via torch.distributed.init_process_group and sets CUDA device per LOCAL_RANK, and document that os._exit(0) is an explicit workaround for the linked PyTorch issue to force clean termination after running test_blockscaled_gemm_python_interface loops; mention expected invocation, required env var LOCAL_RANK, and why pytest discovery is avoided.flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py (1)
2001-2007: Renameiterto avoid shadowing Python builtin.The variable
itershadows Python's builtiniter()function. Consider renaming toiter_idxorrow_iterfor clarity.Proposed fix
- for iter in cutlass.range_constexpr(8): - m_row = warp_idx * 8 + iter + for row_iter in cutlass.range_constexpr(8): + m_row = warp_idx * 8 + row_iter m_idx = ( cur_tile_dim1_offset + subtile_idx * epi_tile[1].shape + m_row ) - out_ptr = reg_out_ptrs[iter] + out_ptr = reg_out_ptrs[row_iter]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py` around lines 2001 - 2007, The loop variable named "iter" shadows the Python builtin iter(); rename it (e.g., to "iter_idx" or "row_iter") in the for loop "for iter in cutlass.range_constexpr(8):" and update all uses inside that loop (references to iter -> iter_idx) including where m_row is computed and any downstream uses (m_row, m_idx calculations that use iter). Ensure the new name is used consistently within the surrounding scope (same block where warp_idx, cur_tile_dim1_offset, subtile_idx, and epi_tile are referenced) to avoid shadowing and maintain readability.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 279-284: The assert uses an extremely large scalar "tolerance"
(10000) in the torch.testing.assert_close call which can mask errors; replace
the hardcoded tolerance with dtype-aware tolerances (e.g., set atol/rtol based
on input dtype: tight values for torch.float32, looser for torch.bfloat16/half)
and/or compute tolerance relative to the magnitude of the reference tensors,
document why a nonstandard tolerance is needed, and apply the same change to the
other assert_close blocks referenced (the call at torch.testing.assert_close and
the similar checks around lines 297-313) so comparisons use appropriate
per-dtype atol/rtol instead of the global 10000 value.
---
Nitpick comments:
In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py`:
- Around line 2001-2007: The loop variable named "iter" shadows the Python
builtin iter(); rename it (e.g., to "iter_idx" or "row_iter") in the for loop
"for iter in cutlass.range_constexpr(8):" and update all uses inside that loop
(references to iter -> iter_idx) including where m_row is computed and any
downstream uses (m_row, m_idx calculations that use iter). Ensure the new name
is used consistently within the surrounding scope (same block where warp_idx,
cur_tile_dim1_offset, subtile_idx, and epi_tile are referenced) to avoid
shadowing and maintain readability.
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 52-62: Replace the manual SM version check using
torch.cuda.get_device_capability with the flashinfer helper: import and call
is_sm100a_supported() instead of evaluating major/minor; in the block after
setting device (and after is_cute_dsl_available()), replace the if not (major ==
10 and minor == 0): ... return with if not is_sm100a_supported():
print("Skipping: Cute-dsl backend is only supported on SM100.") return so the
test uses the standardized GPU capability check.
- Around line 111-155: The unpacked variables a_tensor, b_tensor, sfa_tensor and
sfb_tensor from calls to cutlass_torch.cute_tensor_like and
create_scale_factor_tensor are never used; rename them with a leading underscore
(e.g., _a_tensor, _b_tensor, _sfa_tensor, _sfb_tensor) to indicate intentional
non-use and silence the linter. Update the destructuring at the two
cute_tensor_like calls and the two create_scale_factor_tensor calls (references
to a_tensor, b_tensor, sfa_tensor, sfb_tensor) to the underscored names, leaving
the used variables (a_torch, b_torch, sfa_ref, sfa_torch, sfb_ref, sfb_torch)
unchanged.
- Around line 30-33: Add a brief comment above the global lists
c_torch_handle_list and barrier_flag_local_handle_list explaining they are
intentionally kept as module-level globals to prevent garbage collection of
symmetric memory handles (workaround for PyTorch issue `#162429`) and reference
the issue number (and the related note at line 314) so future readers know why
these handles must persist; update the comment near the definitions of
c_torch_handle_list and barrier_flag_local_handle_list to clearly state their
purpose and link to the PyTorch issue.
- Around line 64-65: The variable name "l" in the assignment "l, m = lm" is
ambiguous; rename it to a clear batch dimension name (e.g., batch_size or
num_batches) wherever it's defined and subsequently used in this test (search
for "l" references in this file, especially around the test function and any
loops or shape constructs) and update "l, m = lm" to "batch_size, m = lm" (or
"num_batches, m = lm") and replace all uses of "l" with the new name (including
any variable unpacking, shape assertions, and calls to the GEMM helpers) to
preserve behavior.
- Around line 214-216: Add a clear inline comment above the torch.cuda._sleep
call explaining that the sleep simulates jitter to exercise potential GPU-side
race conditions on the process where my_rank == 1, and replace the hardcoded
long 1s sleep with a small, bounded, randomized jitter (e.g., a few micro- to
milliseconds) or controlled via an environment variable (e.g., TEST_JITTER_US)
so CI timing won’t be brittle; reference the current check (my_rank == 1) and
the call site (torch.cuda._sleep) when making the change.
- Around line 287-315: Add a concise module-level docstring above the if
__name__ == "__main__": block that explains this file is not a normal pytest
test, that it must be run under torchrun/torch.distributed (e.g., torchrun
--nproc_per_node=<N> python <file>) because it initializes a distributed process
group via torch.distributed.init_process_group and sets CUDA device per
LOCAL_RANK, and document that os._exit(0) is an explicit workaround for the
linked PyTorch issue to force clean termination after running
test_blockscaled_gemm_python_interface loops; mention expected invocation,
required env var LOCAL_RANK, and why pytest discovery is avoided.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 73b0c9b6-36d3-4812-8376-fa249cd45e5e
📥 Commits
Reviewing files that changed from the base of the PR and between 5a906be and 96145665808980e65d24dd8f81764dc72824d627.
📒 Files selected for processing (2)
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.pytests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
There was a problem hiding this comment.
Code Review
This pull request introduces 'combine fusion' and tensor swapping support to the Blackwell grouped GEMM kernel. Key additions include PTX wrappers for asynchronous bulk reduction (bf16 and f32), a fused epilogue that performs atomic reductions to global memory, and support for distributed execution parameters like top-k weights and barrier flags. A new multi-GPU test suite is also provided. Feedback focuses on a critical synchronization mismatch in the epilogue that could regress the non-fusion path and suggests improving documentation consistency for the new PTX wrapper functions.
9614566 to
e8f3994
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py`:
- Around line 3010-3056: The current branch builds combine-related tensors
whenever cutlass.const_expr(self._num_ranks > 0) is true, which can dereference
None pointers in non-fusion runs; change the guard to require both num_ranks > 0
and that combine fusion is enabled (the is_combine_fusion flag) before
constructing topk_weights_tensor, idx_src_info_tensor, rank_src_info_tensor,
out_ptrs_tensor, barrier_flag_local_tensor and barrier_flag_multicast_tensor;
update the condition at the block starting with
cutlass.const_expr(self._num_ranks > 0) (and the analogous block at lines noted
3558-3578) to something like cutlass.const_expr(self._num_ranks > 0 and
self.is_combine_fusion) so tensors are only created when combine fusion is
active.
- Around line 1997-1999: The loop in grouped_gemm_masked_blackwell.py uses the
name `iter` which shadows Python's builtin iter() and fails lint; change the
loop variable in the for over cutlass.range_constexpr(8) (currently "for iter in
cutlass.range_constexpr(8):") to a non-builtin name such as `i`, `idx`, or
`offset` and update all uses inside that block (e.g., the computation of `m_row`
and `m_idx`) to the new variable name so logic remains identical (look for
`warp_idx * 8 + iter`, `m_row`, and `m_idx` references).
- Around line 1875-1892: The prefetch loop can read out-of-bounds because
prefetch_num_m may be 256 while only 128 threads run and topk_weights is indexed
before validating prefetch_m_idx; fix by capping prefetch_num_m to the actual
executing lane count (warp/epilogue threads) and by reordering/adding bounds
checks: compute prefetch_m_idx only when tidx < min(prefetch_num_m, 128), then
check prefetch_m_idx < tile_sched_params.masked_m[prefetch_l_idx] (and any
column bounds for topk_weights) before reading topk_weights[prefetch_l_idx,
prefetch_m_idx]; update smem_topk and smem_out_ptrs only inside that guarded
branch and set smem_out_ptrs[tidx] = Int64(0) otherwise. Apply these changes
around the symbols prefetch_num_m, prefetch_m_idx, tidx, topk_weights,
tile_sched_params.masked_m, smem_topk, and smem_out_ptrs.
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 58-62: Replace the direct torch.cuda.get_device_capability(device)
check with the repository helpers: import
flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, call get_compute_capability() (or directly
is_sm100a_supported()) to decide support, and if unsupported call
pytest.skip("Skipping: Cute-dsl backend is only supported on SM100.") instead of
printing and return; update the conditional around the existing test logic
(replacing torch.cuda.get_device_capability usage) to use is_sm100a_supported()
for gating.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 50b3ee75-55e2-44e1-a495-69b92a908e6c
📥 Commits
Reviewing files that changed from the base of the PR and between 96145665808980e65d24dd8f81764dc72824d627 and e8f3994.
📒 Files selected for processing (2)
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.pytests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
| major, minor = torch.cuda.get_device_capability(device) | ||
|
|
||
| if not (major == 10 and minor == 0): | ||
| print("Skipping: Cute-dsl backend is only supported on SM100.") | ||
| return |
There was a problem hiding this comment.
Use flashinfer.utils architecture helpers for test gating.
Line 58/60 uses torch.cuda.get_device_capability directly. Please switch to flashinfer.utils.get_compute_capability() / is_sm100a_supported() for skip logic to match repository test conventions.
As per coding guidelines: tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
58 - 62, Replace the direct torch.cuda.get_device_capability(device) check with
the repository helpers: import flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, call get_compute_capability() (or directly
is_sm100a_supported()) to decide support, and if unsupported call
pytest.skip("Skipping: Cute-dsl backend is only supported on SM100.") instead of
printing and return; update the conditional around the existing test logic
(replacing torch.cuda.get_device_capability usage) to use is_sm100a_supported()
for gating.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (4)
150-155: Prefix unused tuple elements with underscore.Same issue with
sfa_tensorandsfb_tensor.- sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) - sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 150 - 155, The tuple unpacking from create_scale_factor_tensor currently binds unused elements to sfa_tensor and sfb_tensor; change the unpacking to prefix unused tuple elements with an underscore (e.g., use sfa_ref, _sfa_tensor, sfa_torch and sfb_ref, _sfb_tensor, sfb_torch) so unused variables are clearly marked and linters won’t complain; update occurrences where sfa_tensor or sfb_tensor are not used to the new underscore-prefixed names.
111-122: Prefix unused tuple elements with underscore.Static analysis flags
a_tensorandb_tensoras unused. Prefix with underscore to indicate intentional discard.- a_tensor, a_torch = cutlass_torch.cute_tensor_like( + _a_tensor, a_torch = cutlass_torch.cute_tensor_like( a_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( + _b_tensor, b_torch = cutlass_torch.cute_tensor_like( b_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 111 - 122, The tuples returned by cutlass_torch.cute_tensor_like are assigned to a_tensor and b_tensor but those variables are unused; rename the unused tuple elements to start with an underscore (e.g., change a_tensor to _a_tensor and b_tensor to _b_tensor) in the two assignments that call cutlass_torch.cute_tensor_like so static analysis recognizes them as intentionally discarded while keeping the used a_torch and b_torch names unchanged.
291-315: Consider testing multiple output dtypes.Currently
c_dtypeis hardcoded to"bfloat16". The tolerance logic at lines 311-313 suggests float32 was intended to be tested too. Consider parameterizing:- c_dtype = "bfloat16" - for BATCH_SIZE in [16, 64, 128, 256]: + for c_dtype in ["bfloat16", "float32"]: + for BATCH_SIZE in [16, 64, 128, 256]:Or add a comment explaining why only bf16 is tested in CI.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 291 - 315, The test hardcodes c_dtype="bfloat16" but the tolerance branch in the call to test_blockscaled_gemm_python_interface (the tolerance=... if c_dtype == "bfloat16" else 1e-01 expression) implies float32 should also be tested; update the test to either loop over multiple output dtypes (e.g., iterate c_dtype in ["bfloat16","float32"] and call test_blockscaled_gemm_python_interface for each) or, if only bf16 is intended, add a clarifying comment next to the c_dtype assignment explaining why float32 is excluded from CI; adjust the tolerance logic accordingly so it matches the chosen approach.
64-65: Consider renaming ambiguous variableltonum_groupsorbatch_size.The variable
lcan be visually confused with1in some fonts. While it's a common convention in GEMM literature, using a more descriptive name improves readability.- l, m = lm + num_groups, m = lm k, n = knThen update all references to
lthroughout the function.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 64 - 65, The variable named `l` (from the unpacking line `l, m = lm`) is ambiguous—rename it to a descriptive identifier such as `num_groups` (or `batch_size`) by changing the unpacking to `num_groups, m = lm` and updating every usage of `l` throughout the function in this file (including any arithmetic, indexing, or logging) to `num_groups` so all references remain consistent; ensure `k, n = kn` and other variables are unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Line 191: The code calls torch.cuda.get_device_properties("cuda") but you
already derive an explicit device from LOCAL_RANK; change that call to use the
device variable instead (i.e., torch.cuda.get_device_properties(device)) so the
retrieved properties (used to set num_sms) come from the intended GPU; update
any references around num_sms to ensure they use the same device variable.
---
Nitpick comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 150-155: The tuple unpacking from create_scale_factor_tensor
currently binds unused elements to sfa_tensor and sfb_tensor; change the
unpacking to prefix unused tuple elements with an underscore (e.g., use sfa_ref,
_sfa_tensor, sfa_torch and sfb_ref, _sfb_tensor, sfb_torch) so unused variables
are clearly marked and linters won’t complain; update occurrences where
sfa_tensor or sfb_tensor are not used to the new underscore-prefixed names.
- Around line 111-122: The tuples returned by cutlass_torch.cute_tensor_like are
assigned to a_tensor and b_tensor but those variables are unused; rename the
unused tuple elements to start with an underscore (e.g., change a_tensor to
_a_tensor and b_tensor to _b_tensor) in the two assignments that call
cutlass_torch.cute_tensor_like so static analysis recognizes them as
intentionally discarded while keeping the used a_torch and b_torch names
unchanged.
- Around line 291-315: The test hardcodes c_dtype="bfloat16" but the tolerance
branch in the call to test_blockscaled_gemm_python_interface (the tolerance=...
if c_dtype == "bfloat16" else 1e-01 expression) implies float32 should also be
tested; update the test to either loop over multiple output dtypes (e.g.,
iterate c_dtype in ["bfloat16","float32"] and call
test_blockscaled_gemm_python_interface for each) or, if only bf16 is intended,
add a clarifying comment next to the c_dtype assignment explaining why float32
is excluded from CI; adjust the tolerance logic accordingly so it matches the
chosen approach.
- Around line 64-65: The variable named `l` (from the unpacking line `l, m =
lm`) is ambiguous—rename it to a descriptive identifier such as `num_groups` (or
`batch_size`) by changing the unpacking to `num_groups, m = lm` and updating
every usage of `l` throughout the function in this file (including any
arithmetic, indexing, or logging) to `num_groups` so all references remain
consistent; ensure `k, n = kn` and other variables are unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 529141f8-a8b0-4f00-9a29-7102105a9cfd
📥 Commits
Reviewing files that changed from the base of the PR and between e8f3994 and 75f858d580d549c24b862c3328c701ab4201cb99.
📒 Files selected for processing (2)
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.pytests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
| c_torch_handle.buffer_ptrs, dtype=torch.int64, device=device | ||
| ) | ||
|
|
||
| num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count |
There was a problem hiding this comment.
Use the device variable instead of hardcoded "cuda".
Line 191 uses "cuda" which defaults to the current device, but device was explicitly set from LOCAL_RANK. Use the explicit device for consistency.
- num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
+ num_sms = torch.cuda.get_device_properties(device).multi_processor_count🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` at line 191,
The code calls torch.cuda.get_device_properties("cuda") but you already derive
an explicit device from LOCAL_RANK; change that call to use the device variable
instead (i.e., torch.cuda.get_device_properties(device)) so the retrieved
properties (used to set num_sms) come from the intended GPU; update any
references around num_sms to ensure they use the same device variable.
75f858d to
0ec504a
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (7)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (7)
150-155: Prefix unused unpacked variables with underscore.
sfa_tensorandsfb_tensorare never used.- sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device ) - sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 150 - 155, The unpacking of create_scale_factor_tensor returns three values but the middle outputs sfa_tensor and sfb_tensor are unused; rename them to _sfa_tensor and _sfb_tensor (or _ for each) in the assignments that call create_scale_factor_tensor so unused variables are prefixed with an underscore while keeping sfa_ref, sfa_torch, sfb_ref, and sfb_torch unchanged (update the two statements that assign sfa_ref, sfa_tensor, sfa_torch and sfb_ref, sfb_tensor, sfb_torch).
316-317: Document theos._exit(0)workaround inline.The comment references a PyTorch issue but doesn't explain what happens without this workaround. Consider adding a brief explanation for future maintainers.
# WAR for https://github.com/pytorch/pytorch/issues/162429 + # Prevents hang during atexit cleanup of symmetric memory handles os._exit(0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 316 - 317, Add a short inline comment above the os._exit(0) workaround explaining what breaks without it and referencing the linked PyTorch issue: state that tests hang or cause worker processes to leak/segfault (or the observed failure mode in your environment) when the issue occurs, and note that os._exit(0) forcefully terminates the process to avoid the problem; include the issue number (162429) and a one-line note to remove the workaround once the PyTorch bug is fixed.
111-122: Prefix unused unpacked variables with underscore.Static analysis indicates
a_tensorandb_tensorare never used. Prefix with_to indicate intentional discard.- a_tensor, a_torch = cutlass_torch.cute_tensor_like( + _a_tensor, a_torch = cutlass_torch.cute_tensor_like( a_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, ) - b_tensor, b_torch = cutlass_torch.cute_tensor_like( + _b_tensor, b_torch = cutlass_torch.cute_tensor_like( b_ref, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 111 - 122, The variables a_tensor and b_tensor returned from cutlass_torch.cute_tensor_like are unused; rename them to indicate intentional discard by prefixing with an underscore (e.g., _a_tensor and _b_tensor or just _), leaving a_torch and b_torch unchanged so the call and types remain the same; update the two unpacking statements around cutlass_torch.cute_tensor_like to use the underscored names to satisfy static analysis.
52-62: Consider usingpytest.skip()for proper test framework integration.The
print()+returnpattern doesn't integrate with pytest's skip reporting. When tests are skipped, pytest won't show them in its summary. This also applies to lines 82-85 and 89-92.+import pytest + def test_blockscaled_gemm_python_interface(...): if not is_cute_dsl_available(): - print("Skipping: Please `pip install nvidia-cutlass-dsl`") - return + pytest.skip("Please `pip install nvidia-cutlass-dsl`")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 52 - 62, Replace the print(...) + return skip pattern with pytest.skip(...) to integrate with pytest reporting: where the guard uses is_cute_dsl_available() (first block) and the SM100 capability checks (the blocks around the code at lines 82-85 and 89-92), call pytest.skip("message") instead of printing and returning, and ensure pytest is imported at top of the test module; update the messages to match the originals (e.g., "Please `pip install nvidia-cutlass-dsl`" and "Cute-dsl backend is only supported on SM100.").
5-5: Docstring usage instruction conflicts with PR test plan.The docstring says
torchrun --nproc_per_node=4, but the PR test plan specifiesmpirun -np 4 pytest. Consider aligning these or documenting both options if both are valid.-USAGE: torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py +USAGE: torchrun --nproc_per_node=4 tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py + or: mpirun -np 4 pytest tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` at line 5, Update the conflicting usage instruction in the module docstring: the current line "USAGE: torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" conflicts with the PR test plan "mpirun -np 4 pytest", so either replace it to match the PR test plan or document both valid ways (e.g., show both "torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" and "mpirun -np 4 pytest tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py") and clarify which is preferred; edit the docstring at the top of tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py to make the usage consistent with the PR test plan or include both options.
64-65: Rename ambiguous variablelto improve readability.Static analysis flags
las ambiguous (E741) since it resembles1. Consider renaming tobatch_sizeornum_expertsto clarify intent.- l, m = lm + num_experts, m = lm k, n = knThis would require updating all subsequent uses of
lin the function.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 64 - 65, Rename the ambiguous variable `l` in the tuple unpacking "l, m = lm" to a clearer name such as `batch_size` or `num_experts` and update every subsequent occurrence of `l` in this test file/function to use the new name; keep the unpacking ("batch_size, m = lm") and the other tuple unpacking ("k, n = kn") intact and run tests to ensure no remaining references to the old `l` remain.
287-317: Consider aligning with pytest pattern used in other multi-GPU tests.The file is documented for
torchrun --nproc_per_node=4invocation (line 5), which is clear and correct. However, other multi-GPU tests in the repository (e.g.,tests/comm/test_allreduce_unified_api.py) use pytest-discoverable test functions with distributed fixtures, allowing them to work with standard pytest invocation. For consistency and better CI integration, consider refactoring this test to follow the pytest pattern used elsewhere, or document the specific invocation method in the PR test plan if torchrun is the preferred approach.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines 287 - 317, The current __main__ torchrun-driven script should be converted to a pytest-discoverable test: remove the if __name__ == "__main__" block that manually reads LOCAL_RANK, sets device, calls torch.distributed.init_process_group and os._exit, and instead create a pytest test function (e.g., def test_multi_gpu_blockscaled_gemm(distributed_rank_fixture):) that follows the pattern from tests/comm/test_allreduce_unified_api.py; use pytest.mark.parametrize for BATCH_SIZE values, use the distributed fixture to set the CUDA device and ensure the process group is initialized, then call test_blockscaled_gemm_python_interface(...) with the same parameters (preserve ab_dtype, sf_dtype, c_dtype logic and tolerance selection), and remove the os._exit workaround. This will make the test discoverable by pytest and compatible with CI distributed fixtures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 150-155: The unpacking of create_scale_factor_tensor returns three
values but the middle outputs sfa_tensor and sfb_tensor are unused; rename them
to _sfa_tensor and _sfb_tensor (or _ for each) in the assignments that call
create_scale_factor_tensor so unused variables are prefixed with an underscore
while keeping sfa_ref, sfa_torch, sfb_ref, and sfb_torch unchanged (update the
two statements that assign sfa_ref, sfa_tensor, sfa_torch and sfb_ref,
sfb_tensor, sfb_torch).
- Around line 316-317: Add a short inline comment above the os._exit(0)
workaround explaining what breaks without it and referencing the linked PyTorch
issue: state that tests hang or cause worker processes to leak/segfault (or the
observed failure mode in your environment) when the issue occurs, and note that
os._exit(0) forcefully terminates the process to avoid the problem; include the
issue number (162429) and a one-line note to remove the workaround once the
PyTorch bug is fixed.
- Around line 111-122: The variables a_tensor and b_tensor returned from
cutlass_torch.cute_tensor_like are unused; rename them to indicate intentional
discard by prefixing with an underscore (e.g., _a_tensor and _b_tensor or just
_), leaving a_torch and b_torch unchanged so the call and types remain the same;
update the two unpacking statements around cutlass_torch.cute_tensor_like to use
the underscored names to satisfy static analysis.
- Around line 52-62: Replace the print(...) + return skip pattern with
pytest.skip(...) to integrate with pytest reporting: where the guard uses
is_cute_dsl_available() (first block) and the SM100 capability checks (the
blocks around the code at lines 82-85 and 89-92), call pytest.skip("message")
instead of printing and returning, and ensure pytest is imported at top of the
test module; update the messages to match the originals (e.g., "Please `pip
install nvidia-cutlass-dsl`" and "Cute-dsl backend is only supported on
SM100.").
- Line 5: Update the conflicting usage instruction in the module docstring: the
current line "USAGE: torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" conflicts with the PR test
plan "mpirun -np 4 pytest", so either replace it to match the PR test plan or
document both valid ways (e.g., show both "torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" and "mpirun -np 4 pytest
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py") and clarify
which is preferred; edit the docstring at the top of
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py to make the usage
consistent with the PR test plan or include both options.
- Around line 64-65: Rename the ambiguous variable `l` in the tuple unpacking
"l, m = lm" to a clearer name such as `batch_size` or `num_experts` and update
every subsequent occurrence of `l` in this test file/function to use the new
name; keep the unpacking ("batch_size, m = lm") and the other tuple unpacking
("k, n = kn") intact and run tests to ensure no remaining references to the old
`l` remain.
- Around line 287-317: The current __main__ torchrun-driven script should be
converted to a pytest-discoverable test: remove the if __name__ == "__main__"
block that manually reads LOCAL_RANK, sets device, calls
torch.distributed.init_process_group and os._exit, and instead create a pytest
test function (e.g., def
test_multi_gpu_blockscaled_gemm(distributed_rank_fixture):) that follows the
pattern from tests/comm/test_allreduce_unified_api.py; use
pytest.mark.parametrize for BATCH_SIZE values, use the distributed fixture to
set the CUDA device and ensure the process group is initialized, then call
test_blockscaled_gemm_python_interface(...) with the same parameters (preserve
ab_dtype, sf_dtype, c_dtype logic and tolerance selection), and remove the
os._exit workaround. This will make the test discoverable by pytest and
compatible with CI distributed fixtures.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 985de003-29cf-4fef-a980-053dc3b7b59a
📥 Commits
Reviewing files that changed from the base of the PR and between 75f858d580d549c24b862c3328c701ab4201cb99 and 0ec504a.
📒 Files selected for processing (2)
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.pytests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
|
/bot run |
|
Hi, May I ask a question? Why is "is_swap_ab" needed? The comment says it's "needed for combine fusion where output is M-major." The mma result in tensor memory is M-major. But I think it can be load to RF and then store it to shared memory via "stmatrix.xxx.trans". Then the result in shared memory is N-major and it didn't need the "swap_ab". Thanks! |
Good question, with Blackwell subtiling pipelining the mem copy out of TMEM to REG/smem/gmem the subtile shape is <128,32> for mma tile <128,128>. Which means we can only write 32 elts per token per subtile which is not efficient for mem writes, with swap_ab, we can write 128 contiguous elts per token per subtile (256B for bf16). |
|
/bot run |
31a1831 to
14ddb4e
Compare
|
/bot run |
Summary
Enables sgl-project/sglang#21877
cp.reduce.async.bulkPTX instructions (bf16andf32variants)MaskedSchedulerwithis_swap_absupport so the scheduler correctly computes tile coordinates when A/B inputs are swapped (needed for combine fusion where output is M-major)barrier_flag_local/barrier_flag_multicastparameters for cross-rank synchronization via spin-lock barriersKey changes
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py: Newcp_reduce_bf16_add/cp_reduce_f32_addPTX wrappers, custommake_fused_smem_layout_epifor the combine path,is_swap_ablogic in scheduler, and new parameters (topk_weights,idx_src_info,rank_src_info,out_ptrs, barrier flags) threaded through the full stacktests/comm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py: New multi-GPU test that validates the fused gemm+combine against a reference implementation usingmpirunTest plan
mpirun -np 4 pytest tests/comm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.pyon SM100+ hardwareSummary by CodeRabbit
New Features
Tests
Bug Fixes