Skip to content

non-override tactic control#3260

Merged
yanqinz2 merged 6 commits into
mainfrom
yanqinz/option-to-disable-override
May 8, 2026
Merged

non-override tactic control#3260
yanqinz2 merged 6 commits into
mainfrom
yanqinz/option-to-disable-override

Conversation

@yanqinz2
Copy link
Copy Markdown
Collaborator

@yanqinz2 yanqinz2 commented May 7, 2026

📌 Description

It also fixes a tactic mismatch hazard in the non-override path. Autotuning buckets dynamic M values, but static cuDNN graphs are rebuilt with the actual runtime M. A tactic index profiled on a bucket-M graph may not refer to the same execution plan in the actual-M graph. To avoid applying an invalid or mismatched tactic, the non-override cuDNN paths now expose only the fallback tactic -1, forcing runtime to use the cuDNN heuristic path for the actual static graph.

This update also brings the cuDNN FP8 and MXFP8 GEMM paths in line with the existing BF16/FP4 override-shape design.

For FP8, the per-tensor quantized cuDNN graph builders were renamed to follow the same naming convention as the other GEMM paths: build_cudnn_gemm_fp8_graph and build_cudnn_gemm_fp8_graph_override_shape. The cuDNN FP8 runner now supports the override-shape execution path, using the autotuner’s effective M-bucket mapper to build a reusable bucketed graph and passing the runtime shapes through override_shapes.

For MXFP8, the cuDNN graph construction was refactored to match the FP4/BF16 structure. The previous create_cudnn_execution_plans_mxfp8_gemm + _get_cudnn_mxfp8_gemm_graph split was replaced with a single build_cudnn_gemm_mxfp8_graph builder that owns graph creation, support checking, and plan building. The MXFP8 runner now also supports the override-shape path with bucketed graph reuse.

For both FP8 and MXFP8, the non-override static cuDNN path no longer participates in autotune plan-index profiling, since static graphs are rebuilt for the actual runtime M and cannot safely reuse tactic indices profiled on bucketed M shapes.

🔍 Related Issues

test_bmm_bf16 failure on B300

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • API Changes

    • Removed legacy cuDNN GEMM override helpers from the public API; core GEMM wrappers remain.
    • Public exports now include available CuTe‑DSL kernels at import time when present.
  • Improvements

    • Safer override‑shape handling with deterministic fallback when dynamic-shape support is unavailable.
    • Improved cache invalidation to avoid stale execution plans.
  • Tests

    • Added GPU tests validating override‑shape execution across dynamic sizes and quantized modes.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Removes cuDNN override-shape helper exports, adds conditional CuTe‑DSL kernel discovery to exports, threads enable_override_shape and **kwargs through GEMM dispatch, refactors FP4/FP8/BF16/MXFP8 runners to gate override-shape execution vs static fallback, and adds GPU tests exercising single compiled graphs across varying M.

Changes

cuDNN GEMM override-shape control

Layer / File(s) Summary
Module API and exports
flashinfer/gemm/__init__.py
Removes cuDNN override-shape helper exports and appends dynamically discovered CuTe-DSL kernel names to __all__.
Backend requirement plumbing
flashinfer/gemm/gemm_base.py
Adds **kwargs and threads enable_override_shape into backend requirement/heuristic signatures to allow flag propagation and updates graph cache invalidation.
FP8 graph builders
flashinfer/gemm/gemm_base.py
Renames FP8 per-tensor override-shape builder/executor helpers and ensures cache clearing includes renamed builders.
FP8 runner & execution gating
flashinfer/gemm/gemm_base.py
Introduces _cudnn_gemm_fp8_runner with M-bucket caching and splits execution between override-graph (tactic>=0) and static fallback (tactic=-1).
BF16 runner implementation
flashinfer/gemm/gemm_base.py
Refactors BF16 runner to compute _use_override_shape, build override graphs when supported, and fallback to static execution with tactic=-1.
FP4 public API & runner
flashinfer/gemm/gemm_base.py
Threads enable_override_shape into FP4 runner and updates it to build/execute override-shape graphs when supported, otherwise use static execution (tactic=-1).
MXFP8 graph & runner
flashinfer/gemm/gemm_base.py
Consolidates MXFP8 graph builder, adds _cudnn_gemm_mxfp8_runner with override-shape gating, and splits forward between override and static execution.
BMM call-site updates
flashinfer/gemm/gemm_base.py
Updates BMM FP8 and bmm_mxfp8 call sites to the refactored dispatcher/backends semantics.
Integration tests
tests/gemm/test_cudnn_override_shape.py
GPU-only tests compile one graph per test and execute it across multiple runtime M values for BF16, NVFP4, and MXFP8, validating numeric correctness or cosine similarity.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • dhiraj113
  • aleozlx
  • yzh119
  • bkryu

"🐰 I hopped through GEMM with a tiny flag in paw,
Override-shape listens when you say “on” or “off.”
BF16, FP4, FP8 now choose their fancy route,
One compiled graph reused — no rebuild to flout.
A carrot-sized cheer for tests that pass with glee."

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 19.40% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'non-override tactic control' is partially related to the changeset. It refers to a real aspect of the change—handling tactics when override-shape is disabled—but does not capture the main point: adding enable_override_shape option and fixing tactic mismatches across cuDNN GEMM paths (FP8, MXFP8, BF16, FP4).
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is detailed and comprehensive, covering the purpose, technical rationale, related issues, and addressing most checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yanqinz/option-to-disable-override

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates cuDNN override-shape functionality into high-level GEMM APIs, such as mm_bf16 and bmm_bf16, and introduces an enable_override_shape parameter to manage this feature. It also modifies the autotuning logic for cuDNN FP8 and MXFP8 to return a fallback tactic, preventing plan mismatches when graphs are rebuilt for runtime shapes. Feedback suggests that deleting the override-shape tests significantly reduces coverage and recommends updating them to use the new high-level APIs instead. Reviewers also expressed concern that disabling autotuning for FP8 and MXFP8 could lead to performance regressions, suggesting the use of override-shape builders to allow safe profiling. Lastly, the removal of is_cudnn_override_shape_available from the public API was identified as a breaking change.

I am having trouble creating individual review comments. Click here to see my feedback.

tests/gemm/test_cudnn_override_shape.py (1-326)

high

Removing this test file reduces coverage for cuDNN override-shape functionality. These tests should be updated to use high-level GEMM APIs (e.g., mm_bf16 with enable_override_shape=True) instead of being deleted to ensure no regressions in the cuDNN execution path.

flashinfer/gemm/gemm_base.py (3078-3081)

medium

Disabling autotuning for cuDNN FP8 avoids tactic mismatch but may cause performance regressions. Consider updating _cudnn_gemm_fp8_runner to use the existing override-shape builder (line 2895) to enable safe autotuning. Also, the comment wording 'only cache' is slightly inaccurate as this function determines which tactics to profile.

            # Static-shape cuDNN graphs are rebuilt for the runtime M.  A
            # tactic index profiled for a bucket M may refer to a different
            # plan in the actual-M graph, so only profile the fallback tactic.
            return [-1]

flashinfer/gemm/gemm_base.py (7988-7991)

medium

Disabling autotuning for MXFP8 is suboptimal. Consider updating _cudnn_gemm_mxfp8_runner to use the existing override-shape builder (line 2589) to maintain performance while fixing the hazard.

            # Static-shape cuDNN graphs are rebuilt for the runtime M.  A
            # tactic index profiled for a bucket M may refer to a different
            # plan in the actual-M graph, so only profile the fallback tactic.
            return [-1]

flashinfer/gemm/init.py (23-31)

medium

Removing is_cudnn_override_shape_available from the public API is a breaking change. It is a useful utility for users to check feature support at runtime; consider keeping it exposed.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4808-4810: The comment claiming "currently cudnn backend does not
support alpha for dynamic-shape" is stale for the FP4 cuDNN override-shape path;
update or remove those comments around the override-shape conditional (the block
guarded by self._use_override_shape) and the similar comment in the nearby
override-shape graph build/execute path so they accurately state that alpha is
handled in the override-shape path (or simply remove the incorrect restriction
note). Ensure any new comment clearly distinguishes dynamic-shape cuDNN
limitations from the override-shape implementation that includes alpha support.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8f40fd23-7ed7-4d4d-bbd0-aadeeb0a8704

📥 Commits

Reviewing files that changed from the base of the PR and between bb41dc1 and d55d8437992220d766d50cc0cb7758546470fe26.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
  • tests/gemm/test_cudnn_override_shape.py
  • flashinfer/gemm/init.py

Comment thread flashinfer/gemm/gemm_base.py Outdated
@yanqinz2 yanqinz2 added the run-ci label May 7, 2026
@yanqinz2
Copy link
Copy Markdown
Collaborator Author

yanqinz2 commented May 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !647 has been created, and the CI pipeline #50605869 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #50605869: 1/20 passed

@yanqinz2 yanqinz2 force-pushed the yanqinz/option-to-disable-override branch from b53e8f3 to aea9260 Compare May 7, 2026 22:26
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

7976-7990: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Restore the cuDNN availability guard in the refactored MXFP8 builder.

build_cudnn_gemm_mxfp8_graph() now touches cudnn.build_plan_policy before any availability check. If this helper is reached on a build without cuDNN, it will fail with a raw NameError instead of the consistent _check_cudnn_availability() error used by the other cuDNN graph builders.

Suggested fix
 `@functools.lru_cache`(maxsize=1024)
 def build_cudnn_gemm_mxfp8_graph(
     a_shape,
     a_stride,
     a_type,  # cudnn.data_type, FP8_E4M3 or FP8_E5M2
@@
     device,
     policy=None,
 ):
+    _check_cudnn_availability()
     if policy is None:
         policy = cudnn.build_plan_policy.HEURISTICS_CHOICE
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/gemm_base.py` around lines 7976 - 7990, The function
build_cudnn_gemm_mxfp8_graph currently references cudnn.build_plan_policy before
verifying cuDNN is available; add the existing availability guard by calling
_check_cudnn_availability() at the start of build_cudnn_gemm_mxfp8_graph (or
move the policy default assignment so it happens after that check) so that
missing cuDNN raises the consistent _check_cudnn_availability() error instead of
a NameError; reference the build_cudnn_gemm_mxfp8_graph function and the
cudnn.build_plan_policy assignment when making this change.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

3195-3197: ⚡ Quick win

Partition the autotuner cache by override-shape mode.

These get_cache_key_extras() methods still ignore self._use_override_shape, but this PR makes that flag change the valid tactic set from plan indices to the heuristic-only fallback path. If the same shape is first cached with enable_override_shape=False, later calls with it enabled can reuse that entry and skip profiling the override-shape plans entirely.

At minimum, add self._use_override_shape to the cache key for these cuDNN runners so the two execution modes do not alias in AutoTuner.

Also applies to: 3702-3706, 4949-4954, 8169-8171

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/gemm_base.py` around lines 3195 - 3197, The
get_cache_key_extras methods for the cuDNN GEMM runners currently return only
tensor dtypes and therefore do not distinguish override-shape mode; update each
get_cache_key_extras (e.g., the method shown and the other occurrences) to
include self._use_override_shape in the returned tuple (for example return
(a.dtype, b.dtype, out.dtype, self._use_override_shape)) so the autotuner cache
is partitioned by override-shape mode and the two execution modes do not alias
in AutoTuner.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 7976-7990: The function build_cudnn_gemm_mxfp8_graph currently
references cudnn.build_plan_policy before verifying cuDNN is available; add the
existing availability guard by calling _check_cudnn_availability() at the start
of build_cudnn_gemm_mxfp8_graph (or move the policy default assignment so it
happens after that check) so that missing cuDNN raises the consistent
_check_cudnn_availability() error instead of a NameError; reference the
build_cudnn_gemm_mxfp8_graph function and the cudnn.build_plan_policy assignment
when making this change.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3195-3197: The get_cache_key_extras methods for the cuDNN GEMM
runners currently return only tensor dtypes and therefore do not distinguish
override-shape mode; update each get_cache_key_extras (e.g., the method shown
and the other occurrences) to include self._use_override_shape in the returned
tuple (for example return (a.dtype, b.dtype, out.dtype,
self._use_override_shape)) so the autotuner cache is partitioned by
override-shape mode and the two execution modes do not alias in AutoTuner.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b59f858c-0d29-45fe-a098-c857c7b85d10

📥 Commits

Reviewing files that changed from the base of the PR and between b53e8f3f9d5c9fda80eb61a74cf7b03590c804f4 and aea926091b535d11199bec8dd80feb2f29c12748.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
  • tests/gemm/test_cudnn_override_shape.py
  • flashinfer/gemm/init.py

@yanqinz2 yanqinz2 force-pushed the yanqinz/option-to-disable-override branch from aea9260 to b086df4 Compare May 7, 2026 23:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

3201-3216: ⚡ Quick win

Return only the fallback tactic on the non-override path.

In these branches, get_valid_tactics() still exposes numbered plan indices, but forward() ignores them and always executes the static graph with tactic=-1. That keeps the autotuner profiling meaningless tactic IDs instead of the single heuristic fallback this change is trying to preserve.

Suggested direction
         def get_valid_tactics(
             self,
             inputs: List[torch.Tensor],
             profile: OptimizationProfile,
         ) -> List[int]:
             a, b, _, _, out, _ = inputs
             if self._use_override_shape:
                 graph = self._get_override_graph(a, b, out)
-            else:
-                graph = build_cudnn_gemm_fp8_graph(
-                    a_shape=a.shape,
-                    a_stride=a.stride(),
-                    b_shape=b.shape,
-                    b_stride=b.stride(),
-                    a_type=_torch_data_type_to_cudnn_data_type(a.dtype),
-                    b_type=_torch_data_type_to_cudnn_data_type(b.dtype),
-                    o_type=_torch_data_type_to_cudnn_data_type(out.dtype),
-                    device=a.device,
-                    policy=cudnn.build_plan_policy.HEURISTICS_CHOICE,
-                )
-
-            return list(range(graph.get_execution_plan_count()))
+                return list(range(graph.get_execution_plan_count()))
+            return [-1]

Mirror that pattern in the BF16 / FP4 / MXFP8 runners as well.

Also applies to: 3238-3248, 3694-3719, 3742-3744, 4966-4999, 5037-5050, 8192-8208, 8230-8240

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/gemm/gemm_base.py` around lines 3201 - 3216, The non-override
branch currently returns a list of all graph plan indices
(range(graph.get_execution_plan_count())), but forward() always executes the
static graph with tactic=-1, so change the non-override return to expose only
the fallback tactic (return [-1] or equivalent single-element list containing
the fallback ID) so autotuner sees the single heuristic fallback; update the
same pattern in the other locations mentioned (e.g., the BF16/FP4/MXFP8 runner
methods around lines 3238-3248, 3694-3719, 3742-3744, 4966-4999, 5037-5050,
8192-8208, 8230-8240) and keep the override branch using
_get_override_graph(...) and returning full plan indices there.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3163-3171: The constructors for the cuDNN runners (e.g., the
__init__(self, m_bucket_mapper) shown) ignore the enable_override_shape flag and
always set self._use_override_shape = is_cudnn_override_shape_available(), so
callers cannot opt out; modify these constructors to accept an
enable_override_shape (or similar) parameter and set self._use_override_shape =
bool(enable_override_shape) if provided, otherwise fall back to
is_cudnn_override_shape_available(); update the corresponding BF16/FP4/MXFP8
cuDNN runner factory functions and their call sites to pass this parameter
through (apply the same plumbing pattern used elsewhere) so callers can force
the static fallback path when needed.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3201-3216: The non-override branch currently returns a list of all
graph plan indices (range(graph.get_execution_plan_count())), but forward()
always executes the static graph with tactic=-1, so change the non-override
return to expose only the fallback tactic (return [-1] or equivalent
single-element list containing the fallback ID) so autotuner sees the single
heuristic fallback; update the same pattern in the other locations mentioned
(e.g., the BF16/FP4/MXFP8 runner methods around lines 3238-3248, 3694-3719,
3742-3744, 4966-4999, 5037-5050, 8192-8208, 8230-8240) and keep the override
branch using _get_override_graph(...) and returning full plan indices there.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 122ea434-c0e5-46a8-af0c-c1e2263d82d6

📥 Commits

Reviewing files that changed from the base of the PR and between aea926091b535d11199bec8dd80feb2f29c12748 and b086df4.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
  • flashinfer/gemm/init.py
  • tests/gemm/test_cudnn_override_shape.py

Comment thread flashinfer/gemm/gemm_base.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
] = "cudnn",
backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn",
*,
enable_override_shape: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this API level change strictly necessary?

tuner = AutoTuner.get()
effective_m_bucket_mapper = tuner.get_effective_map_to_tuning_buckets(
_FP8_GEMM_SM100_TUNING_CONFIG, spec_idx=0
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this is only used by cudnn, can this be moved to under cudnn runner.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be cleaner to move its lookup into the cuDNN runner construction path. We can update this in a follow-up change.

@yanqinz2 yanqinz2 merged commit f6717ff into main May 8, 2026
37 checks passed
@yanqinz2 yanqinz2 deleted the yanqinz/option-to-disable-override branch May 8, 2026 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants