Skip to content

feat: Add CuTe DSL grouped-gemm + combine fusion support#2944

Open
nvcastet wants to merge 3 commits into
flashinfer-ai:mainfrom
nvcastet:cleanup_add_moe_combine_v2
Open

feat: Add CuTe DSL grouped-gemm + combine fusion support#2944
nvcastet wants to merge 3 commits into
flashinfer-ai:mainfrom
nvcastet:cleanup_add_moe_combine_v2

Conversation

@nvcastet
Copy link
Copy Markdown

@nvcastet nvcastet commented Apr 1, 2026

Summary

Enables sgl-project/sglang#21877

  • Add fused grouped-GEMM + combine operation to the CuTe DSL Blackwell masked GEMM kernel, enabling the GEMM epilogue to perform weighted scatter-reduce (combine) directly into multi-rank output buffers using cp.reduce.async.bulk PTX instructions (bf16 and f32 variants)
  • Extend MaskedScheduler with is_swap_ab support so the scheduler correctly computes tile coordinates when A/B inputs are swapped (needed for combine fusion where output is M-major)
  • Add barrier_flag_local/barrier_flag_multicast parameters for cross-rank synchronization via spin-lock barriers
  • Add multi-GPU test covering the fused grouped-gemm + combine path

Key changes

  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py: New cp_reduce_bf16_add/cp_reduce_f32_add PTX wrappers, custom make_fused_smem_layout_epi for the combine path, is_swap_ab logic in scheduler, and new parameters (topk_weights, idx_src_info, rank_src_info, out_ptrs, barrier flags) threaded through the full stack
  • tests/comm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py: New multi-GPU test that validates the fused gemm+combine against a reference implementation using mpirun

Test plan

  • mpirun -np 4 pytest tests/comm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py on SM100+ hardware
  • Verify existing grouped-gemm tests still pass (non-fusion path unchanged)

Summary by CodeRabbit

  • New Features

    • Block-scaled GEMM combine-fusion: fused weight prefetch, per-tile async reduction stores, and optional AB swap.
    • Host API extended with optional top-k weights, routing info, output-pointer inputs, rank/count, and barrier-flag controls for distributed fusion runs.
  • Tests

    • Added multi-GPU integration test validating the combine-fusion path, routing, and distributed synchronization.
  • Bug Fixes

    • Persistent scheduler sizing improved by capping active clusters to available tiles; stage computation updated for fusion prefetch buffers.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added AB-swap and optional combine-fusion to the SM100 block-scaled persistent GEMM: new kernel knobs (num_ranks, is_combine_fusion, is_swap_ab), scheduler/grid remapping, fused‑SMEM epilogue with async SMEM→GMEM reduce and distributed barrier wiring, extended DSL/host signatures, and a new multi‑GPU test.

Changes

Block‑scaled persistent GEMM (kernel, scheduler, host, DSL, epilogue, test)

Layer / File(s) Summary
Data / Params
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
Added is_swap_ab to MaskedSchedulerParams and capped persistent cluster count by total tile clusters.
Scheduler / Grid
.../grouped_gemm_masked_blackwell.py
Remapped persistent work/tile coordinate linearization when is_swap_ab is true; _compute_grid now carries is_swap_ab.
SMEM layout / Epilogue helpers
.../grouped_gemm_masked_blackwell.py
Added make_fused_smem_layout_epi and inline‑asm async reducers cp_reduce_bf16_add / cp_reduce_f32_add.
Kernel config / ctor
.../grouped_gemm_masked_blackwell.py
Sm100BlockScaledPersistentDenseGemmKernel.__init__ accepts num_ranks, is_combine_fusion, is_swap_ab.
Stage / layout sizing
.../grouped_gemm_masked_blackwell.py
_compute_stages accounts for combine‑fusion extra SMEM prefetch buffers when is_combine_fusion=True.
Kernel implementation
.../grouped_gemm_masked_blackwell.py
When is_combine_fusion: allocate extra SMEM prefetch buffers (out_ptrs, topk_weights), disable TMA C-store path (tma_atom_c=None), S2R-load topk_weights, scale accumulators by topk (and alpha), and perform async bulk SMEM→GMEM reduction-add into computed outputs with new reducers; replace TMA completion with explicit distributed barrier/spin-lock (barrier_flag_local, barrier_flag_multicast).
Host / DSL wiring
.../grouped_gemm_masked_blackwell.py
Extended DSL compiled-kernel factory, wrapper constructors/calls, and returned tensor_api to accept topk_weights, idx_src_info, rank_src_info, out_ptrs, num_ranks, barrier_flag_local, barrier_flag_multicast, is_combine_fusion, is_swap_ab; grouped_gemm_nt_masked signature updated, supports AB swap (swaps lhs/rhs) and validates combine‑fusion inputs.
Tests
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
New multi‑GPU distributed test exercising combine‑fusion + AB swap: builds routing/topk metadata, allocates symmetric outputs and barrier flags, runs the kernel across ranks with is_combine_fusion=True and is_swap_ab=True, and validates against a masked, weighted einsum reference.

Sequence Diagram

sequenceDiagram
    participant Host as Host/Launcher
    participant Rank as Rank Process
    participant GPU as GPU Kernel
    participant SymmMem as SymmetricMemory
    participant Barrier as Barrier/Flags

    Host ->> Rank: rendezvous & allocate symmetric outputs/flags
    Rank ->> Barrier: initial global barrier
    Rank ->> GPU: launch kernel (A,B,topk,idx,rank_info,out_ptrs,barrier_flags,is_combine_fusion,is_swap_ab)
    GPU ->> SymmMem: prefetch topk_weights / out_ptrs into SMEM (S2R)
    GPU ->> GPU: compute GEMM tiles (scheduler honours is_swap_ab)
    GPU ->> GPU: scale accumulators by topk weights (and alpha)
    GPU ->> SymmMem: cp.reduce.async.bulk...add (bulk SMEM → GMEM) using out_ptrs/idx/rank_info
    GPU ->> Barrier: bulk-wait + set/release barrier flags
    GPU -->> Rank: kernel completes
    Rank ->> Host: copy & validate symmetric output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci, cute-dsl

Suggested reviewers

  • dhiraj113
  • yzh119
  • bkryu
  • nv-yunzheq

Poem

🐰 I hopped through ranks and shared‑mem light,
Prefetched weights and fused them tight,
Async reduces hummed a cunning tune,
Flags blinked softly beneath the moon,
GEMM done — carrots for the night! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature being added: CuTe DSL grouped-gemm + combine fusion support.
Description check ✅ Passed The description is mostly complete with a clear summary, key changes, and test plan. However, the test file path in the description differs from the raw_summary (tests/comm vs tests/gemm), and the PR template checklist items are not marked as complete.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (7)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (6)

52-62: Use flashinfer.utils.is_sm100a_supported() for architecture check.

Per coding guidelines, tests should use flashinfer utility functions for GPU capability checks. Replace the manual capability check with the standardized helper.

Proposed fix
+from flashinfer.utils import is_sm100a_supported
+
 def test_blockscaled_gemm_python_interface(
     ...
 ):
     if not is_cute_dsl_available():
         print("Skipping: Please `pip install nvidia-cutlass-dsl`")
         return
 
     torch.manual_seed(42)
     device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
-    major, minor = torch.cuda.get_device_capability(device)
-
-    if not (major == 10 and minor == 0):
+    if not is_sm100a_supported(device):
         print("Skipping: Cute-dsl backend is only supported on SM100.")
         return

As per coding guidelines: Use flashinfer.utils functions (is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
52 - 62, Replace the manual SM version check using
torch.cuda.get_device_capability with the flashinfer helper: import and call
is_sm100a_supported() instead of evaluating major/minor; in the block after
setting device (and after is_cute_dsl_available()), replace the if not (major ==
10 and minor == 0): ... return with if not is_sm100a_supported():
print("Skipping: Cute-dsl backend is only supported on SM100.") return so the
test uses the standardized GPU capability check.

111-155: Prefix unused unpacked variables with underscore.

The variables a_tensor, b_tensor, sfa_tensor, and sfb_tensor are unpacked but never used. Prefix them with _ to indicate intentional non-use and silence linter warnings.

Proposed fix
-    a_tensor, a_torch = cutlass_torch.cute_tensor_like(
+    _a_tensor, a_torch = cutlass_torch.cute_tensor_like(
         a_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
-    b_tensor, b_torch = cutlass_torch.cute_tensor_like(
+    _b_tensor, b_torch = cutlass_torch.cute_tensor_like(
         b_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
     ...
-    sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
+    sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor(
         l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
-    sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
+    sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor(
         l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
111 - 155, The unpacked variables a_tensor, b_tensor, sfa_tensor and sfb_tensor
from calls to cutlass_torch.cute_tensor_like and create_scale_factor_tensor are
never used; rename them with a leading underscore (e.g., _a_tensor, _b_tensor,
_sfa_tensor, _sfb_tensor) to indicate intentional non-use and silence the
linter. Update the destructuring at the two cute_tensor_like calls and the two
create_scale_factor_tensor calls (references to a_tensor, b_tensor, sfa_tensor,
sfb_tensor) to the underscored names, leaving the used variables (a_torch,
b_torch, sfa_ref, sfa_torch, sfb_ref, sfb_torch) unchanged.

30-33: Consider documenting the purpose of global handle lists.

The global lists c_torch_handle_list and barrier_flag_local_handle_list prevent garbage collection of symmetric memory handles. Add a comment explaining this is necessary due to the PyTorch issue referenced at line 314.

Suggested documentation
-# WAR for https://github.com/pytorch/pytorch/issues/162429
+# WAR for https://github.com/pytorch/pytorch/issues/162429
+# Keep symmetric memory handles alive to prevent premature GC during test execution
 c_torch_handle_list = []
 barrier_flag_local_handle_list = []
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
30 - 33, Add a brief comment above the global lists c_torch_handle_list and
barrier_flag_local_handle_list explaining they are intentionally kept as
module-level globals to prevent garbage collection of symmetric memory handles
(workaround for PyTorch issue `#162429`) and reference the issue number (and the
related note at line 314) so future readers know why these handles must persist;
update the comment near the definitions of c_torch_handle_list and
barrier_flag_local_handle_list to clearly state their purpose and link to the
PyTorch issue.

64-65: Rename ambiguous variable l to batch_size or num_batches.

The variable name l is flagged as ambiguous (E741). In the context of batched GEMM, this represents the batch dimension, so a clearer name like batch_size or num_batches would improve readability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
64 - 65, The variable name "l" in the assignment "l, m = lm" is ambiguous;
rename it to a clear batch dimension name (e.g., batch_size or num_batches)
wherever it's defined and subsequently used in this test (search for "l"
references in this file, especially around the test function and any loops or
shape constructs) and update "l, m = lm" to "batch_size, m = lm" (or
"num_batches, m = lm") and replace all uses of "l" with the new name (including
any variable unpacking, shape assertions, and calls to the GEMM helpers) to
preserve behavior.

214-216: Document the purpose of the jitter injection.

The sleep injection on rank 1 is presumably to test race condition handling, but this should be documented. The hardcoded 1-second sleep could also cause test flakiness in CI environments with different timing characteristics.

Suggested improvement
         if my_rank == 1:
-            # Inject jitter to trigger potential race conditions
+            # Inject jitter on rank 1 to stress-test the distributed barrier
+            # synchronization. This ensures slower ranks are handled correctly
+            # by the spin-lock barrier mechanism.
             torch.cuda._sleep(1000000000)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
214 - 216, Add a clear inline comment above the torch.cuda._sleep call
explaining that the sleep simulates jitter to exercise potential GPU-side race
conditions on the process where my_rank == 1, and replace the hardcoded long 1s
sleep with a small, bounded, randomized jitter (e.g., a few micro- to
milliseconds) or controlled via an environment variable (e.g., TEST_JITTER_US)
so CI timing won’t be brittle; reference the current check (my_rank == 1) and
the call site (torch.cuda._sleep) when making the change.

287-315: Document the test execution pattern more clearly.

The test uses torchrun for distributed execution and os._exit(0) as a workaround. Consider adding a more detailed docstring explaining:

  1. Why this isn't a standard pytest test
  2. The expected invocation command
  3. What the workaround addresses
Suggested documentation improvement
 if __name__ == "__main__":
+    # This test requires distributed execution via torchrun.
+    # Standard pytest discovery is not used because multi-GPU tests
+    # need explicit process group initialization.
+    # Usage: torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
     device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
     torch.distributed.init_process_group(backend="nccl", device_id=device)
     ...
-    # WAR for https://github.com/pytorch/pytorch/issues/162429
+    # Workaround for https://github.com/pytorch/pytorch/issues/162429
+    # Using os._exit(0) to avoid hanging during process group cleanup.
     os._exit(0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
287 - 315, Add a concise module-level docstring above the if __name__ ==
"__main__": block that explains this file is not a normal pytest test, that it
must be run under torchrun/torch.distributed (e.g., torchrun
--nproc_per_node=<N> python <file>) because it initializes a distributed process
group via torch.distributed.init_process_group and sets CUDA device per
LOCAL_RANK, and document that os._exit(0) is an explicit workaround for the
linked PyTorch issue to force clean termination after running
test_blockscaled_gemm_python_interface loops; mention expected invocation,
required env var LOCAL_RANK, and why pytest discovery is avoided.
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py (1)

2001-2007: Rename iter to avoid shadowing Python builtin.

The variable iter shadows Python's builtin iter() function. Consider renaming to iter_idx or row_iter for clarity.

Proposed fix
-                            for iter in cutlass.range_constexpr(8):
-                                m_row = warp_idx * 8 + iter
+                            for row_iter in cutlass.range_constexpr(8):
+                                m_row = warp_idx * 8 + row_iter
                                 m_idx = (
                                     cur_tile_dim1_offset
                                     + subtile_idx * epi_tile[1].shape
                                     + m_row
                                 )
-                                out_ptr = reg_out_ptrs[iter]
+                                out_ptr = reg_out_ptrs[row_iter]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py` around lines 2001 -
2007, The loop variable named "iter" shadows the Python builtin iter(); rename
it (e.g., to "iter_idx" or "row_iter") in the for loop "for iter in
cutlass.range_constexpr(8):" and update all uses inside that loop (references to
iter -> iter_idx) including where m_row is computed and any downstream uses
(m_row, m_idx calculations that use iter). Ensure the new name is used
consistently within the surrounding scope (same block where warp_idx,
cur_tile_dim1_offset, subtile_idx, and epi_tile are referenced) to avoid
shadowing and maintain readability.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 279-284: The assert uses an extremely large scalar "tolerance"
(10000) in the torch.testing.assert_close call which can mask errors; replace
the hardcoded tolerance with dtype-aware tolerances (e.g., set atol/rtol based
on input dtype: tight values for torch.float32, looser for torch.bfloat16/half)
and/or compute tolerance relative to the magnitude of the reference tensors,
document why a nonstandard tolerance is needed, and apply the same change to the
other assert_close blocks referenced (the call at torch.testing.assert_close and
the similar checks around lines 297-313) so comparisons use appropriate
per-dtype atol/rtol instead of the global 10000 value.

---

Nitpick comments:
In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py`:
- Around line 2001-2007: The loop variable named "iter" shadows the Python
builtin iter(); rename it (e.g., to "iter_idx" or "row_iter") in the for loop
"for iter in cutlass.range_constexpr(8):" and update all uses inside that loop
(references to iter -> iter_idx) including where m_row is computed and any
downstream uses (m_row, m_idx calculations that use iter). Ensure the new name
is used consistently within the surrounding scope (same block where warp_idx,
cur_tile_dim1_offset, subtile_idx, and epi_tile are referenced) to avoid
shadowing and maintain readability.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 52-62: Replace the manual SM version check using
torch.cuda.get_device_capability with the flashinfer helper: import and call
is_sm100a_supported() instead of evaluating major/minor; in the block after
setting device (and after is_cute_dsl_available()), replace the if not (major ==
10 and minor == 0): ... return with if not is_sm100a_supported():
print("Skipping: Cute-dsl backend is only supported on SM100.") return so the
test uses the standardized GPU capability check.
- Around line 111-155: The unpacked variables a_tensor, b_tensor, sfa_tensor and
sfb_tensor from calls to cutlass_torch.cute_tensor_like and
create_scale_factor_tensor are never used; rename them with a leading underscore
(e.g., _a_tensor, _b_tensor, _sfa_tensor, _sfb_tensor) to indicate intentional
non-use and silence the linter. Update the destructuring at the two
cute_tensor_like calls and the two create_scale_factor_tensor calls (references
to a_tensor, b_tensor, sfa_tensor, sfb_tensor) to the underscored names, leaving
the used variables (a_torch, b_torch, sfa_ref, sfa_torch, sfb_ref, sfb_torch)
unchanged.
- Around line 30-33: Add a brief comment above the global lists
c_torch_handle_list and barrier_flag_local_handle_list explaining they are
intentionally kept as module-level globals to prevent garbage collection of
symmetric memory handles (workaround for PyTorch issue `#162429`) and reference
the issue number (and the related note at line 314) so future readers know why
these handles must persist; update the comment near the definitions of
c_torch_handle_list and barrier_flag_local_handle_list to clearly state their
purpose and link to the PyTorch issue.
- Around line 64-65: The variable name "l" in the assignment "l, m = lm" is
ambiguous; rename it to a clear batch dimension name (e.g., batch_size or
num_batches) wherever it's defined and subsequently used in this test (search
for "l" references in this file, especially around the test function and any
loops or shape constructs) and update "l, m = lm" to "batch_size, m = lm" (or
"num_batches, m = lm") and replace all uses of "l" with the new name (including
any variable unpacking, shape assertions, and calls to the GEMM helpers) to
preserve behavior.
- Around line 214-216: Add a clear inline comment above the torch.cuda._sleep
call explaining that the sleep simulates jitter to exercise potential GPU-side
race conditions on the process where my_rank == 1, and replace the hardcoded
long 1s sleep with a small, bounded, randomized jitter (e.g., a few micro- to
milliseconds) or controlled via an environment variable (e.g., TEST_JITTER_US)
so CI timing won’t be brittle; reference the current check (my_rank == 1) and
the call site (torch.cuda._sleep) when making the change.
- Around line 287-315: Add a concise module-level docstring above the if
__name__ == "__main__": block that explains this file is not a normal pytest
test, that it must be run under torchrun/torch.distributed (e.g., torchrun
--nproc_per_node=<N> python <file>) because it initializes a distributed process
group via torch.distributed.init_process_group and sets CUDA device per
LOCAL_RANK, and document that os._exit(0) is an explicit workaround for the
linked PyTorch issue to force clean termination after running
test_blockscaled_gemm_python_interface loops; mention expected invocation,
required env var LOCAL_RANK, and why pytest discovery is avoided.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 73b0c9b6-36d3-4812-8376-fa249cd45e5e

📥 Commits

Reviewing files that changed from the base of the PR and between 5a906be and 96145665808980e65d24dd8f81764dc72824d627.

📒 Files selected for processing (2)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
  • tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py

Comment thread tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces 'combine fusion' and tensor swapping support to the Blackwell grouped GEMM kernel. Key additions include PTX wrappers for asynchronous bulk reduction (bf16 and f32), a fused epilogue that performs atomic reductions to global memory, and support for distributed execution parameters like top-k weights and barrier flags. A new multi-GPU test suite is also provided. Feedback focuses on a critical synchronization mismatch in the epilogue that could regress the non-fusion path and suggests improving documentation consistency for the new PTX wrapper functions.

Comment thread flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py Outdated
Comment thread flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py`:
- Around line 3010-3056: The current branch builds combine-related tensors
whenever cutlass.const_expr(self._num_ranks > 0) is true, which can dereference
None pointers in non-fusion runs; change the guard to require both num_ranks > 0
and that combine fusion is enabled (the is_combine_fusion flag) before
constructing topk_weights_tensor, idx_src_info_tensor, rank_src_info_tensor,
out_ptrs_tensor, barrier_flag_local_tensor and barrier_flag_multicast_tensor;
update the condition at the block starting with
cutlass.const_expr(self._num_ranks > 0) (and the analogous block at lines noted
3558-3578) to something like cutlass.const_expr(self._num_ranks > 0 and
self.is_combine_fusion) so tensors are only created when combine fusion is
active.
- Around line 1997-1999: The loop in grouped_gemm_masked_blackwell.py uses the
name `iter` which shadows Python's builtin iter() and fails lint; change the
loop variable in the for over cutlass.range_constexpr(8) (currently "for iter in
cutlass.range_constexpr(8):") to a non-builtin name such as `i`, `idx`, or
`offset` and update all uses inside that block (e.g., the computation of `m_row`
and `m_idx`) to the new variable name so logic remains identical (look for
`warp_idx * 8 + iter`, `m_row`, and `m_idx` references).
- Around line 1875-1892: The prefetch loop can read out-of-bounds because
prefetch_num_m may be 256 while only 128 threads run and topk_weights is indexed
before validating prefetch_m_idx; fix by capping prefetch_num_m to the actual
executing lane count (warp/epilogue threads) and by reordering/adding bounds
checks: compute prefetch_m_idx only when tidx < min(prefetch_num_m, 128), then
check prefetch_m_idx < tile_sched_params.masked_m[prefetch_l_idx] (and any
column bounds for topk_weights) before reading topk_weights[prefetch_l_idx,
prefetch_m_idx]; update smem_topk and smem_out_ptrs only inside that guarded
branch and set smem_out_ptrs[tidx] = Int64(0) otherwise. Apply these changes
around the symbols prefetch_num_m, prefetch_m_idx, tidx, topk_weights,
tile_sched_params.masked_m, smem_topk, and smem_out_ptrs.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 58-62: Replace the direct torch.cuda.get_device_capability(device)
check with the repository helpers: import
flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, call get_compute_capability() (or directly
is_sm100a_supported()) to decide support, and if unsupported call
pytest.skip("Skipping: Cute-dsl backend is only supported on SM100.") instead of
printing and return; update the conditional around the existing test logic
(replacing torch.cuda.get_device_capability usage) to use is_sm100a_supported()
for gating.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 50b3ee75-55e2-44e1-a495-69b92a908e6c

📥 Commits

Reviewing files that changed from the base of the PR and between 96145665808980e65d24dd8f81764dc72824d627 and e8f3994.

📒 Files selected for processing (2)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
  • tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py

Comment thread flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
Comment thread flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py Outdated
Comment thread flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py Outdated
Comment on lines +58 to +62
major, minor = torch.cuda.get_device_capability(device)

if not (major == 10 and minor == 0):
print("Skipping: Cute-dsl backend is only supported on SM100.")
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use flashinfer.utils architecture helpers for test gating.

Line 58/60 uses torch.cuda.get_device_capability directly. Please switch to flashinfer.utils.get_compute_capability() / is_sm100a_supported() for skip logic to match repository test conventions.

As per coding guidelines: tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
58 - 62, Replace the direct torch.cuda.get_device_capability(device) check with
the repository helpers: import flashinfer.utils.get_compute_capability and
flashinfer.utils.is_sm100a_supported, call get_compute_capability() (or directly
is_sm100a_supported()) to decide support, and if unsupported call
pytest.skip("Skipping: Cute-dsl backend is only supported on SM100.") instead of
printing and return; update the conditional around the existing test logic
(replacing torch.cuda.get_device_capability usage) to use is_sm100a_supported()
for gating.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (4)

150-155: Prefix unused tuple elements with underscore.

Same issue with sfa_tensor and sfb_tensor.

-    sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
+    sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor(
         l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
-    sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
+    sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor(
         l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
150 - 155, The tuple unpacking from create_scale_factor_tensor currently binds
unused elements to sfa_tensor and sfb_tensor; change the unpacking to prefix
unused tuple elements with an underscore (e.g., use sfa_ref, _sfa_tensor,
sfa_torch and sfb_ref, _sfb_tensor, sfb_torch) so unused variables are clearly
marked and linters won’t complain; update occurrences where sfa_tensor or
sfb_tensor are not used to the new underscore-prefixed names.

111-122: Prefix unused tuple elements with underscore.

Static analysis flags a_tensor and b_tensor as unused. Prefix with underscore to indicate intentional discard.

-    a_tensor, a_torch = cutlass_torch.cute_tensor_like(
+    _a_tensor, a_torch = cutlass_torch.cute_tensor_like(
         a_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
-    b_tensor, b_torch = cutlass_torch.cute_tensor_like(
+    _b_tensor, b_torch = cutlass_torch.cute_tensor_like(
         b_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
111 - 122, The tuples returned by cutlass_torch.cute_tensor_like are assigned to
a_tensor and b_tensor but those variables are unused; rename the unused tuple
elements to start with an underscore (e.g., change a_tensor to _a_tensor and
b_tensor to _b_tensor) in the two assignments that call
cutlass_torch.cute_tensor_like so static analysis recognizes them as
intentionally discarded while keeping the used a_torch and b_torch names
unchanged.

291-315: Consider testing multiple output dtypes.

Currently c_dtype is hardcoded to "bfloat16". The tolerance logic at lines 311-313 suggests float32 was intended to be tested too. Consider parameterizing:

-    c_dtype = "bfloat16"
-    for BATCH_SIZE in [16, 64, 128, 256]:
+    for c_dtype in ["bfloat16", "float32"]:
+      for BATCH_SIZE in [16, 64, 128, 256]:

Or add a comment explaining why only bf16 is tested in CI.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
291 - 315, The test hardcodes c_dtype="bfloat16" but the tolerance branch in the
call to test_blockscaled_gemm_python_interface (the tolerance=... if c_dtype ==
"bfloat16" else 1e-01 expression) implies float32 should also be tested; update
the test to either loop over multiple output dtypes (e.g., iterate c_dtype in
["bfloat16","float32"] and call test_blockscaled_gemm_python_interface for each)
or, if only bf16 is intended, add a clarifying comment next to the c_dtype
assignment explaining why float32 is excluded from CI; adjust the tolerance
logic accordingly so it matches the chosen approach.

64-65: Consider renaming ambiguous variable l to num_groups or batch_size.

The variable l can be visually confused with 1 in some fonts. While it's a common convention in GEMM literature, using a more descriptive name improves readability.

-    l, m = lm
+    num_groups, m = lm
     k, n = kn

Then update all references to l throughout the function.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
64 - 65, The variable named `l` (from the unpacking line `l, m = lm`) is
ambiguous—rename it to a descriptive identifier such as `num_groups` (or
`batch_size`) by changing the unpacking to `num_groups, m = lm` and updating
every usage of `l` throughout the function in this file (including any
arithmetic, indexing, or logging) to `num_groups` so all references remain
consistent; ensure `k, n = kn` and other variables are unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Line 191: The code calls torch.cuda.get_device_properties("cuda") but you
already derive an explicit device from LOCAL_RANK; change that call to use the
device variable instead (i.e., torch.cuda.get_device_properties(device)) so the
retrieved properties (used to set num_sms) come from the intended GPU; update
any references around num_sms to ensure they use the same device variable.

---

Nitpick comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 150-155: The tuple unpacking from create_scale_factor_tensor
currently binds unused elements to sfa_tensor and sfb_tensor; change the
unpacking to prefix unused tuple elements with an underscore (e.g., use sfa_ref,
_sfa_tensor, sfa_torch and sfb_ref, _sfb_tensor, sfb_torch) so unused variables
are clearly marked and linters won’t complain; update occurrences where
sfa_tensor or sfb_tensor are not used to the new underscore-prefixed names.
- Around line 111-122: The tuples returned by cutlass_torch.cute_tensor_like are
assigned to a_tensor and b_tensor but those variables are unused; rename the
unused tuple elements to start with an underscore (e.g., change a_tensor to
_a_tensor and b_tensor to _b_tensor) in the two assignments that call
cutlass_torch.cute_tensor_like so static analysis recognizes them as
intentionally discarded while keeping the used a_torch and b_torch names
unchanged.
- Around line 291-315: The test hardcodes c_dtype="bfloat16" but the tolerance
branch in the call to test_blockscaled_gemm_python_interface (the tolerance=...
if c_dtype == "bfloat16" else 1e-01 expression) implies float32 should also be
tested; update the test to either loop over multiple output dtypes (e.g.,
iterate c_dtype in ["bfloat16","float32"] and call
test_blockscaled_gemm_python_interface for each) or, if only bf16 is intended,
add a clarifying comment next to the c_dtype assignment explaining why float32
is excluded from CI; adjust the tolerance logic accordingly so it matches the
chosen approach.
- Around line 64-65: The variable named `l` (from the unpacking line `l, m =
lm`) is ambiguous—rename it to a descriptive identifier such as `num_groups` (or
`batch_size`) by changing the unpacking to `num_groups, m = lm` and updating
every usage of `l` throughout the function in this file (including any
arithmetic, indexing, or logging) to `num_groups` so all references remain
consistent; ensure `k, n = kn` and other variables are unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 529141f8-a8b0-4f00-9a29-7102105a9cfd

📥 Commits

Reviewing files that changed from the base of the PR and between e8f3994 and 75f858d580d549c24b862c3328c701ab4201cb99.

📒 Files selected for processing (2)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
  • tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py

c_torch_handle.buffer_ptrs, dtype=torch.int64, device=device
)

num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the device variable instead of hardcoded "cuda".

Line 191 uses "cuda" which defaults to the current device, but device was explicitly set from LOCAL_RANK. Use the explicit device for consistency.

-    num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
+    num_sms = torch.cuda.get_device_properties(device).multi_processor_count
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` at line 191,
The code calls torch.cuda.get_device_properties("cuda") but you already derive
an explicit device from LOCAL_RANK; change that call to use the device variable
instead (i.e., torch.cuda.get_device_properties(device)) so the retrieved
properties (used to set num_sms) come from the intended GPU; update any
references around num_sms to ensure they use the same device variable.

@nvcastet nvcastet force-pushed the cleanup_add_moe_combine_v2 branch from 75f858d to 0ec504a Compare April 8, 2026 19:00
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (7)
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py (7)

150-155: Prefix unused unpacked variables with underscore.

sfa_tensor and sfb_tensor are never used.

-    sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
+    sfa_ref, _sfa_tensor, sfa_torch = create_scale_factor_tensor(
         l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
-    sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
+    sfb_ref, _sfb_tensor, sfb_torch = create_scale_factor_tensor(
         l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
150 - 155, The unpacking of create_scale_factor_tensor returns three values but
the middle outputs sfa_tensor and sfb_tensor are unused; rename them to
_sfa_tensor and _sfb_tensor (or _ for each) in the assignments that call
create_scale_factor_tensor so unused variables are prefixed with an underscore
while keeping sfa_ref, sfa_torch, sfb_ref, and sfb_torch unchanged (update the
two statements that assign sfa_ref, sfa_tensor, sfa_torch and sfb_ref,
sfb_tensor, sfb_torch).

316-317: Document the os._exit(0) workaround inline.

The comment references a PyTorch issue but doesn't explain what happens without this workaround. Consider adding a brief explanation for future maintainers.

     # WAR for https://github.com/pytorch/pytorch/issues/162429
+    # Prevents hang during atexit cleanup of symmetric memory handles
     os._exit(0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
316 - 317, Add a short inline comment above the os._exit(0) workaround
explaining what breaks without it and referencing the linked PyTorch issue:
state that tests hang or cause worker processes to leak/segfault (or the
observed failure mode in your environment) when the issue occurs, and note that
os._exit(0) forcefully terminates the process to avoid the problem; include the
issue number (162429) and a one-line note to remove the workaround once the
PyTorch bug is fixed.

111-122: Prefix unused unpacked variables with underscore.

Static analysis indicates a_tensor and b_tensor are never used. Prefix with _ to indicate intentional discard.

-    a_tensor, a_torch = cutlass_torch.cute_tensor_like(
+    _a_tensor, a_torch = cutlass_torch.cute_tensor_like(
         a_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
-    b_tensor, b_torch = cutlass_torch.cute_tensor_like(
+    _b_tensor, b_torch = cutlass_torch.cute_tensor_like(
         b_ref,
         get_cutlass_dtype(ab_dtype),
         is_dynamic_layout=True,
         assumed_align=16,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
111 - 122, The variables a_tensor and b_tensor returned from
cutlass_torch.cute_tensor_like are unused; rename them to indicate intentional
discard by prefixing with an underscore (e.g., _a_tensor and _b_tensor or just
_), leaving a_torch and b_torch unchanged so the call and types remain the same;
update the two unpacking statements around cutlass_torch.cute_tensor_like to use
the underscored names to satisfy static analysis.

52-62: Consider using pytest.skip() for proper test framework integration.

The print() + return pattern doesn't integrate with pytest's skip reporting. When tests are skipped, pytest won't show them in its summary. This also applies to lines 82-85 and 89-92.

+import pytest
+
 def test_blockscaled_gemm_python_interface(...):
     if not is_cute_dsl_available():
-        print("Skipping: Please `pip install nvidia-cutlass-dsl`")
-        return
+        pytest.skip("Please `pip install nvidia-cutlass-dsl`")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
52 - 62, Replace the print(...) + return skip pattern with pytest.skip(...) to
integrate with pytest reporting: where the guard uses is_cute_dsl_available()
(first block) and the SM100 capability checks (the blocks around the code at
lines 82-85 and 89-92), call pytest.skip("message") instead of printing and
returning, and ensure pytest is imported at top of the test module; update the
messages to match the originals (e.g., "Please `pip install nvidia-cutlass-dsl`"
and "Cute-dsl backend is only supported on SM100.").

5-5: Docstring usage instruction conflicts with PR test plan.

The docstring says torchrun --nproc_per_node=4, but the PR test plan specifies mpirun -np 4 pytest. Consider aligning these or documenting both options if both are valid.

-USAGE: torchrun --nproc_per_node=4 test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
+USAGE: torchrun --nproc_per_node=4 tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
+   or: mpirun -np 4 pytest tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` at line 5,
Update the conflicting usage instruction in the module docstring: the current
line "USAGE: torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" conflicts with the PR test
plan "mpirun -np 4 pytest", so either replace it to match the PR test plan or
document both valid ways (e.g., show both "torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" and "mpirun -np 4 pytest
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py") and clarify
which is preferred; edit the docstring at the top of
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py to make the usage
consistent with the PR test plan or include both options.

64-65: Rename ambiguous variable l to improve readability.

Static analysis flags l as ambiguous (E741) since it resembles 1. Consider renaming to batch_size or num_experts to clarify intent.

-    l, m = lm
+    num_experts, m = lm
     k, n = kn

This would require updating all subsequent uses of l in the function.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
64 - 65, Rename the ambiguous variable `l` in the tuple unpacking "l, m = lm" to
a clearer name such as `batch_size` or `num_experts` and update every subsequent
occurrence of `l` in this test file/function to use the new name; keep the
unpacking ("batch_size, m = lm") and the other tuple unpacking ("k, n = kn")
intact and run tests to ensure no remaining references to the old `l` remain.

287-317: Consider aligning with pytest pattern used in other multi-GPU tests.

The file is documented for torchrun --nproc_per_node=4 invocation (line 5), which is clear and correct. However, other multi-GPU tests in the repository (e.g., tests/comm/test_allreduce_unified_api.py) use pytest-discoverable test functions with distributed fixtures, allowing them to work with standard pytest invocation. For consistency and better CI integration, consider refactoring this test to follow the pytest pattern used elsewhere, or document the specific invocation method in the PR test plan if torchrun is the preferred approach.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py` around lines
287 - 317, The current __main__ torchrun-driven script should be converted to a
pytest-discoverable test: remove the if __name__ == "__main__" block that
manually reads LOCAL_RANK, sets device, calls
torch.distributed.init_process_group and os._exit, and instead create a pytest
test function (e.g., def
test_multi_gpu_blockscaled_gemm(distributed_rank_fixture):) that follows the
pattern from tests/comm/test_allreduce_unified_api.py; use
pytest.mark.parametrize for BATCH_SIZE values, use the distributed fixture to
set the CUDA device and ensure the process group is initialized, then call
test_blockscaled_gemm_python_interface(...) with the same parameters (preserve
ab_dtype, sf_dtype, c_dtype logic and tolerance selection), and remove the
os._exit workaround. This will make the test discoverable by pytest and
compatible with CI distributed fixtures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py`:
- Around line 150-155: The unpacking of create_scale_factor_tensor returns three
values but the middle outputs sfa_tensor and sfb_tensor are unused; rename them
to _sfa_tensor and _sfb_tensor (or _ for each) in the assignments that call
create_scale_factor_tensor so unused variables are prefixed with an underscore
while keeping sfa_ref, sfa_torch, sfb_ref, and sfb_torch unchanged (update the
two statements that assign sfa_ref, sfa_tensor, sfa_torch and sfb_ref,
sfb_tensor, sfb_torch).
- Around line 316-317: Add a short inline comment above the os._exit(0)
workaround explaining what breaks without it and referencing the linked PyTorch
issue: state that tests hang or cause worker processes to leak/segfault (or the
observed failure mode in your environment) when the issue occurs, and note that
os._exit(0) forcefully terminates the process to avoid the problem; include the
issue number (162429) and a one-line note to remove the workaround once the
PyTorch bug is fixed.
- Around line 111-122: The variables a_tensor and b_tensor returned from
cutlass_torch.cute_tensor_like are unused; rename them to indicate intentional
discard by prefixing with an underscore (e.g., _a_tensor and _b_tensor or just
_), leaving a_torch and b_torch unchanged so the call and types remain the same;
update the two unpacking statements around cutlass_torch.cute_tensor_like to use
the underscored names to satisfy static analysis.
- Around line 52-62: Replace the print(...) + return skip pattern with
pytest.skip(...) to integrate with pytest reporting: where the guard uses
is_cute_dsl_available() (first block) and the SM100 capability checks (the
blocks around the code at lines 82-85 and 89-92), call pytest.skip("message")
instead of printing and returning, and ensure pytest is imported at top of the
test module; update the messages to match the originals (e.g., "Please `pip
install nvidia-cutlass-dsl`" and "Cute-dsl backend is only supported on
SM100.").
- Line 5: Update the conflicting usage instruction in the module docstring: the
current line "USAGE: torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" conflicts with the PR test
plan "mpirun -np 4 pytest", so either replace it to match the PR test plan or
document both valid ways (e.g., show both "torchrun --nproc_per_node=4
test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py" and "mpirun -np 4 pytest
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py") and clarify
which is preferred; edit the docstring at the top of
tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py to make the usage
consistent with the PR test plan or include both options.
- Around line 64-65: Rename the ambiguous variable `l` in the tuple unpacking
"l, m = lm" to a clearer name such as `batch_size` or `num_experts` and update
every subsequent occurrence of `l` in this test file/function to use the new
name; keep the unpacking ("batch_size, m = lm") and the other tuple unpacking
("k, n = kn") intact and run tests to ensure no remaining references to the old
`l` remain.
- Around line 287-317: The current __main__ torchrun-driven script should be
converted to a pytest-discoverable test: remove the if __name__ == "__main__"
block that manually reads LOCAL_RANK, sets device, calls
torch.distributed.init_process_group and os._exit, and instead create a pytest
test function (e.g., def
test_multi_gpu_blockscaled_gemm(distributed_rank_fixture):) that follows the
pattern from tests/comm/test_allreduce_unified_api.py; use
pytest.mark.parametrize for BATCH_SIZE values, use the distributed fixture to
set the CUDA device and ensure the process group is initialized, then call
test_blockscaled_gemm_python_interface(...) with the same parameters (preserve
ab_dtype, sf_dtype, c_dtype logic and tolerance selection), and remove the
os._exit workaround. This will make the test discoverable by pytest and
compatible with CI distributed fixtures.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 985de003-29cf-4fef-a980-053dc3b7b59a

📥 Commits

Reviewing files that changed from the base of the PR and between 75f858d580d549c24b862c3328c701ab4201cb99 and 0ec504a.

📒 Files selected for processing (2)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
  • tests/gemm/test_multi_gpu_cute_dsl_blockscaled_gemm_fusion.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !533 has been created, and the CI pipeline #48185130 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx self-assigned this Apr 17, 2026
@gyhintel
Copy link
Copy Markdown

Hi, May I ask a question? Why is "is_swap_ab" needed? The comment says it's "needed for combine fusion where output is M-major." The mma result in tensor memory is M-major. But I think it can be load to RF and then store it to shared memory via "stmatrix.xxx.trans". Then the result in shared memory is N-major and it didn't need the "swap_ab". Thanks!

@nvcastet
Copy link
Copy Markdown
Author

Hi, May I ask a question? Why is "is_swap_ab" needed? The comment says it's "needed for combine fusion where output is M-major." The mma result in tensor memory is M-major. But I think it can be load to RF and then store it to shared memory via "stmatrix.xxx.trans". Then the result in shared memory is N-major and it didn't need the "swap_ab". Thanks!

Good question, with Blackwell subtiling pipelining the mem copy out of TMEM to REG/smem/gmem the subtile shape is <128,32> for mma tile <128,128>. Which means we can only write 32 elts per token per subtile which is not efficient for mem writes, with swap_ab, we can write 128 contiguous elts per token per subtile (256B for bf16).

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx aleozlx removed the run-ci label May 8, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented May 8, 2026

/bot run

@aleozlx aleozlx added the run-ci label May 8, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !533 has been updated with latest changes, and the CI pipeline #50622730 is currently running. I'll report back once the pipeline job completes.

@nvcastet nvcastet force-pushed the cleanup_add_moe_combine_v2 branch from 31a1831 to 14ddb4e Compare May 11, 2026 21:57
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented May 11, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !533 has been updated with latest changes, and the CI pipeline #50968196 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants