Skip to content

feat: add specialized gemm kernel for sm121#3283

Open
nv-yunzheq wants to merge 2 commits into
flashinfer-ai:mainfrom
nv-yunzheq:specialized-gemm-update
Open

feat: add specialized gemm kernel for sm121#3283
nv-yunzheq wants to merge 2 commits into
flashinfer-ai:mainfrom
nv-yunzheq:specialized-gemm-update

Conversation

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq commented May 11, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • SM121-specialized FP4 and FP8 kernels for faster matrix ops on supported NVIDIA GPUs.
    • Runtime routing with environment control to enable/disable specialized kernels.
  • Improvements

    • More robust autotuning cache/key behavior for clearer config disambiguation.
    • Improved backend auto-detection resilience and packaging updates for CUDA assets.
  • Tests

    • Added CUDA/SM121-gated tests validating routing correctness and performance.
  • Tools

    • New CUDA benchmarking script to measure routing speedups and correctness.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 11, 2026

📝 Walkthrough

Walkthrough

Adds SM121-specialized FP4 (mm_fp4) and FP8 (bmm_fp8) GEMM routing, multi-backend kernels (CUDA/CUTE-DSL/cuTile), env toggle, autotuner cache-key changes, AOT registration, tests, and a benchmarking tool.

Changes

SM121-Specialized GEMM Routing and Kernels

Layer / File(s) Summary
Environment & Configuration
flashinfer/env.py, flashinfer/utils.py, pyproject.toml
Adds FLASHINFER_SPECIALIZED_KERNEL_DISABLE with cached reader/reset, expands backend probe exception handling to include RuntimeError, and registers JSON/CUDA package-data for specialized-kernel packages.
Autotuner Cache Updates
flashinfer/autotuner.py
Introduces helper file-key builders that include cache-key extras for persisted autotune configs with legacy 3-field fallback; updates save/load to use these keys.
Specialized Kernel Public API
flashinfer/gemm/specialized_kernels/__init__.py
Exports problem predicates and runners (is_mm_fp4_sm121_specialized_problem, run_mm_fp4_sm121_specialized, is_bmm_fp8_sm121_specialized_problem, run_bmm_fp8_sm121_specialized) and the CUDA module generator.
MM FP4 SM121 Router & CUTE DSL Kernel
flashinfer/gemm/specialized_kernels/mm_fp4_sm121/*
Adds mm_fp4 SM121 router with workload LUT, selection predicate, alpha handling, lazy-loaded CUTE-DSL NVFP4 kernels with multiple tiling/dispatch paths, and workload JSON.
BMM FP8 SM121 Router & Multi-Backend Kernels
flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/*
Adds bmm_fp8 SM121 router with workload LUT, split-K heuristic and workspace caching, native CUDA binding+kernel (split-K/reduce and small-shape fallbacks), CUTE-DSL implementations, and cuTile backend with swizzled/persistent dispatch; includes workload JSON and cached loaders.
GEMM Router Integration and AOT Compilation
flashinfer/gemm/gemm_base.py, flashinfer/aot.py
Integrates specialized runners into mm_fp4 and bmm_fp8 with early dispatch when applicable; extends fp8_gemm_sm100 to accept extra_runners, extra_runners_first, and custom_op; refactors FP8 algo cache keying; appends SM121 CUDA module generator to AOT JIT when available.
Tests & Benchmarks
tests/gemm/test_specialized_gemm_routing.py, benchmarks/bench_specialized_gemm_routing.py
Adds SM121-gated tests comparing specialized vs baseline vs routing-disabled outputs with cosine-similarity assertions; adds a benchmark script recording median GPU ms, routing status, speedup/improvement, and correctness with optional autotune and CUDA graph/CUPTI flags.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant is_bmm as is_bmm_fp8_sm121_specialized_problem
  participant select as _select_impl
  participant runSpecial as run_bmm_fp8_sm121_specialized
  participant cuda as cuda_kernel
  participant cute as cute_dsl_kernel
  participant cutile as cutile_kernel
  Caller->>is_bmm: check(A,B,scales,dtype,out)
  is_bmm->>select: validate shape/dtype/device and LUT
  select-->>is_bmm: impl
  is_bmm-->>Caller: specialized available
  Caller->>runSpecial: execute(A,B,scales,out)
  runSpecial->>select: get impl
  alt impl=="cuda"
    select-->>runSpecial: cuda
    runSpecial->>cuda: run(A,B,scales,out,workspace)
    cuda->>cuda: compute_splits & launch_fp8_gemm
    cuda-->>runSpecial: out (BF16)
  else impl=="cute_dsl"
    select-->>runSpecial: cute_dsl
    runSpecial->>cute: run(A,B,scales,out)
    cute->>cute: dispatch kernel variant
    cute-->>runSpecial: out (BF16)
  else impl=="cutile"
    select-->>runSpecial: cutile
    runSpecial->>cutile: run(A,B,scales,out)
    cutile->>cutile: build_dispatch & launch
    cutile-->>runSpecial: out (BF16)
  end
  runSpecial-->>Caller: output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

cute-dsl

Suggested reviewers

  • sricketts
  • dhiraj113
  • aleozlx
  • yzh119
  • cyx-6
  • samuellees
  • bkryu

Poem

🐰 I hopped through kernels, tiny to grand,
SM121 roads make matrix math stand,
FP4 and FP8 in a speedy race,
CUTE and cuTile find their place,
Benchmarks cheer — a fast, fluffy chase!

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description contains only the repository template with unchecked checklist items and no actual implementation details, rationale, or test results filled in. Fill in the Description section explaining what was implemented and why, complete the checklist items, and add any relevant notes for reviewers.
Docstring Coverage ⚠️ Warning Docstring coverage is 5.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add specialized gemm kernel for sm121' clearly and concisely describes the main change in the changeset - addition of specialized GEMM kernels for SM121 architecture.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !658 has been created, and the CI pipeline #50886446 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

6090-6108: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid eagerly constructing the generic b12x runner for specialized SM121 problems.

When specialized_problem is true in tuning mode, this still instantiates _b12x_gemm_fp4_runner() before the specialized runner is appended. That reintroduces the CuTe DSL/CUTLASS import path that _b12x_gemm_fp4_requirement() intentionally bypasses, so autotune/profile can fail before the specialized kernel is reachable.

Suggested fix
-    runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends]
+    runners = []
+    for cur_backend in backends:
+        if specialized_problem and cur_backend == "b12x":
+            try:
+                _check_cute_dsl_availability()
+            except RuntimeError:
+                continue
+        runners.append(backend_to_runner_factory[cur_backend]())
+
     custom_op = "fp4_gemm"
     if specialized_problem:
         runners.append(_mm_fp4_sm121_specialized_runner())
         custom_op = "fp4_gemm_sm121_specialized"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/gemm_base.py` around lines 6090 - 6108, The code eagerly
constructs all runners by calling backend_to_runner_factory[...]() for every
backend, which causes _b12x_gemm_fp4_runner() to be instantiated even when
specialized_problem is true; change the construction to iterate backends and
call the factory only when appropriate (skip calling the "b12x" factory if
specialized_problem is true) so that _b12x_gemm_fp4_runner is not created
prematurely; use the existing backend_to_runner_factory mapping (and its lambda
values) and append _mm_fp4_sm121_specialized_runner() and set custom_op when
specialized_problem is true.
🧹 Nitpick comments (4)
flashinfer/utils.py (1)

1167-1173: ⚡ Quick win

Consider adding debug logging and narrowing exception scope.

The exception handling enables robust auto-backend selection by treating requirement-check failures as "backend unsuitable" rather than fatal errors. However, silently catching broad exceptions like RuntimeError and ValueError creates observability and debugging challenges:

  1. Observability gap: When auto-selection fails or produces unexpected results, there's no trace of which backends were attempted and why they were rejected.
  2. Broad exception handling: RuntimeError can indicate actual runtime failures (CUDA errors, resource exhaustion) that may deserve propagation rather than silent continuation.
  3. Masked bugs: Errors in requirement checker logic itself could be hidden rather than surfaced.
Suggested improvements

Option 1: Add debug logging

                except (ValueError, RuntimeError):
                    # In backend="auto", requirement functions are probed before
                    # compute-capability filtering. Optional backend dependency
                    # failures, such as CuTe DSL being unavailable, should only
                    # make that backend unsuitable and must not block later
                    # candidates.
+                   # Log at debug level for troubleshooting
+                   import logging
+                   logging.getLogger(__name__).debug(
+                       f"Backend '{backend}' skipped during auto-selection: {sys.exc_info()[1]}"
+                   )
                    continue

Option 2: Narrow exception handling (if specific exception types are known)

If the specialized kernels consistently raise specific exceptions for missing dependencies (e.g., ImportError, ModuleNotFoundError), consider catching those explicitly:

except (ValueError, ImportError, ModuleNotFoundError):
    # More specific exception types
    continue
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/utils.py` around lines 1167 - 1173, The except block that
currently reads "except (ValueError, RuntimeError): ... continue" silently
swallows failures during the backend="auto" requirement checks; change this to
catch only dependency/missing-module exceptions (e.g., except (ValueError,
ImportError, ModuleNotFoundError):) and add a debug-level log that records the
backend being probed and the caught exception before continuing. Locate the
try/except around the backend requirement check (the block that currently uses
"continue" on failure) and ensure you emit logger.debug(...) (using the module
logger or logging.getLogger(__name__)) with the backend identifier and exception
details, while allowing unexpected RuntimeError-like exceptions to propagate.
flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json (1)

1-191: ⚡ Quick win

Consider adding JSON schema validation.

To prevent configuration errors and ensure consistency, consider validating this file against a JSON schema during tests or at module load time. This would catch typos in field names, invalid dtype values, or missing required fields early.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json` around
lines 1 - 191, The workloads.json list lacks schema validation, which allows
typos or invalid values in fields like "b", "m", "k", "n", "out_dtype",
"backend", and "impl"; add a JSON Schema that defines required properties,
types, allowed enums for out_dtype/backend/impl, and ranges (e.g., positive
integers) and enforce it either at module load (where the file is parsed) or in
tests using a validator library (e.g., ajv for JS/TS or jsonschema for Python)
to fail fast on invalid entries.
flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json (1)

1-1091: ⚡ Quick win

Consider adding JSON schema validation.

Similar to the bmm_fp8_sm121 workloads file, adding schema validation would help catch configuration errors such as typos in field names, invalid dtype/backend values, or missing required fields. This could be done during tests or at module load time.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json` around lines
1 - 1091, The workloads.json lacks schema validation; add a JSON Schema (e.g.,
mm_fp4_sm121_schema.json) that requires fields
m,k,n,block_size,out_dtype,backend,use_nvfp4,impl and use_8x4_sf_layout,
constrains types and allowed values for out_dtype and backend, and then validate
the JSON at load/test time using a validator (jsonschema or similar) inside the
workload loader (e.g., load_workloads / parse_workloads) so invalid keys/typos
or missing fields are rejected during module load or in CI tests.
benchmarks/bench_specialized_gemm_routing.py (1)

26-26: ⚡ Quick win

Use the benchmark timer via the public testing entrypoint.

Line 26 imports bench_gpu_time from flashinfer.testing.utils; please switch to the flashinfer.testing entrypoint required by the benchmark guideline.

Proposed patch
-from flashinfer.testing.utils import bench_gpu_time
+from flashinfer.testing import bench_gpu_time
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/bench_specialized_gemm_routing.py` at line 26, Replace the direct
import from flashinfer.testing.utils with the public testing entrypoint: change
the import so bench_gpu_time is imported from flashinfer.testing (i.e., use
"from flashinfer.testing import bench_gpu_time") to follow the benchmark
guideline and ensure the public API is used; update the existing import that
currently references flashinfer.testing.utils to reference flashinfer.testing
instead.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6218-6250: The runner currently hardcodes the backend string
"cublas" in both get_valid_tactics and forward; update
_bmm_fp8_sm121_specialized_runner to accept (or capture) the caller's backend
and use that variable when calling is_bmm_fp8_sm121_specialized_problem and
run_bmm_fp8_sm121_specialized (instead of the literal "cublas") so the predicate
and execution use the same backend as the caller; apply the same change to the
analogous runner at the other location (the block around the second occurrence)
to ensure consistent backend propagation.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py`:
- Around line 127-146: The _cuda_workspace function currently keys the shared
buffers by (device, m, n, splits) causing cross-stream buffer reuse; modify it
to be stream-scoped by including the current CUDA stream in the cache key (e.g.,
torch.cuda.current_stream(device) or stream.cuda_stream) or, alternatively, skip
caching and allocate a fresh workspace when concurrent streams/splits are
possible (i.e., when splits > 1). Update the cache key usage for
_WORKSPACE_CACHE and the empty-cache behavior for _EMPTY_WORKSPACE_CACHE
accordingly so each CUDA stream gets its own tensor or a new tensor is returned
for split-K cases to avoid concurrent writes.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu`:
- Around line 122-139: Add a pre-dispatch guard that rejects shapes where K is
not a multiple of 16 to avoid kernel loads reading past row tails: before
creating the device guard / entering the batch loop (after computing splits and
workspace), check if (K % 16 != 0) and fail fast with a clear error (e.g. via
TVM_FFI_ICHECK or similar) indicating unsupported K for the fp8 kernel; keep
this check near compute_splits/required_workspace and mention launch_fp8_gemm in
the message so callers know which backend requires the constraint.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu`:
- Around line 270-280: The two-element BF16 stores using __nv_bfloat162 (writing
v01/v23 into Out_bf) can overrun when col_base == N-1 because they always write
two outputs; update the non-split-K epilogue to guard the second element the
same way as the split-K path: keep the existing check for the first element
(col_base < N) for each row (row0,row1) but add an additional condition col_base
+ 1 < N (or equivalent) before performing the __nv_bfloat162 write that writes
two adjacent outputs into Out_bf, or alternatively only write a single BF16
element when col_base == N-1 so the tail element is not overwritten (adjust the
code around v01/v23 and the reinterpret_cast to enforce this).

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.py`:
- Around line 1771-1905: The dispatch branches call launchers with weaker
divisibility checks than the launchers require; update each branch guard to
match the launcher's contract: for the SIMT1 path (symbols _SIMT1_V16_COMPILED /
_compile_simt1_v16) require K % 512 == 0 (use (K & 511) == 0) instead of only K
% 16; for both SIMT2 and SIMT8 paths (symbols _SIMT2_COMPILED / _compile_simt2
and _SIMT8_COMPILED / _compile_simt8) require N % 64 == 0 (use (N & 63) == 0)
not just N % 8; and for the tiny MMA fallback (symbols _MMA_T_COMPILED /
_compile_mma_tiny) add the missing N % 32 == 0 and K % 256 == 0 checks (use (N &
31) == 0 and (K & 255) == 0) before invoking the compiled kernel so kernels
never receive shapes that violate their tile/divisibility contracts.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py`:
- Around line 24-59: The kernel currently hard-codes batch 0 via
ct.load/ct.store index=(0, ...), so for A.shape[0] > 1 later batches are never
computed; update bmm_fp8_kernel (and the other kernels referenced) to read the
batch index from the kernel launch (e.g. use ct.bid(1) or an additional bid
variable instead of the literal 0), replace all occurrences of index=(0, bid_m,
k) and index=(0, k, bid_n) and the store index=(0, bid_m, bid_n) with
index=(batch_id, bid_m, k), index=(batch_id, k, bid_n), and index=(batch_id,
bid_m, bid_n) respectively, and ensure the host run()/launch code creates a
2D/3D grid that includes the batch dimension so the kernel’s batch_id maps
correctly; apply the same change to the other kernels mentioned (lines ~62-93,
96-131, 134-182, 285-327).
- Around line 12-21: The grouped swizzle in _swizzle_2d_from_bid computes bid_m
using bid % current_group_size_m which fails for a final partial M-group;
instead compute the CTA-local id within the group (local_id = bid - group_id *
num_bid_in_group) and derive bid_m = first_bid_m + (local_id %
current_group_size_m) and bid_n = local_id // current_group_size_m so the last
partial M-group maps tiles correctly; update the calculations in
_swizzle_2d_from_bid to use local_id and current_group_size_m as described.

In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py`:
- Around line 200-210: _prepare_alpha currently preserves caller device and
dtype which can yield invalid pointers when kernel expects a CUDA float32*;
change it to always return a CUDA float32 1-D tensor (device normalized via
_device_key) before exposing data_ptr. Concretely, when alpha is None, ensure
the cached value in _ALPHA_ONE_CACHE is torch.tensor([1.0], dtype=torch.float32,
device=device); when alpha is provided, move/convert it to device and dtype
torch.float32 (e.g., alpha.to(device=device, dtype=torch.float32)), then
normalize shape with unsqueeze(0) or reshape(1) as before; keep using the same
cache key (_ALPHA_ONE_CACHE) and keep function name _prepare_alpha so callers
(and cute_dsl.kernel.run) receive a CUDA float32 pointer.
- Around line 110-156: The current device checks in _select_impl only verify
tensors are CUDA but not that they all live on the same GPU as a, allowing
unsafe mixed-device launches; update the predicate to ensure b.device,
a_descale.device, b_descale.device (and out.device when out is provided) equal
a.device before accepting the specialized impl. Locate the selection logic in
function _select_impl (and the final out validation block) and add explicit
device-equality checks (e.g., compare .device or .get_device() values) for b,
a_descale, b_descale, and out to reject cases where any tensor is on a different
CUDA device than a. Ensure these checks run before returning the impl.

In `@tests/gemm/test_specialized_gemm_routing.py`:
- Around line 213-219: The test uses the builtin name "input" as a local
variable (seen around creation of the tensor and calls to to_float8 and
torch.bmm), which shadows Python's builtin and triggers lint error A001; rename
that variable (and its uses: input_fp8, input_inv_s, and the reference bmm call)
to a non-builtin name like input_tensor or inp throughout the block so
to_float8(input_tensor, ...) and torch.bmm(input_tensor, mat2) are used instead.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6090-6108: The code eagerly constructs all runners by calling
backend_to_runner_factory[...]() for every backend, which causes
_b12x_gemm_fp4_runner() to be instantiated even when specialized_problem is
true; change the construction to iterate backends and call the factory only when
appropriate (skip calling the "b12x" factory if specialized_problem is true) so
that _b12x_gemm_fp4_runner is not created prematurely; use the existing
backend_to_runner_factory mapping (and its lambda values) and append
_mm_fp4_sm121_specialized_runner() and set custom_op when specialized_problem is
true.

---

Nitpick comments:
In `@benchmarks/bench_specialized_gemm_routing.py`:
- Line 26: Replace the direct import from flashinfer.testing.utils with the
public testing entrypoint: change the import so bench_gpu_time is imported from
flashinfer.testing (i.e., use "from flashinfer.testing import bench_gpu_time")
to follow the benchmark guideline and ensure the public API is used; update the
existing import that currently references flashinfer.testing.utils to reference
flashinfer.testing instead.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json`:
- Around line 1-191: The workloads.json list lacks schema validation, which
allows typos or invalid values in fields like "b", "m", "k", "n", "out_dtype",
"backend", and "impl"; add a JSON Schema that defines required properties,
types, allowed enums for out_dtype/backend/impl, and ranges (e.g., positive
integers) and enforce it either at module load (where the file is parsed) or in
tests using a validator library (e.g., ajv for JS/TS or jsonschema for Python)
to fail fast on invalid entries.

In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json`:
- Around line 1-1091: The workloads.json lacks schema validation; add a JSON
Schema (e.g., mm_fp4_sm121_schema.json) that requires fields
m,k,n,block_size,out_dtype,backend,use_nvfp4,impl and use_8x4_sf_layout,
constrains types and allowed values for out_dtype and backend, and then validate
the JSON at load/test time using a validator (jsonschema or similar) inside the
workload loader (e.g., load_workloads / parse_workloads) so invalid keys/typos
or missing fields are rejected during module load or in CI tests.

In `@flashinfer/utils.py`:
- Around line 1167-1173: The except block that currently reads "except
(ValueError, RuntimeError): ... continue" silently swallows failures during the
backend="auto" requirement checks; change this to catch only
dependency/missing-module exceptions (e.g., except (ValueError, ImportError,
ModuleNotFoundError):) and add a debug-level log that records the backend being
probed and the caught exception before continuing. Locate the try/except around
the backend requirement check (the block that currently uses "continue" on
failure) and ensure you emit logger.debug(...) (using the module logger or
logging.getLogger(__name__)) with the backend identifier and exception details,
while allowing unexpected RuntimeError-like exceptions to propagate.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f1e302b2-beca-4279-b584-532ea49ba229

📥 Commits

Reviewing files that changed from the base of the PR and between 0a128d1 and 01dac0a.

📒 Files selected for processing (24)
  • benchmarks/bench_specialized_gemm_routing.py
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/env.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/gemm/specialized_kernels/__init__.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/__init__.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/__init__.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/__init__.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/__init__.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py
  • flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json
  • flashinfer/gemm/specialized_kernels/mm_fp4_sm121/__init__.py
  • flashinfer/gemm/specialized_kernels/mm_fp4_sm121/cute_dsl/__init__.py
  • flashinfer/gemm/specialized_kernels/mm_fp4_sm121/cute_dsl/kernel.py
  • flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py
  • flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json
  • flashinfer/utils.py
  • pyproject.toml
  • tests/gemm/test_specialized_gemm_routing.py

Comment on lines +6218 to +6250
def _bmm_fp8_sm121_specialized_runner():
class BMMFp8Sm121SpecializedRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> list:
del profile
A, B, A_scale, B_scale, out, _ = inputs
if is_bmm_fp8_sm121_specialized_problem(
A,
B,
A_scale,
B_scale,
out.dtype,
out,
"cublas",
):
return [0]
return []

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
**kwargs,
):
del tactic, do_preparation, kwargs
A, B, A_scale, B_scale, out, _ = inputs
return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, "cublas")

return BMMFp8Sm121SpecializedRunner()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Pass the selected backend through the SM121 BMM specialized runner.

specialized_problem is computed with the caller’s backend, but the extra runner re-checks and executes with hardcoded "cublas". If the specialized predicate or runtime dispatch uses that parameter to gate disabled variants or pick an implementation, autotune/runtime can diverge from the path that was actually approved here.

Suggested fix
-def _bmm_fp8_sm121_specialized_runner():
+def _bmm_fp8_sm121_specialized_runner(backend: str):
     class BMMFp8Sm121SpecializedRunner(TunableRunner):
         def get_valid_tactics(
             self,
             inputs: List[torch.Tensor],
             profile: OptimizationProfile,
@@
             if is_bmm_fp8_sm121_specialized_problem(
                 A,
                 B,
                 A_scale,
                 B_scale,
                 out.dtype,
                 out,
-                "cublas",
+                backend,
             ):
                 return [0]
             return []
@@
         ):
             del tactic, do_preparation, kwargs
             A, B, A_scale, B_scale, out, _ = inputs
-            return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, "cublas")
+            return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, backend)

     return BMMFp8Sm121SpecializedRunner()
@@
     extra_runners = None
     custom_op = "fp8_gemm"
     if specialized_problem:
-        extra_runners = [_bmm_fp8_sm121_specialized_runner()]
+        extra_runners = [_bmm_fp8_sm121_specialized_runner(backend)]
         custom_op = "fp8_gemm_sm121_specialized"

Also applies to: 6335-6358

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/gemm_base.py` around lines 6218 - 6250, The runner currently
hardcodes the backend string "cublas" in both get_valid_tactics and forward;
update _bmm_fp8_sm121_specialized_runner to accept (or capture) the caller's
backend and use that variable when calling is_bmm_fp8_sm121_specialized_problem
and run_bmm_fp8_sm121_specialized (instead of the literal "cublas") so the
predicate and execution use the same backend as the caller; apply the same
change to the analogous runner at the other location (the block around the
second occurrence) to ensure consistent backend propagation.

Comment on lines +127 to +146
def _cuda_workspace(A: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
device = _device_key(A.device)
m = int(A.shape[1])
k = int(A.shape[2])
n = int(out.shape[2])
splits = _compute_splits(m, n, k)
required_numel = splits * m * n if splits > 1 else 0
if required_numel == 0:
cached = _EMPTY_WORKSPACE_CACHE.get(device)
if cached is None:
cached = torch.empty((0,), dtype=torch.float32, device=device)
_EMPTY_WORKSPACE_CACHE[device] = cached
return cached

key = (device, m, n, splits)
cached = _WORKSPACE_CACHE.get(key)
if cached is None or cached.numel() < required_numel:
cached = torch.empty((required_numel,), dtype=torch.float32, device=device)
_WORKSPACE_CACHE[key] = cached
return cached
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

# First, find the complete implementation of run_bmm_fp8_sm121_specialized
cat -n flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py

Repository: flashinfer-ai/flashinfer

Length of output: 9415


🏁 Script executed:

# Look for the CUDA binding and how workspace is passed
rg "run_bmm_fp8_sm121_specialized|_cuda_workspace" -A 10 flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/

Repository: flashinfer-ai/flashinfer

Length of output: 4132


🏁 Script executed:

# Check if there are any stream-related parameters in the function signatures
rg "def run_bmm_fp8_sm121_specialized|stream|Stream" -B 2 -A 5 flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/

Repository: flashinfer-ai/flashinfer

Length of output: 50380


🏁 Script executed:

# Look for the actual CUDA kernel binding to understand how workspace is used
fd -t f "*.py" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/ -exec grep -l "cuda_lib\|libflashinfer\|ctypes" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 578


🏁 Script executed:

cat -n flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2691


🏁 Script executed:

# Look for get_stream implementation and synchronization
rg "get_stream|cudaStreamSynchronize|cudaEventRecord" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/

Repository: flashinfer-ai/flashinfer

Length of output: 182


🏁 Script executed:

# Check if there are tests that use multiple streams
fd -t f "test" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/ -o fd -t f "test*.py" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Look for how get_stream is defined in the codebase
rg "def get_stream|cudaStream_t.*get_stream" -B 2 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 34899


🏁 Script executed:

# Get more context about get_stream behavior
cat -n csrc/tvm_ffi_utils.h | grep -A 3 "get_stream"

Repository: flashinfer-ai/flashinfer

Length of output: 243


🏁 Script executed:

# Look at the actual kernel launch with split-k to see workspace usage
rg "launch_kernel_splitk|workspace" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu -B 2 -A 8

Repository: flashinfer-ai/flashinfer

Length of output: 8654


🏁 Script executed:

# Check if there's any synchronization between split-k phases
rg "cudaStreamSynchronize|cudaEventSynchronize|cudaDeviceSynchronize" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/

Repository: flashinfer-ai/flashinfer

Length of output: 50


Stream-scoped workspace cache required for concurrent multi-stream execution.

The workspace cache key omits the stream, so concurrent operations on different CUDA streams with identical shapes (m, n, k) share the same buffer. When split-K is active (splits > 1), multiple streams write to the same workspace tensor simultaneously, corrupting partial results. Either include the stream in the cache key or allocate fresh workspace per invocation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py` around
lines 127 - 146, The _cuda_workspace function currently keys the shared buffers
by (device, m, n, splits) causing cross-stream buffer reuse; modify it to be
stream-scoped by including the current CUDA stream in the cache key (e.g.,
torch.cuda.current_stream(device) or stream.cuda_stream) or, alternatively, skip
caching and allocate a fresh workspace when concurrent streams/splits are
possible (i.e., when splits > 1). Update the cache key usage for
_WORKSPACE_CACHE and the empty-cache behavior for _EMPTY_WORKSPACE_CACHE
accordingly so each CUDA stream gets its own tensor or a new tensor is returned
for split-K cases to avoid concurrent writes.

Comment on lines +122 to +139
const int splits = compute_splits(M, N, K);
const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0;
TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace)
<< "workspace is too small for bmm_fp8 specialized kernel";

ffi::CUDADeviceGuard device_guard(A.device().device_id);
cudaStream_t stream = get_stream(A.device());
const int64_t A_batch_stride = static_cast<int64_t>(M) * K;
const int64_t B_batch_stride = static_cast<int64_t>(K) * N;
const int64_t O_batch_stride = static_cast<int64_t>(M) * N;
void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr;

for (int b = 0; b < batch; ++b) {
const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride;
const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride;
void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16);
launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr,
splits, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject unsupported K values before dispatching the CUDA kernels.

This binding accepts any K, but the backend only has 16-byte-chunk load paths (uint4 / cp.async_16). If K % 16 != 0, the last iteration can read past the row tail before launch_fp8_gemm ever gets a chance to recover. A cheap guard here avoids turning an unsupported shape into a device fault.

Suggested fix
   const int splits = compute_splits(M, N, K);
+  TVM_FFI_ICHECK_EQ(K % 16, 0)
+      << "bmm_fp8 SM121 specialized kernel requires K to be divisible by 16";
   const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const int splits = compute_splits(M, N, K);
const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0;
TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace)
<< "workspace is too small for bmm_fp8 specialized kernel";
ffi::CUDADeviceGuard device_guard(A.device().device_id);
cudaStream_t stream = get_stream(A.device());
const int64_t A_batch_stride = static_cast<int64_t>(M) * K;
const int64_t B_batch_stride = static_cast<int64_t>(K) * N;
const int64_t O_batch_stride = static_cast<int64_t>(M) * N;
void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr;
for (int b = 0; b < batch; ++b) {
const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride;
const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride;
void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16);
launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr,
splits, stream);
const int splits = compute_splits(M, N, K);
TVM_FFI_ICHECK_EQ(K % 16, 0)
<< "bmm_fp8 SM121 specialized kernel requires K to be divisible by 16";
const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0;
TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace)
<< "workspace is too small for bmm_fp8 specialized kernel";
ffi::CUDADeviceGuard device_guard(A.device().device_id);
cudaStream_t stream = get_stream(A.device());
const int64_t A_batch_stride = static_cast<int64_t>(M) * K;
const int64_t B_batch_stride = static_cast<int64_t>(K) * N;
const int64_t O_batch_stride = static_cast<int64_t>(M) * N;
void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr;
for (int b = 0; b < batch; ++b) {
const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride;
const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride;
void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16);
launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr,
splits, stream);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu` around
lines 122 - 139, Add a pre-dispatch guard that rejects shapes where K is not a
multiple of 16 to avoid kernel loads reading past row tails: before creating the
device guard / entering the batch loop (after computing splits and workspace),
check if (K % 16 != 0) and fail fast with a clear error (e.g. via TVM_FFI_ICHECK
or similar) indicating unsupported K for the fp8 kernel; keep this check near
compute_splits/required_workspace and mention launch_fp8_gemm in the message so
callers know which backend requires the constraint.

Comment on lines +270 to +280
__nv_bfloat162 v01, v23;
v01.x = __float2bfloat16(acc[mf][nf][0] * scale);
v01.y = __float2bfloat16(acc[mf][nf][1] * scale);
v23.x = __float2bfloat16(acc[mf][nf][2] * scale);
v23.y = __float2bfloat16(acc[mf][nf][3] * scale);

if (row0 < M && col_base < N) {
*reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01;
}
if (row1 < M && col_base < N) {
*reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Guard the second BF16 write in the non-split-K epilogue.

Once col_base == N - 1, these __nv_bfloat162 stores still write two outputs and overrun the row tail by one element. The split-K epilogue above already handles col_base + 1 < N; this direct-write path needs the same tail logic.

Suggested fix
-        if (row0 < M && col_base < N) {
-          *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01;
-        }
-        if (row1 < M && col_base < N) {
-          *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23;
-        }
+        if (row0 < M && col_base < N) {
+          Out_bf[row0 * N + col_base] = v01.x;
+          if (col_base + 1 < N) Out_bf[row0 * N + col_base + 1] = v01.y;
+        }
+        if (row1 < M && col_base < N) {
+          Out_bf[row1 * N + col_base] = v23.x;
+          if (col_base + 1 < N) Out_bf[row1 * N + col_base + 1] = v23.y;
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__nv_bfloat162 v01, v23;
v01.x = __float2bfloat16(acc[mf][nf][0] * scale);
v01.y = __float2bfloat16(acc[mf][nf][1] * scale);
v23.x = __float2bfloat16(acc[mf][nf][2] * scale);
v23.y = __float2bfloat16(acc[mf][nf][3] * scale);
if (row0 < M && col_base < N) {
*reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01;
}
if (row1 < M && col_base < N) {
*reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23;
__nv_bfloat162 v01, v23;
v01.x = __float2bfloat16(acc[mf][nf][0] * scale);
v01.y = __float2bfloat16(acc[mf][nf][1] * scale);
v23.x = __float2bfloat16(acc[mf][nf][2] * scale);
v23.y = __float2bfloat16(acc[mf][nf][3] * scale);
if (row0 < M && col_base < N) {
Out_bf[row0 * N + col_base] = v01.x;
if (col_base + 1 < N) Out_bf[row0 * N + col_base + 1] = v01.y;
}
if (row1 < M && col_base < N) {
Out_bf[row1 * N + col_base] = v23.x;
if (col_base + 1 < N) Out_bf[row1 * N + col_base + 1] = v23.y;
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu` around
lines 270 - 280, The two-element BF16 stores using __nv_bfloat162 (writing
v01/v23 into Out_bf) can overrun when col_base == N-1 because they always write
two outputs; update the non-split-K epilogue to guard the second element the
same way as the split-K path: keep the existing check for the first element
(col_base < N) for each row (row0,row1) but add an additional condition col_base
+ 1 < N (or equivalent) before performing the __nv_bfloat162 write that writes
two adjacent outputs into Out_bf, or alternatively only write a single BF16
element when col_base == N-1 so the tail element is not overwritten (adjust the
code around v01/v23 and the reinterpret_cast to enforce this).

Comment on lines +1771 to +1905
# SIMT vec1+VEC=16 for M=1 (any N) or M<5 with M*N<=8192
if (K & 15) == 0 and (M == 1 or (M < 5 and M * N <= 8192)):
if _SIMT1_V16_COMPILED is None:
_SIMT1_V16_COMPILED = _compile_simt1_v16()
_SIMT1_V16_COMPILED(
_fp8_ptr(A),
_fp8_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# 2-row SIMT for M=5-15 with N>=1024 N<5120 (B reuse across 2 A rows)
if M >= 5 and M < 16 and N >= 1024 and N < 5120 and (N & 7) == 0:
if _SIMT2_COMPILED is None:
_SIMT2_COMPILED = _compile_simt2()
_SIMT2_COMPILED(
_fp8_ptr(A),
_fp8_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# 2-row SIMT also for M=16-56 with low N (TINY MMA underfills SMs at low N)
if M >= 16 and M <= 56 and N <= 2048 and (N & 7) == 0:
if _SIMT2_COMPILED is None:
_SIMT2_COMPILED = _compile_simt2()
_SIMT2_COMPILED(
_fp8_ptr(A),
_fp8_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# SIMT for M<5 (large N), or M=5-15 with N<5120 (fallback)
if M < 5 or (M < 16 and N < 5120):
if _SIMT8_COMPILED is None:
_SIMT8_COMPILED = _compile_simt8()
_SIMT8_COMPILED(
_fp8_ptr(A),
_fp8_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# Small-M MMA path.
if M >= 256 and M < 1024 and N <= 2048 and (N & (BN_S - 1)) == 0:
if _MMA_S_COMPILED is None:
_MMA_S_COMPILED = _compile_mma_small()
_MMA_S_COMPILED(
_u32_ptr(A),
_u32_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# Large MMA for M>=256 with N divisible by 128
if M >= 256 and (N & (BN_L - 1)) == 0:
if _MMA_L_COMPILED is None:
_MMA_L_COMPILED = _compile_mma_large()
_MMA_L_COMPILED(
_u32_ptr(A),
_u32_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# Med MMA for M=64-255 N>=16384 - 32x32 register tile
if M >= 64 and M < 256 and N >= 16384 and (N & 31) == 0:
if _MMA_M_COMPILED is None:
_MMA_M_COMPILED = _compile_mma_med()
_MMA_M_COMPILED(
_u32_ptr(A),
_u32_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return

# Tiny MMA for M=5..255 - register-only, no barriers
if M >= 5 and M < 256:
if _MMA_T_COMPILED is None:
_MMA_T_COMPILED = _compile_mma_tiny()
_MMA_T_COMPILED(
_u32_ptr(A),
_u32_ptr(B),
_f32_ptr(A_scale),
_f32_ptr(B_scale),
_bf16_ptr(out),
M,
N,
K,
stream,
)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Match each dispatch branch to the launcher's divisibility contract.

Several branches can currently call a launcher with weaker guards than that launcher declares. For example, _launch_simt1_v16 requires K % 512 == 0 but the branch only checks K % 16, _launch_simt2 / _launch_simt8 require N % 64 == 0 but only gate on N % 8, and the tiny MMA fallback never enforces _launch_mma_tiny’s N % 32 / K % 256 contract. That can route unsupported shapes into kernels that assume stricter tile boundaries.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.py` around
lines 1771 - 1905, The dispatch branches call launchers with weaker divisibility
checks than the launchers require; update each branch guard to match the
launcher's contract: for the SIMT1 path (symbols _SIMT1_V16_COMPILED /
_compile_simt1_v16) require K % 512 == 0 (use (K & 511) == 0) instead of only K
% 16; for both SIMT2 and SIMT8 paths (symbols _SIMT2_COMPILED / _compile_simt2
and _SIMT8_COMPILED / _compile_simt8) require N % 64 == 0 (use (N & 63) == 0)
not just N % 8; and for the tiny MMA fallback (symbols _MMA_T_COMPILED /
_compile_mma_tiny) add the missing N % 32 == 0 and K % 256 == 0 checks (use (N &
31) == 0 and (K & 255) == 0) before invoking the compiled kernel so kernels
never receive shapes that violate their tile/divisibility contracts.

Comment on lines +12 to +21
def _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid):
num_bid_m = ct.cdiv(M, tm)
num_bid_n = ct.cdiv(N, tn)
num_bid_in_group = group_size_m * num_bid_n
group_id = bid // num_bid_in_group
first_bid_m = group_id * group_size_m
current_group_size_m = min(num_bid_m - first_bid_m, group_size_m)
bid_m = first_bid_m + (bid % current_group_size_m)
bid_n = (bid % num_bid_in_group) // current_group_size_m
return bid_m, bid_n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Fix the grouped swizzle math for the final partial M-group.

bid_m needs to be derived from the CTA's local id inside the current group. Using bid % current_group_size_m only works when the tail-group height happens to divide group_size_m * num_bid_n; otherwise the last group gets duplicate/skipped tiles.

🛠️ Suggested fix
 def _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid):
     num_bid_m = ct.cdiv(M, tm)
     num_bid_n = ct.cdiv(N, tn)
     num_bid_in_group = group_size_m * num_bid_n
     group_id = bid // num_bid_in_group
     first_bid_m = group_id * group_size_m
     current_group_size_m = min(num_bid_m - first_bid_m, group_size_m)
-    bid_m = first_bid_m + (bid % current_group_size_m)
-    bid_n = (bid % num_bid_in_group) // current_group_size_m
+    local_bid = bid % num_bid_in_group
+    bid_m = first_bid_m + (local_bid % current_group_size_m)
+    bid_n = local_bid // current_group_size_m
     return bid_m, bid_n
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py` around
lines 12 - 21, The grouped swizzle in _swizzle_2d_from_bid computes bid_m using
bid % current_group_size_m which fails for a final partial M-group; instead
compute the CTA-local id within the group (local_id = bid - group_id *
num_bid_in_group) and derive bid_m = first_bid_m + (local_id %
current_group_size_m) and bid_n = local_id // current_group_size_m so the last
partial M-group maps tiles correctly; update the calculations in
_swizzle_2d_from_bid to use local_id and current_group_size_m as described.

Comment on lines +24 to +59
@ct.kernel
def bmm_fp8_kernel(
A,
B,
A_scale,
B_scale,
out,
tm: ConstInt,
tn: ConstInt,
tk: ConstInt,
num_tiles_k: ConstInt,
group_size_m: ConstInt,
):
bid = ct.bid(0)
M = A.shape[1]
N = B.shape[2]
sa = ct.load(A_scale, (0,), shape=(1,)).astype(ct.float32)
sb = ct.load(B_scale, (0,), shape=(1,)).astype(ct.float32)
scale = sa.item() * sb.item()
bid_m, bid_n = _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid)
accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32)
zero_pad = ct.PaddingMode.ZERO
for k in range(num_tiles_k):
a = ct.load(
A, index=(0, bid_m, k), shape=(1, tm, tk), padding_mode=zero_pad, latency=10
)
a = ct.reshape(a, (tm, tk))
b = ct.load(
B, index=(0, k, bid_n), shape=(1, tk, tn), padding_mode=zero_pad, latency=10
)
b = ct.reshape(b, (tk, tn))
accumulator = ct.mma(a, b, acc=accumulator)
accumulator = accumulator * scale
result = ct.astype(accumulator, ct.bfloat16)
result_3d = ct.reshape(result, (1, tm, tn))
ct.store(out, index=(0, bid_m, bid_n), tile=result_3d)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

This backend silently computes only batch 0.

All four kernels hard-code batch index 0 in their ct.load/ct.store calls, and run() launches only a 1D grid over M/N tiles. For any A.shape[0] > 1, later batches are never produced, so the BMM result is wrong.

🛡️ Minimal safe guard until batched launch support exists
 def run(A, B, A_scale, B_scale, out):
+    if A.shape[0] != 1 or B.shape[0] != 1 or out.shape[0] != 1:
+        raise ValueError("cuTile bmm_fp8_sm121 currently supports batch size 1")
+
     M = A.shape[1]
     K = A.shape[2]
     N = B.shape[2]

Also applies to: 62-93, 96-131, 134-182, 285-327

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py` around
lines 24 - 59, The kernel currently hard-codes batch 0 via ct.load/ct.store
index=(0, ...), so for A.shape[0] > 1 later batches are never computed; update
bmm_fp8_kernel (and the other kernels referenced) to read the batch index from
the kernel launch (e.g. use ct.bid(1) or an additional bid variable instead of
the literal 0), replace all occurrences of index=(0, bid_m, k) and index=(0, k,
bid_n) and the store index=(0, bid_m, bid_n) with index=(batch_id, bid_m, k),
index=(batch_id, k, bid_n), and index=(batch_id, bid_m, bid_n) respectively, and
ensure the host run()/launch code creates a 2D/3D grid that includes the batch
dimension so the kernel’s batch_id maps correctly; apply the same change to the
other kernels mentioned (lines ~62-93, 96-131, 134-182, 285-327).

Comment on lines +110 to +156
if not (a.is_cuda and b.is_cuda and a_descale.is_cuda and b_descale.is_cuda):
return None
if torch.cuda.get_device_capability(a.device) != (12, 1):
return None
if a.dtype != torch.uint8 or b.dtype != torch.uint8:
return None
if a_descale.dtype != torch.uint8 or b_descale.dtype != torch.uint8:
return None
if a.ndim != 2 or b.ndim != 2 or a_descale.ndim != 2 or b_descale.ndim != 2:
return None

m = int(a.shape[0])
k = int(a.shape[1] * 2)
n = int(b.shape[1])
key = (
m,
k,
n,
block_size,
out_dtype,
use_8x4_sf_layout,
_normalize_backend(backend),
use_nvfp4,
)
impl = _WORKLOAD_LUT.get(key)
if impl is None or not _impl_available(impl):
return None
if tuple(b.shape) != (k // 2, n):
return None

sf_m = ((m + 127) // 128) * 128
sf_n = ((n + 127) // 128) * 128
if tuple(a_descale.shape) != (sf_m, k // block_size):
return None
if tuple(b_descale.shape) != (k // block_size, sf_n):
return None
if not a.is_contiguous() or not a_descale.is_contiguous():
return None
if not _is_column_major_view(b) or not _is_column_major_view(b_descale):
return None

if out is not None:
if out.dtype != torch.bfloat16 or tuple(out.shape) != (m, n):
return None
if not out.is_cuda or not out.is_contiguous():
return None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject mixed-device tensors in _select_impl.

The predicate only checks is_cuda, so b, the descales, or out can still live on a different GPU than a. The specialized runner later uses a.device for stream selection and forwards every tensor's raw pointer into the same launch, which makes mixed-device inputs unsafe instead of ineligible.

Suggested fix
     if not (a.is_cuda and b.is_cuda and a_descale.is_cuda and b_descale.is_cuda):
         return None
+    devices = {a.device, b.device, a_descale.device, b_descale.device}
+    if out is not None:
+        devices.add(out.device)
+    if len(devices) != 1:
+        return None
     if torch.cuda.get_device_capability(a.device) != (12, 1):
         return None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py` around
lines 110 - 156, The current device checks in _select_impl only verify tensors
are CUDA but not that they all live on the same GPU as a, allowing unsafe
mixed-device launches; update the predicate to ensure b.device,
a_descale.device, b_descale.device (and out.device when out is provided) equal
a.device before accepting the specialized impl. Locate the selection logic in
function _select_impl (and the final out validation block) and add explicit
device-equality checks (e.g., compare .device or .get_device() values) for b,
a_descale, b_descale, and out to reject cases where any tensor is on a different
CUDA device than a. Ensure these checks run before returning the impl.

Comment on lines +200 to +210
def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor:
device = _device_key(device)
if alpha is None:
cached = _ALPHA_ONE_CACHE.get(device)
if cached is None:
cached = torch.tensor([1.0], dtype=torch.float32, device=device)
_ALPHA_ONE_CACHE[device] = cached
return cached
if alpha.dim() == 0:
return alpha.unsqueeze(0)
return alpha.reshape(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Normalize alpha onto CUDA float32 before exposing its pointer to the kernel.

_prepare_alpha() currently preserves the caller's device and dtype. If a user passes a CPU scalar or an fp16/fp64 tensor, cute_dsl.kernel.run() still builds a device float32* from that data_ptr(), which turns into an invalid pointer/type reinterpretation at launch time.

Suggested fix
 def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor:
     device = _device_key(device)
     if alpha is None:
         cached = _ALPHA_ONE_CACHE.get(device)
         if cached is None:
             cached = torch.tensor([1.0], dtype=torch.float32, device=device)
             _ALPHA_ONE_CACHE[device] = cached
         return cached
-    if alpha.dim() == 0:
-        return alpha.unsqueeze(0)
-    return alpha.reshape(1)
+    alpha = alpha.to(device=device, dtype=torch.float32)
+    if alpha.dim() == 0:
+        alpha = alpha.unsqueeze(0)
+    else:
+        alpha = alpha.reshape(1)
+    return alpha.contiguous()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor:
device = _device_key(device)
if alpha is None:
cached = _ALPHA_ONE_CACHE.get(device)
if cached is None:
cached = torch.tensor([1.0], dtype=torch.float32, device=device)
_ALPHA_ONE_CACHE[device] = cached
return cached
if alpha.dim() == 0:
return alpha.unsqueeze(0)
return alpha.reshape(1)
def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor:
device = _device_key(device)
if alpha is None:
cached = _ALPHA_ONE_CACHE.get(device)
if cached is None:
cached = torch.tensor([1.0], dtype=torch.float32, device=device)
_ALPHA_ONE_CACHE[device] = cached
return cached
alpha = alpha.to(device=device, dtype=torch.float32)
if alpha.dim() == 0:
alpha = alpha.unsqueeze(0)
else:
alpha = alpha.reshape(1)
return alpha.contiguous()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py` around
lines 200 - 210, _prepare_alpha currently preserves caller device and dtype
which can yield invalid pointers when kernel expects a CUDA float32*; change it
to always return a CUDA float32 1-D tensor (device normalized via _device_key)
before exposing data_ptr. Concretely, when alpha is None, ensure the cached
value in _ALPHA_ONE_CACHE is torch.tensor([1.0], dtype=torch.float32,
device=device); when alpha is provided, move/convert it to device and dtype
torch.float32 (e.g., alpha.to(device=device, dtype=torch.float32)), then
normalize shape with unsqueeze(0) or reshape(1) as before; keep using the same
cache key (_ALPHA_ONE_CACHE) and keep function name _prepare_alpha so callers
(and cute_dsl.kernel.run) receive a CUDA float32 pointer.

Comment on lines +213 to +219
input = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn)
mat2 = torch.randn((batch, n, k), device="cuda", dtype=torch.bfloat16).transpose(
-2, -1
)
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=torch.float8_e4m3fn)
reference = torch.bmm(input, mat2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid shadowing Python builtins in tests.

Line 213 uses input as a variable name, which triggers Ruff A001 and can fail lint-gated CI.

Proposed patch
-    input = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16)
-    input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn)
+    inp = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16)
+    input_fp8, input_inv_s = to_float8(inp, dtype=torch.float8_e4m3fn)
@@
-    reference = torch.bmm(input, mat2)
+    reference = torch.bmm(inp, mat2)
🧰 Tools
🪛 Ruff (0.15.12)

[error] 213-213: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gemm/test_specialized_gemm_routing.py` around lines 213 - 219, The test
uses the builtin name "input" as a local variable (seen around creation of the
tensor and calls to to_float8 and torch.bmm), which shadows Python's builtin and
triggers lint error A001; rename that variable (and its uses: input_fp8,
input_inv_s, and the reference bmm call) to a non-builtin name like input_tensor
or inp throughout the block so to_float8(input_tensor, ...) and
torch.bmm(input_tensor, mat2) are used instead.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces specialized GEMM kernels for SM121 architectures, supporting both FP4 and FP8 precisions across multiple backends including CUDA, CUTE DSL, and cuTile. The implementation includes a runtime routing mechanism that detects specific workloads and dispatches them to optimized kernels, along with an environment flag to disable this behavior for benchmarking. Key infrastructure updates include enhancements to the autotuner's caching logic for better disambiguation and the addition of comprehensive benchmarking and testing scripts. Review feedback highlighted the need for improved robustness in the benchmark script, specifically suggesting guards against division by zero and invalid logarithmic inputs in the speedup and geometric mean calculations.

disabled_cosine = cosine(reference, disabled_snapshot)
enabled_cosine = cosine(reference, enabled_snapshot)
routed_cosine = cosine(disabled_snapshot, enabled_snapshot)
speedup = disabled_ms / enabled_ms
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation for speedup could result in a ZeroDivisionError if enabled_ms is zero. While this may be unlikely for GPU timings, it would be safer to handle this edge case to prevent the benchmark from crashing.

Comment on lines +257 to +258
def geomean(values: list[float]) -> float:
return math.exp(sum(math.log(v) for v in values) / len(values))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The geomean function may fail in two scenarios:

  1. If values contains a zero or negative number, math.log(v) will raise a ValueError. This could happen if disabled_ms is 0, leading to a speedup of 0.
  2. If values is an empty list, len(values) will be zero, causing a ZeroDivisionError.

Please consider adding checks for these cases to improve the script's robustness.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !658 has been updated with latest changes, and the CI pipeline #50936022 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/autotuner.py (1)

887-938: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Legacy file-key fallback block is unreachable.

_get_file_cache_key(cache_key) already returns the legacy 3-field key when cache_key[4] == () (see lines 1547–1548). So at line 890, file_key is identical to what _get_legacy_file_cache_key(cache_key) would produce in that branch. The block at lines 918–937 only runs when cache_key[4] == (), computes the same string again, and re-queries self._file_configs — it will never hit when line 895 missed.

In its current shape this code is dead: the "legacy config file" log path is unreachable, and no actual backward-compat lookup happens for runners that do return non-empty extras (which is presumably the case that motivated the new SM121-specialized runners). Two plausible fixes, depending on intent:

  • If the intent is to always match legacy 3-field cache files (including for newer runners that now emit extras), make the legacy lookup run when cache_key[4] != () instead, and have _get_file_cache_key always return the new 4-field format:
♻️ Option A: always use the 4-field key, fall back to legacy when extras present
     `@staticmethod`
     def _get_file_cache_key(cache_key: Tuple) -> str:
-        if cache_key[4] == ():
-            return AutoTuner._get_legacy_file_cache_key(cache_key)
         return str((cache_key[0], cache_key[1], cache_key[3], cache_key[4]))
-                # Preserve compatibility with older cache files only for
-                # runners that do not need extra key material. Reusing a
-                # legacy key when extras are non-empty can apply a tactic to a
-                # shape or dtype it was never profiled for.
-                if cache_key[4] == ():
+                # Preserve compatibility with older cache files written before
+                # the `extras` field existed. Only safe when the runner has no
+                # extras, since legacy keys cannot disambiguate dtype/etc.
+                if cache_key[4] == ():
                     legacy_file_key = AutoTuner._get_legacy_file_cache_key(cache_key)
                     if legacy_file_key in self._file_configs:
                         ...

Note: this would also require updating save_configs so newly written files always use the 4-field format (which then need migration semantics for already-saved 3-field files).

  • If the intent is that legacy 3-field files should only be honored when the runner has no extras (and the current _get_file_cache_key behavior is correct for write compatibility), then lines 914–937 are simply redundant and should be removed:
♻️ Option B: drop the dead legacy fallback block
-                # Preserve compatibility with older cache files only for
-                # runners that do not need extra key material. Reusing a
-                # legacy key when extras are non-empty can apply a tactic to a
-                # shape or dtype it was never profiled for.
-                if cache_key[4] == ():
-                    legacy_file_key = AutoTuner._get_legacy_file_cache_key(cache_key)
-                    if legacy_file_key in self._file_configs:
-                        runner_name, tactic = self._file_configs[legacy_file_key]
-                        runner_id = next(
-                            (
-                                i
-                                for i, runner in enumerate(runners)
-                                if runner.__class__.__name__ == runner_name
-                            ),
-                            0,
-                        )
-                        log_key = (custom_op, runner_name)
-                        if log_key not in self._logged_file_hits:
-                            self._logged_file_hits.add(log_key)
-                            logger.info(
-                                f"[Autotuner]: Config cache hit for {custom_op} "
-                                f"(runner={runner_name}, source=legacy config file)"
-                            )
-                        return True, runner_id, tactic, None
-

Worth confirming the intent before picking one; the resulting behavior for cross-version cache files differs materially.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/autotuner.py` around lines 887 - 938, The legacy-file fallback
block is unreachable because AutoTuner._get_file_cache_key(cache_key) already
returns the 3-field legacy key when cache_key[4] == (), so the subsequent legacy
lookup using _get_legacy_file_cache_key never adds new matches; either remove
the redundant block (delete the branch that computes legacy_file_key and its
logging/return) if legacy keys should only be honored when extras are empty, or
change the condition to run the legacy lookup when cache_key[4] != () so legacy
3-field configs are matched for runners that now emit extras (and ensure
save_configs consistently writes the chosen format); update references to
file_key, cache_key, AutoTuner._get_file_cache_key,
AutoTuner._get_legacy_file_cache_key, self._file_configs and
self._logged_file_hits accordingly.
🧹 Nitpick comments (1)
flashinfer/autotuner.py (1)

1650-1651: 💤 Low value

Prefix unused unpacked names with _.

Ruff flags custom_op, profile, extras (line 1650) and runner_id (line 1651) as unused. _get_file_cache_key consumes the whole cache_key tuple directly, and only runner_class_name/tactic are read.

♻️ Suggested rename
-                custom_op, runner_class_name, _runner_hash, profile, extras = cache_key
-                runner_id, tactic, _opt_profile = cache_value
+                _custom_op, runner_class_name, _runner_hash, _profile, _extras = cache_key
+                _runner_id, tactic, _opt_profile = cache_value

Or, since the helper already takes the full tuple, drop the unpack entirely and use cache_value[1] for tactic.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/autotuner.py` around lines 1650 - 1651, The unpacking of
cache_key/cache_value exposes unused names; either prefix the unused variables
with an underscore (e.g., _custom_op, _runner_hash, _profile, _extras,
_runner_id) or remove the unpack and index the tuple directly (use cache_key and
cache_value[1] for tactic) in the loop where _get_file_cache_key is called;
update the variables in that block (runner_class_name, tactic) so only the used
names remain to satisfy Ruff.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@flashinfer/autotuner.py`:
- Around line 887-938: The legacy-file fallback block is unreachable because
AutoTuner._get_file_cache_key(cache_key) already returns the 3-field legacy key
when cache_key[4] == (), so the subsequent legacy lookup using
_get_legacy_file_cache_key never adds new matches; either remove the redundant
block (delete the branch that computes legacy_file_key and its logging/return)
if legacy keys should only be honored when extras are empty, or change the
condition to run the legacy lookup when cache_key[4] != () so legacy 3-field
configs are matched for runners that now emit extras (and ensure save_configs
consistently writes the chosen format); update references to file_key,
cache_key, AutoTuner._get_file_cache_key, AutoTuner._get_legacy_file_cache_key,
self._file_configs and self._logged_file_hits accordingly.

---

Nitpick comments:
In `@flashinfer/autotuner.py`:
- Around line 1650-1651: The unpacking of cache_key/cache_value exposes unused
names; either prefix the unused variables with an underscore (e.g., _custom_op,
_runner_hash, _profile, _extras, _runner_id) or remove the unpack and index the
tuple directly (use cache_key and cache_value[1] for tactic) in the loop where
_get_file_cache_key is called; update the variables in that block
(runner_class_name, tactic) so only the used names remain to satisfy Ruff.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a4d302ca-02c5-4218-81c7-cfdec9921290

📥 Commits

Reviewing files that changed from the base of the PR and between 01dac0a and e98dfcf.

📒 Files selected for processing (1)
  • flashinfer/autotuner.py

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm pending passing CI tests and vLLM integration validation

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #50936022: 12/20 passed

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented May 11, 2026

seems clean? ready for auto-merge?

@aleozlx aleozlx removed the v0.6.12 label May 11, 2026
@aleozlx aleozlx added the v0.6.11 release blocker label for 0.6.11 label May 11, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented May 11, 2026

swapped labels for 0.6.11.post1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

arch: DGX Spark op: gemm run-ci v0.6.11 release blocker label for 0.6.11

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants