non-override tactic control#3260
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRemoves cuDNN override-shape helper exports, adds conditional CuTe‑DSL kernel discovery to exports, threads enable_override_shape and **kwargs through GEMM dispatch, refactors FP4/FP8/BF16/MXFP8 runners to gate override-shape execution vs static fallback, and adds GPU tests exercising single compiled graphs across varying M. ChangescuDNN GEMM override-shape control
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request integrates cuDNN override-shape functionality into high-level GEMM APIs, such as mm_bf16 and bmm_bf16, and introduces an enable_override_shape parameter to manage this feature. It also modifies the autotuning logic for cuDNN FP8 and MXFP8 to return a fallback tactic, preventing plan mismatches when graphs are rebuilt for runtime shapes. Feedback suggests that deleting the override-shape tests significantly reduces coverage and recommends updating them to use the new high-level APIs instead. Reviewers also expressed concern that disabling autotuning for FP8 and MXFP8 could lead to performance regressions, suggesting the use of override-shape builders to allow safe profiling. Lastly, the removal of is_cudnn_override_shape_available from the public API was identified as a breaking change.
I am having trouble creating individual review comments. Click here to see my feedback.
tests/gemm/test_cudnn_override_shape.py (1-326)
Removing this test file reduces coverage for cuDNN override-shape functionality. These tests should be updated to use high-level GEMM APIs (e.g., mm_bf16 with enable_override_shape=True) instead of being deleted to ensure no regressions in the cuDNN execution path.
flashinfer/gemm/gemm_base.py (3078-3081)
Disabling autotuning for cuDNN FP8 avoids tactic mismatch but may cause performance regressions. Consider updating _cudnn_gemm_fp8_runner to use the existing override-shape builder (line 2895) to enable safe autotuning. Also, the comment wording 'only cache' is slightly inaccurate as this function determines which tactics to profile.
# Static-shape cuDNN graphs are rebuilt for the runtime M. A
# tactic index profiled for a bucket M may refer to a different
# plan in the actual-M graph, so only profile the fallback tactic.
return [-1]
flashinfer/gemm/gemm_base.py (7988-7991)
Disabling autotuning for MXFP8 is suboptimal. Consider updating _cudnn_gemm_mxfp8_runner to use the existing override-shape builder (line 2589) to maintain performance while fixing the hazard.
# Static-shape cuDNN graphs are rebuilt for the runtime M. A
# tactic index profiled for a bucket M may refer to a different
# plan in the actual-M graph, so only profile the fallback tactic.
return [-1]
flashinfer/gemm/init.py (23-31)
Removing is_cudnn_override_shape_available from the public API is a breaking change. It is a useful utility for users to check feature support at runtime; consider keeping it exposed.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4808-4810: The comment claiming "currently cudnn backend does not
support alpha for dynamic-shape" is stale for the FP4 cuDNN override-shape path;
update or remove those comments around the override-shape conditional (the block
guarded by self._use_override_shape) and the similar comment in the nearby
override-shape graph build/execute path so they accurately state that alpha is
handled in the override-shape path (or simply remove the incorrect restriction
note). Ensure any new comment clearly distinguishes dynamic-shape cuDNN
limitations from the override-shape implementation that includes alpha support.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8f40fd23-7ed7-4d4d-bbd0-aadeeb0a8704
📥 Commits
Reviewing files that changed from the base of the PR and between bb41dc1 and d55d8437992220d766d50cc0cb7758546470fe26.
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
- tests/gemm/test_cudnn_override_shape.py
- flashinfer/gemm/init.py
|
/bot run |
|
[FAILED] Pipeline #50605869: 1/20 passed |
b53e8f3 to
aea9260
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
7976-7990:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winRestore the cuDNN availability guard in the refactored MXFP8 builder.
build_cudnn_gemm_mxfp8_graph()now touchescudnn.build_plan_policybefore any availability check. If this helper is reached on a build without cuDNN, it will fail with a rawNameErrorinstead of the consistent_check_cudnn_availability()error used by the other cuDNN graph builders.Suggested fix
`@functools.lru_cache`(maxsize=1024) def build_cudnn_gemm_mxfp8_graph( a_shape, a_stride, a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 @@ device, policy=None, ): + _check_cudnn_availability() if policy is None: policy = cudnn.build_plan_policy.HEURISTICS_CHOICE🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/gemm_base.py` around lines 7976 - 7990, The function build_cudnn_gemm_mxfp8_graph currently references cudnn.build_plan_policy before verifying cuDNN is available; add the existing availability guard by calling _check_cudnn_availability() at the start of build_cudnn_gemm_mxfp8_graph (or move the policy default assignment so it happens after that check) so that missing cuDNN raises the consistent _check_cudnn_availability() error instead of a NameError; reference the build_cudnn_gemm_mxfp8_graph function and the cudnn.build_plan_policy assignment when making this change.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
3195-3197: ⚡ Quick winPartition the autotuner cache by override-shape mode.
These
get_cache_key_extras()methods still ignoreself._use_override_shape, but this PR makes that flag change the valid tactic set from plan indices to the heuristic-only fallback path. If the same shape is first cached withenable_override_shape=False, later calls with it enabled can reuse that entry and skip profiling the override-shape plans entirely.At minimum, add
self._use_override_shapeto the cache key for these cuDNN runners so the two execution modes do not alias inAutoTuner.Also applies to: 3702-3706, 4949-4954, 8169-8171
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/gemm_base.py` around lines 3195 - 3197, The get_cache_key_extras methods for the cuDNN GEMM runners currently return only tensor dtypes and therefore do not distinguish override-shape mode; update each get_cache_key_extras (e.g., the method shown and the other occurrences) to include self._use_override_shape in the returned tuple (for example return (a.dtype, b.dtype, out.dtype, self._use_override_shape)) so the autotuner cache is partitioned by override-shape mode and the two execution modes do not alias in AutoTuner.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 7976-7990: The function build_cudnn_gemm_mxfp8_graph currently
references cudnn.build_plan_policy before verifying cuDNN is available; add the
existing availability guard by calling _check_cudnn_availability() at the start
of build_cudnn_gemm_mxfp8_graph (or move the policy default assignment so it
happens after that check) so that missing cuDNN raises the consistent
_check_cudnn_availability() error instead of a NameError; reference the
build_cudnn_gemm_mxfp8_graph function and the cudnn.build_plan_policy assignment
when making this change.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3195-3197: The get_cache_key_extras methods for the cuDNN GEMM
runners currently return only tensor dtypes and therefore do not distinguish
override-shape mode; update each get_cache_key_extras (e.g., the method shown
and the other occurrences) to include self._use_override_shape in the returned
tuple (for example return (a.dtype, b.dtype, out.dtype,
self._use_override_shape)) so the autotuner cache is partitioned by
override-shape mode and the two execution modes do not alias in AutoTuner.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b59f858c-0d29-45fe-a098-c857c7b85d10
📥 Commits
Reviewing files that changed from the base of the PR and between b53e8f3f9d5c9fda80eb61a74cf7b03590c804f4 and aea926091b535d11199bec8dd80feb2f29c12748.
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
- tests/gemm/test_cudnn_override_shape.py
- flashinfer/gemm/init.py
aea9260 to
b086df4
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
3201-3216: ⚡ Quick winReturn only the fallback tactic on the non-override path.
In these branches,
get_valid_tactics()still exposes numbered plan indices, butforward()ignores them and always executes the static graph withtactic=-1. That keeps the autotuner profiling meaningless tactic IDs instead of the single heuristic fallback this change is trying to preserve.Suggested direction
def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: a, b, _, _, out, _ = inputs if self._use_override_shape: graph = self._get_override_graph(a, b, out) - else: - graph = build_cudnn_gemm_fp8_graph( - a_shape=a.shape, - a_stride=a.stride(), - b_shape=b.shape, - b_stride=b.stride(), - a_type=_torch_data_type_to_cudnn_data_type(a.dtype), - b_type=_torch_data_type_to_cudnn_data_type(b.dtype), - o_type=_torch_data_type_to_cudnn_data_type(out.dtype), - device=a.device, - policy=cudnn.build_plan_policy.HEURISTICS_CHOICE, - ) - - return list(range(graph.get_execution_plan_count())) + return list(range(graph.get_execution_plan_count())) + return [-1]Mirror that pattern in the BF16 / FP4 / MXFP8 runners as well.
Also applies to: 3238-3248, 3694-3719, 3742-3744, 4966-4999, 5037-5050, 8192-8208, 8230-8240
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/gemm/gemm_base.py` around lines 3201 - 3216, The non-override branch currently returns a list of all graph plan indices (range(graph.get_execution_plan_count())), but forward() always executes the static graph with tactic=-1, so change the non-override return to expose only the fallback tactic (return [-1] or equivalent single-element list containing the fallback ID) so autotuner sees the single heuristic fallback; update the same pattern in the other locations mentioned (e.g., the BF16/FP4/MXFP8 runner methods around lines 3238-3248, 3694-3719, 3742-3744, 4966-4999, 5037-5050, 8192-8208, 8230-8240) and keep the override branch using _get_override_graph(...) and returning full plan indices there.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3163-3171: The constructors for the cuDNN runners (e.g., the
__init__(self, m_bucket_mapper) shown) ignore the enable_override_shape flag and
always set self._use_override_shape = is_cudnn_override_shape_available(), so
callers cannot opt out; modify these constructors to accept an
enable_override_shape (or similar) parameter and set self._use_override_shape =
bool(enable_override_shape) if provided, otherwise fall back to
is_cudnn_override_shape_available(); update the corresponding BF16/FP4/MXFP8
cuDNN runner factory functions and their call sites to pass this parameter
through (apply the same plumbing pattern used elsewhere) so callers can force
the static fallback path when needed.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3201-3216: The non-override branch currently returns a list of all
graph plan indices (range(graph.get_execution_plan_count())), but forward()
always executes the static graph with tactic=-1, so change the non-override
return to expose only the fallback tactic (return [-1] or equivalent
single-element list containing the fallback ID) so autotuner sees the single
heuristic fallback; update the same pattern in the other locations mentioned
(e.g., the BF16/FP4/MXFP8 runner methods around lines 3238-3248, 3694-3719,
3742-3744, 4966-4999, 5037-5050, 8192-8208, 8230-8240) and keep the override
branch using _get_override_graph(...) and returning full plan indices there.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 122ea434-c0e5-46a8-af0c-c1e2263d82d6
📥 Commits
Reviewing files that changed from the base of the PR and between aea926091b535d11199bec8dd80feb2f29c12748 and b086df4.
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
💤 Files with no reviewable changes (2)
- flashinfer/gemm/init.py
- tests/gemm/test_cudnn_override_shape.py
| ] = "cudnn", | ||
| backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", | ||
| *, | ||
| enable_override_shape: bool = True, |
There was a problem hiding this comment.
Is this API level change strictly necessary?
| tuner = AutoTuner.get() | ||
| effective_m_bucket_mapper = tuner.get_effective_map_to_tuning_buckets( | ||
| _FP8_GEMM_SM100_TUNING_CONFIG, spec_idx=0 | ||
| ) |
There was a problem hiding this comment.
Is this is only used by cudnn, can this be moved to under cudnn runner.
There was a problem hiding this comment.
Yes, it would be cleaner to move its lookup into the cuDNN runner construction path. We can update this in a follow-up change.
📌 Description
It also fixes a tactic mismatch hazard in the non-override path. Autotuning buckets dynamic M values, but static cuDNN graphs are rebuilt with the actual runtime M. A tactic index profiled on a bucket-M graph may not refer to the same execution plan in the actual-M graph. To avoid applying an invalid or mismatched tactic, the non-override cuDNN paths now expose only the fallback tactic -1, forcing runtime to use the cuDNN heuristic path for the actual static graph.
This update also brings the cuDNN FP8 and MXFP8 GEMM paths in line with the existing BF16/FP4 override-shape design.
For FP8, the per-tensor quantized cuDNN graph builders were renamed to follow the same naming convention as the other GEMM paths: build_cudnn_gemm_fp8_graph and build_cudnn_gemm_fp8_graph_override_shape. The cuDNN FP8 runner now supports the override-shape execution path, using the autotuner’s effective M-bucket mapper to build a reusable bucketed graph and passing the runtime shapes through override_shapes.
For MXFP8, the cuDNN graph construction was refactored to match the FP4/BF16 structure. The previous create_cudnn_execution_plans_mxfp8_gemm + _get_cudnn_mxfp8_gemm_graph split was replaced with a single build_cudnn_gemm_mxfp8_graph builder that owns graph creation, support checking, and plan building. The MXFP8 runner now also supports the override-shape path with bucketed graph reuse.
For both FP8 and MXFP8, the non-override static cuDNN path no longer participates in autotune plan-index profiling, since static graphs are rebuilt for the actual runtime M and cannot safely reuse tactic indices profiled on bucketed M shapes.
🔍 Related Issues
test_bmm_bf16 failure on B300
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
API Changes
Improvements
Tests