feat: add specialized gemm kernel for sm121#3283
Conversation
📝 WalkthroughWalkthroughAdds SM121-specialized FP4 (mm_fp4) and FP8 (bmm_fp8) GEMM routing, multi-backend kernels (CUDA/CUTE-DSL/cuTile), env toggle, autotuner cache-key changes, AOT registration, tests, and a benchmarking tool. ChangesSM121-Specialized GEMM Routing and Kernels
Sequence Diagram(s)sequenceDiagram
participant Caller
participant is_bmm as is_bmm_fp8_sm121_specialized_problem
participant select as _select_impl
participant runSpecial as run_bmm_fp8_sm121_specialized
participant cuda as cuda_kernel
participant cute as cute_dsl_kernel
participant cutile as cutile_kernel
Caller->>is_bmm: check(A,B,scales,dtype,out)
is_bmm->>select: validate shape/dtype/device and LUT
select-->>is_bmm: impl
is_bmm-->>Caller: specialized available
Caller->>runSpecial: execute(A,B,scales,out)
runSpecial->>select: get impl
alt impl=="cuda"
select-->>runSpecial: cuda
runSpecial->>cuda: run(A,B,scales,out,workspace)
cuda->>cuda: compute_splits & launch_fp8_gemm
cuda-->>runSpecial: out (BF16)
else impl=="cute_dsl"
select-->>runSpecial: cute_dsl
runSpecial->>cute: run(A,B,scales,out)
cute->>cute: dispatch kernel variant
cute-->>runSpecial: out (BF16)
else impl=="cutile"
select-->>runSpecial: cutile
runSpecial->>cutile: run(A,B,scales,out)
cutile->>cutile: build_dispatch & launch
cutile-->>runSpecial: out (BF16)
end
runSpecial-->>Caller: output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
6090-6108:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAvoid eagerly constructing the generic
b12xrunner for specialized SM121 problems.When
specialized_problemis true in tuning mode, this still instantiates_b12x_gemm_fp4_runner()before the specialized runner is appended. That reintroduces the CuTe DSL/CUTLASS import path that_b12x_gemm_fp4_requirement()intentionally bypasses, so autotune/profile can fail before the specialized kernel is reachable.Suggested fix
- runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] + runners = [] + for cur_backend in backends: + if specialized_problem and cur_backend == "b12x": + try: + _check_cute_dsl_availability() + except RuntimeError: + continue + runners.append(backend_to_runner_factory[cur_backend]()) + custom_op = "fp4_gemm" if specialized_problem: runners.append(_mm_fp4_sm121_specialized_runner()) custom_op = "fp4_gemm_sm121_specialized"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/gemm_base.py` around lines 6090 - 6108, The code eagerly constructs all runners by calling backend_to_runner_factory[...]() for every backend, which causes _b12x_gemm_fp4_runner() to be instantiated even when specialized_problem is true; change the construction to iterate backends and call the factory only when appropriate (skip calling the "b12x" factory if specialized_problem is true) so that _b12x_gemm_fp4_runner is not created prematurely; use the existing backend_to_runner_factory mapping (and its lambda values) and append _mm_fp4_sm121_specialized_runner() and set custom_op when specialized_problem is true.
🧹 Nitpick comments (4)
flashinfer/utils.py (1)
1167-1173: ⚡ Quick winConsider adding debug logging and narrowing exception scope.
The exception handling enables robust auto-backend selection by treating requirement-check failures as "backend unsuitable" rather than fatal errors. However, silently catching broad exceptions like
RuntimeErrorandValueErrorcreates observability and debugging challenges:
- Observability gap: When auto-selection fails or produces unexpected results, there's no trace of which backends were attempted and why they were rejected.
- Broad exception handling:
RuntimeErrorcan indicate actual runtime failures (CUDA errors, resource exhaustion) that may deserve propagation rather than silent continuation.- Masked bugs: Errors in requirement checker logic itself could be hidden rather than surfaced.
Suggested improvements
Option 1: Add debug logging
except (ValueError, RuntimeError): # In backend="auto", requirement functions are probed before # compute-capability filtering. Optional backend dependency # failures, such as CuTe DSL being unavailable, should only # make that backend unsuitable and must not block later # candidates. + # Log at debug level for troubleshooting + import logging + logging.getLogger(__name__).debug( + f"Backend '{backend}' skipped during auto-selection: {sys.exc_info()[1]}" + ) continueOption 2: Narrow exception handling (if specific exception types are known)
If the specialized kernels consistently raise specific exceptions for missing dependencies (e.g.,
ImportError,ModuleNotFoundError), consider catching those explicitly:except (ValueError, ImportError, ModuleNotFoundError): # More specific exception types continue🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/utils.py` around lines 1167 - 1173, The except block that currently reads "except (ValueError, RuntimeError): ... continue" silently swallows failures during the backend="auto" requirement checks; change this to catch only dependency/missing-module exceptions (e.g., except (ValueError, ImportError, ModuleNotFoundError):) and add a debug-level log that records the backend being probed and the caught exception before continuing. Locate the try/except around the backend requirement check (the block that currently uses "continue" on failure) and ensure you emit logger.debug(...) (using the module logger or logging.getLogger(__name__)) with the backend identifier and exception details, while allowing unexpected RuntimeError-like exceptions to propagate.flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json (1)
1-191: ⚡ Quick winConsider adding JSON schema validation.
To prevent configuration errors and ensure consistency, consider validating this file against a JSON schema during tests or at module load time. This would catch typos in field names, invalid dtype values, or missing required fields early.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json` around lines 1 - 191, The workloads.json list lacks schema validation, which allows typos or invalid values in fields like "b", "m", "k", "n", "out_dtype", "backend", and "impl"; add a JSON Schema that defines required properties, types, allowed enums for out_dtype/backend/impl, and ranges (e.g., positive integers) and enforce it either at module load (where the file is parsed) or in tests using a validator library (e.g., ajv for JS/TS or jsonschema for Python) to fail fast on invalid entries.flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json (1)
1-1091: ⚡ Quick winConsider adding JSON schema validation.
Similar to the bmm_fp8_sm121 workloads file, adding schema validation would help catch configuration errors such as typos in field names, invalid dtype/backend values, or missing required fields. This could be done during tests or at module load time.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json` around lines 1 - 1091, The workloads.json lacks schema validation; add a JSON Schema (e.g., mm_fp4_sm121_schema.json) that requires fields m,k,n,block_size,out_dtype,backend,use_nvfp4,impl and use_8x4_sf_layout, constrains types and allowed values for out_dtype and backend, and then validate the JSON at load/test time using a validator (jsonschema or similar) inside the workload loader (e.g., load_workloads / parse_workloads) so invalid keys/typos or missing fields are rejected during module load or in CI tests.benchmarks/bench_specialized_gemm_routing.py (1)
26-26: ⚡ Quick winUse the benchmark timer via the public testing entrypoint.
Line 26 imports
bench_gpu_timefromflashinfer.testing.utils; please switch to theflashinfer.testingentrypoint required by the benchmark guideline.Proposed patch
-from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing import bench_gpu_time🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/bench_specialized_gemm_routing.py` at line 26, Replace the direct import from flashinfer.testing.utils with the public testing entrypoint: change the import so bench_gpu_time is imported from flashinfer.testing (i.e., use "from flashinfer.testing import bench_gpu_time") to follow the benchmark guideline and ensure the public API is used; update the existing import that currently references flashinfer.testing.utils to reference flashinfer.testing instead.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6218-6250: The runner currently hardcodes the backend string
"cublas" in both get_valid_tactics and forward; update
_bmm_fp8_sm121_specialized_runner to accept (or capture) the caller's backend
and use that variable when calling is_bmm_fp8_sm121_specialized_problem and
run_bmm_fp8_sm121_specialized (instead of the literal "cublas") so the predicate
and execution use the same backend as the caller; apply the same change to the
analogous runner at the other location (the block around the second occurrence)
to ensure consistent backend propagation.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py`:
- Around line 127-146: The _cuda_workspace function currently keys the shared
buffers by (device, m, n, splits) causing cross-stream buffer reuse; modify it
to be stream-scoped by including the current CUDA stream in the cache key (e.g.,
torch.cuda.current_stream(device) or stream.cuda_stream) or, alternatively, skip
caching and allocate a fresh workspace when concurrent streams/splits are
possible (i.e., when splits > 1). Update the cache key usage for
_WORKSPACE_CACHE and the empty-cache behavior for _EMPTY_WORKSPACE_CACHE
accordingly so each CUDA stream gets its own tensor or a new tensor is returned
for split-K cases to avoid concurrent writes.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu`:
- Around line 122-139: Add a pre-dispatch guard that rejects shapes where K is
not a multiple of 16 to avoid kernel loads reading past row tails: before
creating the device guard / entering the batch loop (after computing splits and
workspace), check if (K % 16 != 0) and fail fast with a clear error (e.g. via
TVM_FFI_ICHECK or similar) indicating unsupported K for the fp8 kernel; keep
this check near compute_splits/required_workspace and mention launch_fp8_gemm in
the message so callers know which backend requires the constraint.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu`:
- Around line 270-280: The two-element BF16 stores using __nv_bfloat162 (writing
v01/v23 into Out_bf) can overrun when col_base == N-1 because they always write
two outputs; update the non-split-K epilogue to guard the second element the
same way as the split-K path: keep the existing check for the first element
(col_base < N) for each row (row0,row1) but add an additional condition col_base
+ 1 < N (or equivalent) before performing the __nv_bfloat162 write that writes
two adjacent outputs into Out_bf, or alternatively only write a single BF16
element when col_base == N-1 so the tail element is not overwritten (adjust the
code around v01/v23 and the reinterpret_cast to enforce this).
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.py`:
- Around line 1771-1905: The dispatch branches call launchers with weaker
divisibility checks than the launchers require; update each branch guard to
match the launcher's contract: for the SIMT1 path (symbols _SIMT1_V16_COMPILED /
_compile_simt1_v16) require K % 512 == 0 (use (K & 511) == 0) instead of only K
% 16; for both SIMT2 and SIMT8 paths (symbols _SIMT2_COMPILED / _compile_simt2
and _SIMT8_COMPILED / _compile_simt8) require N % 64 == 0 (use (N & 63) == 0)
not just N % 8; and for the tiny MMA fallback (symbols _MMA_T_COMPILED /
_compile_mma_tiny) add the missing N % 32 == 0 and K % 256 == 0 checks (use (N &
31) == 0 and (K & 255) == 0) before invoking the compiled kernel so kernels
never receive shapes that violate their tile/divisibility contracts.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py`:
- Around line 24-59: The kernel currently hard-codes batch 0 via
ct.load/ct.store index=(0, ...), so for A.shape[0] > 1 later batches are never
computed; update bmm_fp8_kernel (and the other kernels referenced) to read the
batch index from the kernel launch (e.g. use ct.bid(1) or an additional bid
variable instead of the literal 0), replace all occurrences of index=(0, bid_m,
k) and index=(0, k, bid_n) and the store index=(0, bid_m, bid_n) with
index=(batch_id, bid_m, k), index=(batch_id, k, bid_n), and index=(batch_id,
bid_m, bid_n) respectively, and ensure the host run()/launch code creates a
2D/3D grid that includes the batch dimension so the kernel’s batch_id maps
correctly; apply the same change to the other kernels mentioned (lines ~62-93,
96-131, 134-182, 285-327).
- Around line 12-21: The grouped swizzle in _swizzle_2d_from_bid computes bid_m
using bid % current_group_size_m which fails for a final partial M-group;
instead compute the CTA-local id within the group (local_id = bid - group_id *
num_bid_in_group) and derive bid_m = first_bid_m + (local_id %
current_group_size_m) and bid_n = local_id // current_group_size_m so the last
partial M-group maps tiles correctly; update the calculations in
_swizzle_2d_from_bid to use local_id and current_group_size_m as described.
In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py`:
- Around line 200-210: _prepare_alpha currently preserves caller device and
dtype which can yield invalid pointers when kernel expects a CUDA float32*;
change it to always return a CUDA float32 1-D tensor (device normalized via
_device_key) before exposing data_ptr. Concretely, when alpha is None, ensure
the cached value in _ALPHA_ONE_CACHE is torch.tensor([1.0], dtype=torch.float32,
device=device); when alpha is provided, move/convert it to device and dtype
torch.float32 (e.g., alpha.to(device=device, dtype=torch.float32)), then
normalize shape with unsqueeze(0) or reshape(1) as before; keep using the same
cache key (_ALPHA_ONE_CACHE) and keep function name _prepare_alpha so callers
(and cute_dsl.kernel.run) receive a CUDA float32 pointer.
- Around line 110-156: The current device checks in _select_impl only verify
tensors are CUDA but not that they all live on the same GPU as a, allowing
unsafe mixed-device launches; update the predicate to ensure b.device,
a_descale.device, b_descale.device (and out.device when out is provided) equal
a.device before accepting the specialized impl. Locate the selection logic in
function _select_impl (and the final out validation block) and add explicit
device-equality checks (e.g., compare .device or .get_device() values) for b,
a_descale, b_descale, and out to reject cases where any tensor is on a different
CUDA device than a. Ensure these checks run before returning the impl.
In `@tests/gemm/test_specialized_gemm_routing.py`:
- Around line 213-219: The test uses the builtin name "input" as a local
variable (seen around creation of the tensor and calls to to_float8 and
torch.bmm), which shadows Python's builtin and triggers lint error A001; rename
that variable (and its uses: input_fp8, input_inv_s, and the reference bmm call)
to a non-builtin name like input_tensor or inp throughout the block so
to_float8(input_tensor, ...) and torch.bmm(input_tensor, mat2) are used instead.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 6090-6108: The code eagerly constructs all runners by calling
backend_to_runner_factory[...]() for every backend, which causes
_b12x_gemm_fp4_runner() to be instantiated even when specialized_problem is
true; change the construction to iterate backends and call the factory only when
appropriate (skip calling the "b12x" factory if specialized_problem is true) so
that _b12x_gemm_fp4_runner is not created prematurely; use the existing
backend_to_runner_factory mapping (and its lambda values) and append
_mm_fp4_sm121_specialized_runner() and set custom_op when specialized_problem is
true.
---
Nitpick comments:
In `@benchmarks/bench_specialized_gemm_routing.py`:
- Line 26: Replace the direct import from flashinfer.testing.utils with the
public testing entrypoint: change the import so bench_gpu_time is imported from
flashinfer.testing (i.e., use "from flashinfer.testing import bench_gpu_time")
to follow the benchmark guideline and ensure the public API is used; update the
existing import that currently references flashinfer.testing.utils to reference
flashinfer.testing instead.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.json`:
- Around line 1-191: The workloads.json list lacks schema validation, which
allows typos or invalid values in fields like "b", "m", "k", "n", "out_dtype",
"backend", and "impl"; add a JSON Schema that defines required properties,
types, allowed enums for out_dtype/backend/impl, and ranges (e.g., positive
integers) and enforce it either at module load (where the file is parsed) or in
tests using a validator library (e.g., ajv for JS/TS or jsonschema for Python)
to fail fast on invalid entries.
In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.json`:
- Around line 1-1091: The workloads.json lacks schema validation; add a JSON
Schema (e.g., mm_fp4_sm121_schema.json) that requires fields
m,k,n,block_size,out_dtype,backend,use_nvfp4,impl and use_8x4_sf_layout,
constrains types and allowed values for out_dtype and backend, and then validate
the JSON at load/test time using a validator (jsonschema or similar) inside the
workload loader (e.g., load_workloads / parse_workloads) so invalid keys/typos
or missing fields are rejected during module load or in CI tests.
In `@flashinfer/utils.py`:
- Around line 1167-1173: The except block that currently reads "except
(ValueError, RuntimeError): ... continue" silently swallows failures during the
backend="auto" requirement checks; change this to catch only
dependency/missing-module exceptions (e.g., except (ValueError, ImportError,
ModuleNotFoundError):) and add a debug-level log that records the backend being
probed and the caught exception before continuing. Locate the try/except around
the backend requirement check (the block that currently uses "continue" on
failure) and ensure you emit logger.debug(...) (using the module logger or
logging.getLogger(__name__)) with the backend identifier and exception details,
while allowing unexpected RuntimeError-like exceptions to propagate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f1e302b2-beca-4279-b584-532ea49ba229
📒 Files selected for processing (24)
benchmarks/bench_specialized_gemm_routing.pyflashinfer/aot.pyflashinfer/autotuner.pyflashinfer/env.pyflashinfer/gemm/gemm_base.pyflashinfer/gemm/specialized_kernels/__init__.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/__init__.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/__init__.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cuflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cuflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/__init__.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/__init__.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.pyflashinfer/gemm/specialized_kernels/bmm_fp8_sm121/workloads.jsonflashinfer/gemm/specialized_kernels/mm_fp4_sm121/__init__.pyflashinfer/gemm/specialized_kernels/mm_fp4_sm121/cute_dsl/__init__.pyflashinfer/gemm/specialized_kernels/mm_fp4_sm121/cute_dsl/kernel.pyflashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.pyflashinfer/gemm/specialized_kernels/mm_fp4_sm121/workloads.jsonflashinfer/utils.pypyproject.tomltests/gemm/test_specialized_gemm_routing.py
| def _bmm_fp8_sm121_specialized_runner(): | ||
| class BMMFp8Sm121SpecializedRunner(TunableRunner): | ||
| def get_valid_tactics( | ||
| self, | ||
| inputs: List[torch.Tensor], | ||
| profile: OptimizationProfile, | ||
| ) -> list: | ||
| del profile | ||
| A, B, A_scale, B_scale, out, _ = inputs | ||
| if is_bmm_fp8_sm121_specialized_problem( | ||
| A, | ||
| B, | ||
| A_scale, | ||
| B_scale, | ||
| out.dtype, | ||
| out, | ||
| "cublas", | ||
| ): | ||
| return [0] | ||
| return [] | ||
|
|
||
| def forward( | ||
| self, | ||
| inputs: List[torch.Tensor], | ||
| tactic: int = -1, | ||
| do_preparation: bool = False, | ||
| **kwargs, | ||
| ): | ||
| del tactic, do_preparation, kwargs | ||
| A, B, A_scale, B_scale, out, _ = inputs | ||
| return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, "cublas") | ||
|
|
||
| return BMMFp8Sm121SpecializedRunner() |
There was a problem hiding this comment.
Pass the selected backend through the SM121 BMM specialized runner.
specialized_problem is computed with the caller’s backend, but the extra runner re-checks and executes with hardcoded "cublas". If the specialized predicate or runtime dispatch uses that parameter to gate disabled variants or pick an implementation, autotune/runtime can diverge from the path that was actually approved here.
Suggested fix
-def _bmm_fp8_sm121_specialized_runner():
+def _bmm_fp8_sm121_specialized_runner(backend: str):
class BMMFp8Sm121SpecializedRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
@@
if is_bmm_fp8_sm121_specialized_problem(
A,
B,
A_scale,
B_scale,
out.dtype,
out,
- "cublas",
+ backend,
):
return [0]
return []
@@
):
del tactic, do_preparation, kwargs
A, B, A_scale, B_scale, out, _ = inputs
- return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, "cublas")
+ return run_bmm_fp8_sm121_specialized(A, B, A_scale, B_scale, out, backend)
return BMMFp8Sm121SpecializedRunner()
@@
extra_runners = None
custom_op = "fp8_gemm"
if specialized_problem:
- extra_runners = [_bmm_fp8_sm121_specialized_runner()]
+ extra_runners = [_bmm_fp8_sm121_specialized_runner(backend)]
custom_op = "fp8_gemm_sm121_specialized"Also applies to: 6335-6358
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/gemm_base.py` around lines 6218 - 6250, The runner currently
hardcodes the backend string "cublas" in both get_valid_tactics and forward;
update _bmm_fp8_sm121_specialized_runner to accept (or capture) the caller's
backend and use that variable when calling is_bmm_fp8_sm121_specialized_problem
and run_bmm_fp8_sm121_specialized (instead of the literal "cublas") so the
predicate and execution use the same backend as the caller; apply the same
change to the analogous runner at the other location (the block around the
second occurrence) to ensure consistent backend propagation.
| def _cuda_workspace(A: torch.Tensor, out: torch.Tensor) -> torch.Tensor: | ||
| device = _device_key(A.device) | ||
| m = int(A.shape[1]) | ||
| k = int(A.shape[2]) | ||
| n = int(out.shape[2]) | ||
| splits = _compute_splits(m, n, k) | ||
| required_numel = splits * m * n if splits > 1 else 0 | ||
| if required_numel == 0: | ||
| cached = _EMPTY_WORKSPACE_CACHE.get(device) | ||
| if cached is None: | ||
| cached = torch.empty((0,), dtype=torch.float32, device=device) | ||
| _EMPTY_WORKSPACE_CACHE[device] = cached | ||
| return cached | ||
|
|
||
| key = (device, m, n, splits) | ||
| cached = _WORKSPACE_CACHE.get(key) | ||
| if cached is None or cached.numel() < required_numel: | ||
| cached = torch.empty((required_numel,), dtype=torch.float32, device=device) | ||
| _WORKSPACE_CACHE[key] = cached | ||
| return cached |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the complete implementation of run_bmm_fp8_sm121_specialized
cat -n flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.pyRepository: flashinfer-ai/flashinfer
Length of output: 9415
🏁 Script executed:
# Look for the CUDA binding and how workspace is passed
rg "run_bmm_fp8_sm121_specialized|_cuda_workspace" -A 10 flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/Repository: flashinfer-ai/flashinfer
Length of output: 4132
🏁 Script executed:
# Check if there are any stream-related parameters in the function signatures
rg "def run_bmm_fp8_sm121_specialized|stream|Stream" -B 2 -A 5 flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/Repository: flashinfer-ai/flashinfer
Length of output: 50380
🏁 Script executed:
# Look for the actual CUDA kernel binding to understand how workspace is used
fd -t f "*.py" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/ -exec grep -l "cuda_lib\|libflashinfer\|ctypes" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 578
🏁 Script executed:
cat -n flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2691
🏁 Script executed:
# Look for get_stream implementation and synchronization
rg "get_stream|cudaStreamSynchronize|cudaEventRecord" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/Repository: flashinfer-ai/flashinfer
Length of output: 182
🏁 Script executed:
# Check if there are tests that use multiple streams
fd -t f "test" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/ -o fd -t f "test*.py" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Look for how get_stream is defined in the codebase
rg "def get_stream|cudaStream_t.*get_stream" -B 2 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 34899
🏁 Script executed:
# Get more context about get_stream behavior
cat -n csrc/tvm_ffi_utils.h | grep -A 3 "get_stream"Repository: flashinfer-ai/flashinfer
Length of output: 243
🏁 Script executed:
# Look at the actual kernel launch with split-k to see workspace usage
rg "launch_kernel_splitk|workspace" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu -B 2 -A 8Repository: flashinfer-ai/flashinfer
Length of output: 8654
🏁 Script executed:
# Check if there's any synchronization between split-k phases
rg "cudaStreamSynchronize|cudaEventSynchronize|cudaDeviceSynchronize" flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/Repository: flashinfer-ai/flashinfer
Length of output: 50
Stream-scoped workspace cache required for concurrent multi-stream execution.
The workspace cache key omits the stream, so concurrent operations on different CUDA streams with identical shapes (m, n, k) share the same buffer. When split-K is active (splits > 1), multiple streams write to the same workspace tensor simultaneously, corrupting partial results. Either include the stream in the cache key or allocate fresh workspace per invocation.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/bmm_fp8_sm121.py` around
lines 127 - 146, The _cuda_workspace function currently keys the shared buffers
by (device, m, n, splits) causing cross-stream buffer reuse; modify it to be
stream-scoped by including the current CUDA stream in the cache key (e.g.,
torch.cuda.current_stream(device) or stream.cuda_stream) or, alternatively, skip
caching and allocate a fresh workspace when concurrent streams/splits are
possible (i.e., when splits > 1). Update the cache key usage for
_WORKSPACE_CACHE and the empty-cache behavior for _EMPTY_WORKSPACE_CACHE
accordingly so each CUDA stream gets its own tensor or a new tensor is returned
for split-K cases to avoid concurrent writes.
| const int splits = compute_splits(M, N, K); | ||
| const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0; | ||
| TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace) | ||
| << "workspace is too small for bmm_fp8 specialized kernel"; | ||
|
|
||
| ffi::CUDADeviceGuard device_guard(A.device().device_id); | ||
| cudaStream_t stream = get_stream(A.device()); | ||
| const int64_t A_batch_stride = static_cast<int64_t>(M) * K; | ||
| const int64_t B_batch_stride = static_cast<int64_t>(K) * N; | ||
| const int64_t O_batch_stride = static_cast<int64_t>(M) * N; | ||
| void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr; | ||
|
|
||
| for (int b = 0; b < batch; ++b) { | ||
| const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride; | ||
| const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride; | ||
| void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16); | ||
| launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr, | ||
| splits, stream); |
There was a problem hiding this comment.
Reject unsupported K values before dispatching the CUDA kernels.
This binding accepts any K, but the backend only has 16-byte-chunk load paths (uint4 / cp.async_16). If K % 16 != 0, the last iteration can read past the row tail before launch_fp8_gemm ever gets a chance to recover. A cheap guard here avoids turning an unsupported shape into a device fault.
Suggested fix
const int splits = compute_splits(M, N, K);
+ TVM_FFI_ICHECK_EQ(K % 16, 0)
+ << "bmm_fp8 SM121 specialized kernel requires K to be divisible by 16";
const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| const int splits = compute_splits(M, N, K); | |
| const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0; | |
| TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace) | |
| << "workspace is too small for bmm_fp8 specialized kernel"; | |
| ffi::CUDADeviceGuard device_guard(A.device().device_id); | |
| cudaStream_t stream = get_stream(A.device()); | |
| const int64_t A_batch_stride = static_cast<int64_t>(M) * K; | |
| const int64_t B_batch_stride = static_cast<int64_t>(K) * N; | |
| const int64_t O_batch_stride = static_cast<int64_t>(M) * N; | |
| void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr; | |
| for (int b = 0; b < batch; ++b) { | |
| const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride; | |
| const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride; | |
| void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16); | |
| launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr, | |
| splits, stream); | |
| const int splits = compute_splits(M, N, K); | |
| TVM_FFI_ICHECK_EQ(K % 16, 0) | |
| << "bmm_fp8 SM121 specialized kernel requires K to be divisible by 16"; | |
| const int64_t required_workspace = splits > 1 ? static_cast<int64_t>(splits) * M * N : 0; | |
| TVM_FFI_ICHECK_GE(workspace.numel(), required_workspace) | |
| << "workspace is too small for bmm_fp8 specialized kernel"; | |
| ffi::CUDADeviceGuard device_guard(A.device().device_id); | |
| cudaStream_t stream = get_stream(A.device()); | |
| const int64_t A_batch_stride = static_cast<int64_t>(M) * K; | |
| const int64_t B_batch_stride = static_cast<int64_t>(K) * N; | |
| const int64_t O_batch_stride = static_cast<int64_t>(M) * N; | |
| void* workspace_ptr = required_workspace > 0 ? workspace.data_ptr() : nullptr; | |
| for (int b = 0; b < batch; ++b) { | |
| const void* Ap = static_cast<const uint8_t*>(A.data_ptr()) + b * A_batch_stride; | |
| const void* Bp = static_cast<const uint8_t*>(B.data_ptr()) + b * B_batch_stride; | |
| void* Op = static_cast<uint8_t*>(out.data_ptr()) + b * O_batch_stride * sizeof(__nv_bfloat16); | |
| launch_fp8_gemm(Ap, Bp, Op, A_scale.data_ptr(), B_scale.data_ptr(), M, N, K, workspace_ptr, | |
| splits, stream); |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/binding.cu` around
lines 122 - 139, Add a pre-dispatch guard that rejects shapes where K is not a
multiple of 16 to avoid kernel loads reading past row tails: before creating the
device guard / entering the batch loop (after computing splits and workspace),
check if (K % 16 != 0) and fail fast with a clear error (e.g. via TVM_FFI_ICHECK
or similar) indicating unsupported K for the fp8 kernel; keep this check near
compute_splits/required_workspace and mention launch_fp8_gemm in the message so
callers know which backend requires the constraint.
| __nv_bfloat162 v01, v23; | ||
| v01.x = __float2bfloat16(acc[mf][nf][0] * scale); | ||
| v01.y = __float2bfloat16(acc[mf][nf][1] * scale); | ||
| v23.x = __float2bfloat16(acc[mf][nf][2] * scale); | ||
| v23.y = __float2bfloat16(acc[mf][nf][3] * scale); | ||
|
|
||
| if (row0 < M && col_base < N) { | ||
| *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01; | ||
| } | ||
| if (row1 < M && col_base < N) { | ||
| *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23; |
There was a problem hiding this comment.
Guard the second BF16 write in the non-split-K epilogue.
Once col_base == N - 1, these __nv_bfloat162 stores still write two outputs and overrun the row tail by one element. The split-K epilogue above already handles col_base + 1 < N; this direct-write path needs the same tail logic.
Suggested fix
- if (row0 < M && col_base < N) {
- *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01;
- }
- if (row1 < M && col_base < N) {
- *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23;
- }
+ if (row0 < M && col_base < N) {
+ Out_bf[row0 * N + col_base] = v01.x;
+ if (col_base + 1 < N) Out_bf[row0 * N + col_base + 1] = v01.y;
+ }
+ if (row1 < M && col_base < N) {
+ Out_bf[row1 * N + col_base] = v23.x;
+ if (col_base + 1 < N) Out_bf[row1 * N + col_base + 1] = v23.y;
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| __nv_bfloat162 v01, v23; | |
| v01.x = __float2bfloat16(acc[mf][nf][0] * scale); | |
| v01.y = __float2bfloat16(acc[mf][nf][1] * scale); | |
| v23.x = __float2bfloat16(acc[mf][nf][2] * scale); | |
| v23.y = __float2bfloat16(acc[mf][nf][3] * scale); | |
| if (row0 < M && col_base < N) { | |
| *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row0 * N + col_base]) = v01; | |
| } | |
| if (row1 < M && col_base < N) { | |
| *reinterpret_cast<__nv_bfloat162*>(&Out_bf[row1 * N + col_base]) = v23; | |
| __nv_bfloat162 v01, v23; | |
| v01.x = __float2bfloat16(acc[mf][nf][0] * scale); | |
| v01.y = __float2bfloat16(acc[mf][nf][1] * scale); | |
| v23.x = __float2bfloat16(acc[mf][nf][2] * scale); | |
| v23.y = __float2bfloat16(acc[mf][nf][3] * scale); | |
| if (row0 < M && col_base < N) { | |
| Out_bf[row0 * N + col_base] = v01.x; | |
| if (col_base + 1 < N) Out_bf[row0 * N + col_base + 1] = v01.y; | |
| } | |
| if (row1 < M && col_base < N) { | |
| Out_bf[row1 * N + col_base] = v23.x; | |
| if (col_base + 1 < N) Out_bf[row1 * N + col_base + 1] = v23.y; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cuda/kernel.cu` around
lines 270 - 280, The two-element BF16 stores using __nv_bfloat162 (writing
v01/v23 into Out_bf) can overrun when col_base == N-1 because they always write
two outputs; update the non-split-K epilogue to guard the second element the
same way as the split-K path: keep the existing check for the first element
(col_base < N) for each row (row0,row1) but add an additional condition col_base
+ 1 < N (or equivalent) before performing the __nv_bfloat162 write that writes
two adjacent outputs into Out_bf, or alternatively only write a single BF16
element when col_base == N-1 so the tail element is not overwritten (adjust the
code around v01/v23 and the reinterpret_cast to enforce this).
| # SIMT vec1+VEC=16 for M=1 (any N) or M<5 with M*N<=8192 | ||
| if (K & 15) == 0 and (M == 1 or (M < 5 and M * N <= 8192)): | ||
| if _SIMT1_V16_COMPILED is None: | ||
| _SIMT1_V16_COMPILED = _compile_simt1_v16() | ||
| _SIMT1_V16_COMPILED( | ||
| _fp8_ptr(A), | ||
| _fp8_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # 2-row SIMT for M=5-15 with N>=1024 N<5120 (B reuse across 2 A rows) | ||
| if M >= 5 and M < 16 and N >= 1024 and N < 5120 and (N & 7) == 0: | ||
| if _SIMT2_COMPILED is None: | ||
| _SIMT2_COMPILED = _compile_simt2() | ||
| _SIMT2_COMPILED( | ||
| _fp8_ptr(A), | ||
| _fp8_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # 2-row SIMT also for M=16-56 with low N (TINY MMA underfills SMs at low N) | ||
| if M >= 16 and M <= 56 and N <= 2048 and (N & 7) == 0: | ||
| if _SIMT2_COMPILED is None: | ||
| _SIMT2_COMPILED = _compile_simt2() | ||
| _SIMT2_COMPILED( | ||
| _fp8_ptr(A), | ||
| _fp8_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # SIMT for M<5 (large N), or M=5-15 with N<5120 (fallback) | ||
| if M < 5 or (M < 16 and N < 5120): | ||
| if _SIMT8_COMPILED is None: | ||
| _SIMT8_COMPILED = _compile_simt8() | ||
| _SIMT8_COMPILED( | ||
| _fp8_ptr(A), | ||
| _fp8_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # Small-M MMA path. | ||
| if M >= 256 and M < 1024 and N <= 2048 and (N & (BN_S - 1)) == 0: | ||
| if _MMA_S_COMPILED is None: | ||
| _MMA_S_COMPILED = _compile_mma_small() | ||
| _MMA_S_COMPILED( | ||
| _u32_ptr(A), | ||
| _u32_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # Large MMA for M>=256 with N divisible by 128 | ||
| if M >= 256 and (N & (BN_L - 1)) == 0: | ||
| if _MMA_L_COMPILED is None: | ||
| _MMA_L_COMPILED = _compile_mma_large() | ||
| _MMA_L_COMPILED( | ||
| _u32_ptr(A), | ||
| _u32_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # Med MMA for M=64-255 N>=16384 - 32x32 register tile | ||
| if M >= 64 and M < 256 and N >= 16384 and (N & 31) == 0: | ||
| if _MMA_M_COMPILED is None: | ||
| _MMA_M_COMPILED = _compile_mma_med() | ||
| _MMA_M_COMPILED( | ||
| _u32_ptr(A), | ||
| _u32_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return | ||
|
|
||
| # Tiny MMA for M=5..255 - register-only, no barriers | ||
| if M >= 5 and M < 256: | ||
| if _MMA_T_COMPILED is None: | ||
| _MMA_T_COMPILED = _compile_mma_tiny() | ||
| _MMA_T_COMPILED( | ||
| _u32_ptr(A), | ||
| _u32_ptr(B), | ||
| _f32_ptr(A_scale), | ||
| _f32_ptr(B_scale), | ||
| _bf16_ptr(out), | ||
| M, | ||
| N, | ||
| K, | ||
| stream, | ||
| ) | ||
| return |
There was a problem hiding this comment.
Match each dispatch branch to the launcher's divisibility contract.
Several branches can currently call a launcher with weaker guards than that launcher declares. For example, _launch_simt1_v16 requires K % 512 == 0 but the branch only checks K % 16, _launch_simt2 / _launch_simt8 require N % 64 == 0 but only gate on N % 8, and the tiny MMA fallback never enforces _launch_mma_tiny’s N % 32 / K % 256 contract. That can route unsupported shapes into kernels that assume stricter tile boundaries.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cute_dsl/kernel.py` around
lines 1771 - 1905, The dispatch branches call launchers with weaker divisibility
checks than the launchers require; update each branch guard to match the
launcher's contract: for the SIMT1 path (symbols _SIMT1_V16_COMPILED /
_compile_simt1_v16) require K % 512 == 0 (use (K & 511) == 0) instead of only K
% 16; for both SIMT2 and SIMT8 paths (symbols _SIMT2_COMPILED / _compile_simt2
and _SIMT8_COMPILED / _compile_simt8) require N % 64 == 0 (use (N & 63) == 0)
not just N % 8; and for the tiny MMA fallback (symbols _MMA_T_COMPILED /
_compile_mma_tiny) add the missing N % 32 == 0 and K % 256 == 0 checks (use (N &
31) == 0 and (K & 255) == 0) before invoking the compiled kernel so kernels
never receive shapes that violate their tile/divisibility contracts.
| def _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid): | ||
| num_bid_m = ct.cdiv(M, tm) | ||
| num_bid_n = ct.cdiv(N, tn) | ||
| num_bid_in_group = group_size_m * num_bid_n | ||
| group_id = bid // num_bid_in_group | ||
| first_bid_m = group_id * group_size_m | ||
| current_group_size_m = min(num_bid_m - first_bid_m, group_size_m) | ||
| bid_m = first_bid_m + (bid % current_group_size_m) | ||
| bid_n = (bid % num_bid_in_group) // current_group_size_m | ||
| return bid_m, bid_n |
There was a problem hiding this comment.
Fix the grouped swizzle math for the final partial M-group.
bid_m needs to be derived from the CTA's local id inside the current group. Using bid % current_group_size_m only works when the tail-group height happens to divide group_size_m * num_bid_n; otherwise the last group gets duplicate/skipped tiles.
🛠️ Suggested fix
def _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid):
num_bid_m = ct.cdiv(M, tm)
num_bid_n = ct.cdiv(N, tn)
num_bid_in_group = group_size_m * num_bid_n
group_id = bid // num_bid_in_group
first_bid_m = group_id * group_size_m
current_group_size_m = min(num_bid_m - first_bid_m, group_size_m)
- bid_m = first_bid_m + (bid % current_group_size_m)
- bid_n = (bid % num_bid_in_group) // current_group_size_m
+ local_bid = bid % num_bid_in_group
+ bid_m = first_bid_m + (local_bid % current_group_size_m)
+ bid_n = local_bid // current_group_size_m
return bid_m, bid_n🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py` around
lines 12 - 21, The grouped swizzle in _swizzle_2d_from_bid computes bid_m using
bid % current_group_size_m which fails for a final partial M-group; instead
compute the CTA-local id within the group (local_id = bid - group_id *
num_bid_in_group) and derive bid_m = first_bid_m + (local_id %
current_group_size_m) and bid_n = local_id // current_group_size_m so the last
partial M-group maps tiles correctly; update the calculations in
_swizzle_2d_from_bid to use local_id and current_group_size_m as described.
| @ct.kernel | ||
| def bmm_fp8_kernel( | ||
| A, | ||
| B, | ||
| A_scale, | ||
| B_scale, | ||
| out, | ||
| tm: ConstInt, | ||
| tn: ConstInt, | ||
| tk: ConstInt, | ||
| num_tiles_k: ConstInt, | ||
| group_size_m: ConstInt, | ||
| ): | ||
| bid = ct.bid(0) | ||
| M = A.shape[1] | ||
| N = B.shape[2] | ||
| sa = ct.load(A_scale, (0,), shape=(1,)).astype(ct.float32) | ||
| sb = ct.load(B_scale, (0,), shape=(1,)).astype(ct.float32) | ||
| scale = sa.item() * sb.item() | ||
| bid_m, bid_n = _swizzle_2d_from_bid(M, N, tm, tn, group_size_m, bid) | ||
| accumulator = ct.full((tm, tn), 0.0, dtype=ct.float32) | ||
| zero_pad = ct.PaddingMode.ZERO | ||
| for k in range(num_tiles_k): | ||
| a = ct.load( | ||
| A, index=(0, bid_m, k), shape=(1, tm, tk), padding_mode=zero_pad, latency=10 | ||
| ) | ||
| a = ct.reshape(a, (tm, tk)) | ||
| b = ct.load( | ||
| B, index=(0, k, bid_n), shape=(1, tk, tn), padding_mode=zero_pad, latency=10 | ||
| ) | ||
| b = ct.reshape(b, (tk, tn)) | ||
| accumulator = ct.mma(a, b, acc=accumulator) | ||
| accumulator = accumulator * scale | ||
| result = ct.astype(accumulator, ct.bfloat16) | ||
| result_3d = ct.reshape(result, (1, tm, tn)) | ||
| ct.store(out, index=(0, bid_m, bid_n), tile=result_3d) |
There was a problem hiding this comment.
This backend silently computes only batch 0.
All four kernels hard-code batch index 0 in their ct.load/ct.store calls, and run() launches only a 1D grid over M/N tiles. For any A.shape[0] > 1, later batches are never produced, so the BMM result is wrong.
🛡️ Minimal safe guard until batched launch support exists
def run(A, B, A_scale, B_scale, out):
+ if A.shape[0] != 1 or B.shape[0] != 1 or out.shape[0] != 1:
+ raise ValueError("cuTile bmm_fp8_sm121 currently supports batch size 1")
+
M = A.shape[1]
K = A.shape[2]
N = B.shape[2]Also applies to: 62-93, 96-131, 134-182, 285-327
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/bmm_fp8_sm121/cutile/kernel.py` around
lines 24 - 59, The kernel currently hard-codes batch 0 via ct.load/ct.store
index=(0, ...), so for A.shape[0] > 1 later batches are never computed; update
bmm_fp8_kernel (and the other kernels referenced) to read the batch index from
the kernel launch (e.g. use ct.bid(1) or an additional bid variable instead of
the literal 0), replace all occurrences of index=(0, bid_m, k) and index=(0, k,
bid_n) and the store index=(0, bid_m, bid_n) with index=(batch_id, bid_m, k),
index=(batch_id, k, bid_n), and index=(batch_id, bid_m, bid_n) respectively, and
ensure the host run()/launch code creates a 2D/3D grid that includes the batch
dimension so the kernel’s batch_id maps correctly; apply the same change to the
other kernels mentioned (lines ~62-93, 96-131, 134-182, 285-327).
| if not (a.is_cuda and b.is_cuda and a_descale.is_cuda and b_descale.is_cuda): | ||
| return None | ||
| if torch.cuda.get_device_capability(a.device) != (12, 1): | ||
| return None | ||
| if a.dtype != torch.uint8 or b.dtype != torch.uint8: | ||
| return None | ||
| if a_descale.dtype != torch.uint8 or b_descale.dtype != torch.uint8: | ||
| return None | ||
| if a.ndim != 2 or b.ndim != 2 or a_descale.ndim != 2 or b_descale.ndim != 2: | ||
| return None | ||
|
|
||
| m = int(a.shape[0]) | ||
| k = int(a.shape[1] * 2) | ||
| n = int(b.shape[1]) | ||
| key = ( | ||
| m, | ||
| k, | ||
| n, | ||
| block_size, | ||
| out_dtype, | ||
| use_8x4_sf_layout, | ||
| _normalize_backend(backend), | ||
| use_nvfp4, | ||
| ) | ||
| impl = _WORKLOAD_LUT.get(key) | ||
| if impl is None or not _impl_available(impl): | ||
| return None | ||
| if tuple(b.shape) != (k // 2, n): | ||
| return None | ||
|
|
||
| sf_m = ((m + 127) // 128) * 128 | ||
| sf_n = ((n + 127) // 128) * 128 | ||
| if tuple(a_descale.shape) != (sf_m, k // block_size): | ||
| return None | ||
| if tuple(b_descale.shape) != (k // block_size, sf_n): | ||
| return None | ||
| if not a.is_contiguous() or not a_descale.is_contiguous(): | ||
| return None | ||
| if not _is_column_major_view(b) or not _is_column_major_view(b_descale): | ||
| return None | ||
|
|
||
| if out is not None: | ||
| if out.dtype != torch.bfloat16 or tuple(out.shape) != (m, n): | ||
| return None | ||
| if not out.is_cuda or not out.is_contiguous(): | ||
| return None | ||
|
|
There was a problem hiding this comment.
Reject mixed-device tensors in _select_impl.
The predicate only checks is_cuda, so b, the descales, or out can still live on a different GPU than a. The specialized runner later uses a.device for stream selection and forwards every tensor's raw pointer into the same launch, which makes mixed-device inputs unsafe instead of ineligible.
Suggested fix
if not (a.is_cuda and b.is_cuda and a_descale.is_cuda and b_descale.is_cuda):
return None
+ devices = {a.device, b.device, a_descale.device, b_descale.device}
+ if out is not None:
+ devices.add(out.device)
+ if len(devices) != 1:
+ return None
if torch.cuda.get_device_capability(a.device) != (12, 1):
return None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py` around
lines 110 - 156, The current device checks in _select_impl only verify tensors
are CUDA but not that they all live on the same GPU as a, allowing unsafe
mixed-device launches; update the predicate to ensure b.device,
a_descale.device, b_descale.device (and out.device when out is provided) equal
a.device before accepting the specialized impl. Locate the selection logic in
function _select_impl (and the final out validation block) and add explicit
device-equality checks (e.g., compare .device or .get_device() values) for b,
a_descale, b_descale, and out to reject cases where any tensor is on a different
CUDA device than a. Ensure these checks run before returning the impl.
| def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor: | ||
| device = _device_key(device) | ||
| if alpha is None: | ||
| cached = _ALPHA_ONE_CACHE.get(device) | ||
| if cached is None: | ||
| cached = torch.tensor([1.0], dtype=torch.float32, device=device) | ||
| _ALPHA_ONE_CACHE[device] = cached | ||
| return cached | ||
| if alpha.dim() == 0: | ||
| return alpha.unsqueeze(0) | ||
| return alpha.reshape(1) |
There was a problem hiding this comment.
Normalize alpha onto CUDA float32 before exposing its pointer to the kernel.
_prepare_alpha() currently preserves the caller's device and dtype. If a user passes a CPU scalar or an fp16/fp64 tensor, cute_dsl.kernel.run() still builds a device float32* from that data_ptr(), which turns into an invalid pointer/type reinterpretation at launch time.
Suggested fix
def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor:
device = _device_key(device)
if alpha is None:
cached = _ALPHA_ONE_CACHE.get(device)
if cached is None:
cached = torch.tensor([1.0], dtype=torch.float32, device=device)
_ALPHA_ONE_CACHE[device] = cached
return cached
- if alpha.dim() == 0:
- return alpha.unsqueeze(0)
- return alpha.reshape(1)
+ alpha = alpha.to(device=device, dtype=torch.float32)
+ if alpha.dim() == 0:
+ alpha = alpha.unsqueeze(0)
+ else:
+ alpha = alpha.reshape(1)
+ return alpha.contiguous()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor: | |
| device = _device_key(device) | |
| if alpha is None: | |
| cached = _ALPHA_ONE_CACHE.get(device) | |
| if cached is None: | |
| cached = torch.tensor([1.0], dtype=torch.float32, device=device) | |
| _ALPHA_ONE_CACHE[device] = cached | |
| return cached | |
| if alpha.dim() == 0: | |
| return alpha.unsqueeze(0) | |
| return alpha.reshape(1) | |
| def _prepare_alpha(alpha: Optional[torch.Tensor], device: torch.device) -> torch.Tensor: | |
| device = _device_key(device) | |
| if alpha is None: | |
| cached = _ALPHA_ONE_CACHE.get(device) | |
| if cached is None: | |
| cached = torch.tensor([1.0], dtype=torch.float32, device=device) | |
| _ALPHA_ONE_CACHE[device] = cached | |
| return cached | |
| alpha = alpha.to(device=device, dtype=torch.float32) | |
| if alpha.dim() == 0: | |
| alpha = alpha.unsqueeze(0) | |
| else: | |
| alpha = alpha.reshape(1) | |
| return alpha.contiguous() |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/gemm/specialized_kernels/mm_fp4_sm121/mm_fp4_sm121.py` around
lines 200 - 210, _prepare_alpha currently preserves caller device and dtype
which can yield invalid pointers when kernel expects a CUDA float32*; change it
to always return a CUDA float32 1-D tensor (device normalized via _device_key)
before exposing data_ptr. Concretely, when alpha is None, ensure the cached
value in _ALPHA_ONE_CACHE is torch.tensor([1.0], dtype=torch.float32,
device=device); when alpha is provided, move/convert it to device and dtype
torch.float32 (e.g., alpha.to(device=device, dtype=torch.float32)), then
normalize shape with unsqueeze(0) or reshape(1) as before; keep using the same
cache key (_ALPHA_ONE_CACHE) and keep function name _prepare_alpha so callers
(and cute_dsl.kernel.run) receive a CUDA float32 pointer.
| input = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16) | ||
| input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn) | ||
| mat2 = torch.randn((batch, n, k), device="cuda", dtype=torch.bfloat16).transpose( | ||
| -2, -1 | ||
| ) | ||
| mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=torch.float8_e4m3fn) | ||
| reference = torch.bmm(input, mat2) |
There was a problem hiding this comment.
Avoid shadowing Python builtins in tests.
Line 213 uses input as a variable name, which triggers Ruff A001 and can fail lint-gated CI.
Proposed patch
- input = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16)
- input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn)
+ inp = torch.randn((batch, m, k), device="cuda", dtype=torch.bfloat16)
+ input_fp8, input_inv_s = to_float8(inp, dtype=torch.float8_e4m3fn)
@@
- reference = torch.bmm(input, mat2)
+ reference = torch.bmm(inp, mat2)🧰 Tools
🪛 Ruff (0.15.12)
[error] 213-213: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/gemm/test_specialized_gemm_routing.py` around lines 213 - 219, The test
uses the builtin name "input" as a local variable (seen around creation of the
tensor and calls to to_float8 and torch.bmm), which shadows Python's builtin and
triggers lint error A001; rename that variable (and its uses: input_fp8,
input_inv_s, and the reference bmm call) to a non-builtin name like input_tensor
or inp throughout the block so to_float8(input_tensor, ...) and
torch.bmm(input_tensor, mat2) are used instead.
There was a problem hiding this comment.
Code Review
This pull request introduces specialized GEMM kernels for SM121 architectures, supporting both FP4 and FP8 precisions across multiple backends including CUDA, CUTE DSL, and cuTile. The implementation includes a runtime routing mechanism that detects specific workloads and dispatches them to optimized kernels, along with an environment flag to disable this behavior for benchmarking. Key infrastructure updates include enhancements to the autotuner's caching logic for better disambiguation and the addition of comprehensive benchmarking and testing scripts. Review feedback highlighted the need for improved robustness in the benchmark script, specifically suggesting guards against division by zero and invalid logarithmic inputs in the speedup and geometric mean calculations.
| disabled_cosine = cosine(reference, disabled_snapshot) | ||
| enabled_cosine = cosine(reference, enabled_snapshot) | ||
| routed_cosine = cosine(disabled_snapshot, enabled_snapshot) | ||
| speedup = disabled_ms / enabled_ms |
| def geomean(values: list[float]) -> float: | ||
| return math.exp(sum(math.log(v) for v in values) / len(values)) |
There was a problem hiding this comment.
The geomean function may fail in two scenarios:
- If
valuescontains a zero or negative number,math.log(v)will raise aValueError. This could happen ifdisabled_msis 0, leading to aspeedupof 0. - If
valuesis an empty list,len(values)will be zero, causing aZeroDivisionError.
Please consider adding checks for these cases to improve the script's robustness.
|
/bot run |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/autotuner.py (1)
887-938:⚠️ Potential issue | 🟠 Major | ⚡ Quick winLegacy file-key fallback block is unreachable.
_get_file_cache_key(cache_key)already returns the legacy 3-field key whencache_key[4] == ()(see lines 1547–1548). So at line 890,file_keyis identical to what_get_legacy_file_cache_key(cache_key)would produce in that branch. The block at lines 918–937 only runs whencache_key[4] == (), computes the same string again, and re-queriesself._file_configs— it will never hit when line 895 missed.In its current shape this code is dead: the "legacy config file" log path is unreachable, and no actual backward-compat lookup happens for runners that do return non-empty
extras(which is presumably the case that motivated the new SM121-specialized runners). Two plausible fixes, depending on intent:
- If the intent is to always match legacy 3-field cache files (including for newer runners that now emit extras), make the legacy lookup run when
cache_key[4] != ()instead, and have_get_file_cache_keyalways return the new 4-field format:♻️ Option A: always use the 4-field key, fall back to legacy when extras present
`@staticmethod` def _get_file_cache_key(cache_key: Tuple) -> str: - if cache_key[4] == (): - return AutoTuner._get_legacy_file_cache_key(cache_key) return str((cache_key[0], cache_key[1], cache_key[3], cache_key[4]))- # Preserve compatibility with older cache files only for - # runners that do not need extra key material. Reusing a - # legacy key when extras are non-empty can apply a tactic to a - # shape or dtype it was never profiled for. - if cache_key[4] == (): + # Preserve compatibility with older cache files written before + # the `extras` field existed. Only safe when the runner has no + # extras, since legacy keys cannot disambiguate dtype/etc. + if cache_key[4] == (): legacy_file_key = AutoTuner._get_legacy_file_cache_key(cache_key) if legacy_file_key in self._file_configs: ...Note: this would also require updating
save_configsso newly written files always use the 4-field format (which then need migration semantics for already-saved 3-field files).
- If the intent is that legacy 3-field files should only be honored when the runner has no extras (and the current
_get_file_cache_keybehavior is correct for write compatibility), then lines 914–937 are simply redundant and should be removed:♻️ Option B: drop the dead legacy fallback block
- # Preserve compatibility with older cache files only for - # runners that do not need extra key material. Reusing a - # legacy key when extras are non-empty can apply a tactic to a - # shape or dtype it was never profiled for. - if cache_key[4] == (): - legacy_file_key = AutoTuner._get_legacy_file_cache_key(cache_key) - if legacy_file_key in self._file_configs: - runner_name, tactic = self._file_configs[legacy_file_key] - runner_id = next( - ( - i - for i, runner in enumerate(runners) - if runner.__class__.__name__ == runner_name - ), - 0, - ) - log_key = (custom_op, runner_name) - if log_key not in self._logged_file_hits: - self._logged_file_hits.add(log_key) - logger.info( - f"[Autotuner]: Config cache hit for {custom_op} " - f"(runner={runner_name}, source=legacy config file)" - ) - return True, runner_id, tactic, None -Worth confirming the intent before picking one; the resulting behavior for cross-version cache files differs materially.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/autotuner.py` around lines 887 - 938, The legacy-file fallback block is unreachable because AutoTuner._get_file_cache_key(cache_key) already returns the 3-field legacy key when cache_key[4] == (), so the subsequent legacy lookup using _get_legacy_file_cache_key never adds new matches; either remove the redundant block (delete the branch that computes legacy_file_key and its logging/return) if legacy keys should only be honored when extras are empty, or change the condition to run the legacy lookup when cache_key[4] != () so legacy 3-field configs are matched for runners that now emit extras (and ensure save_configs consistently writes the chosen format); update references to file_key, cache_key, AutoTuner._get_file_cache_key, AutoTuner._get_legacy_file_cache_key, self._file_configs and self._logged_file_hits accordingly.
🧹 Nitpick comments (1)
flashinfer/autotuner.py (1)
1650-1651: 💤 Low valuePrefix unused unpacked names with
_.Ruff flags
custom_op,profile,extras(line 1650) andrunner_id(line 1651) as unused._get_file_cache_keyconsumes the wholecache_keytuple directly, and onlyrunner_class_name/tacticare read.♻️ Suggested rename
- custom_op, runner_class_name, _runner_hash, profile, extras = cache_key - runner_id, tactic, _opt_profile = cache_value + _custom_op, runner_class_name, _runner_hash, _profile, _extras = cache_key + _runner_id, tactic, _opt_profile = cache_valueOr, since the helper already takes the full tuple, drop the unpack entirely and use
cache_value[1]fortactic.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/autotuner.py` around lines 1650 - 1651, The unpacking of cache_key/cache_value exposes unused names; either prefix the unused variables with an underscore (e.g., _custom_op, _runner_hash, _profile, _extras, _runner_id) or remove the unpack and index the tuple directly (use cache_key and cache_value[1] for tactic) in the loop where _get_file_cache_key is called; update the variables in that block (runner_class_name, tactic) so only the used names remain to satisfy Ruff.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flashinfer/autotuner.py`:
- Around line 887-938: The legacy-file fallback block is unreachable because
AutoTuner._get_file_cache_key(cache_key) already returns the 3-field legacy key
when cache_key[4] == (), so the subsequent legacy lookup using
_get_legacy_file_cache_key never adds new matches; either remove the redundant
block (delete the branch that computes legacy_file_key and its logging/return)
if legacy keys should only be honored when extras are empty, or change the
condition to run the legacy lookup when cache_key[4] != () so legacy 3-field
configs are matched for runners that now emit extras (and ensure save_configs
consistently writes the chosen format); update references to file_key,
cache_key, AutoTuner._get_file_cache_key, AutoTuner._get_legacy_file_cache_key,
self._file_configs and self._logged_file_hits accordingly.
---
Nitpick comments:
In `@flashinfer/autotuner.py`:
- Around line 1650-1651: The unpacking of cache_key/cache_value exposes unused
names; either prefix the unused variables with an underscore (e.g., _custom_op,
_runner_hash, _profile, _extras, _runner_id) or remove the unpack and index the
tuple directly (use cache_key and cache_value[1] for tactic) in the loop where
_get_file_cache_key is called; update the variables in that block
(runner_class_name, tactic) so only the used names remain to satisfy Ruff.
|
[FAILED] Pipeline #50936022: 12/20 passed |
|
seems clean? ready for auto-merge? |
|
swapped labels for 0.6.11.post1 |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
Tools