Skip to content

Use cudnn 9.23 new API to query workspace with override shape#3291

Open
yanqinz2 wants to merge 2 commits into
mainfrom
yanqinz/cudnn-override-shape-workspace
Open

Use cudnn 9.23 new API to query workspace with override shape#3291
yanqinz2 wants to merge 2 commits into
mainfrom
yanqinz/cudnn-override-shape-workspace

Conversation

@yanqinz2
Copy link
Copy Markdown
Collaborator

@yanqinz2 yanqinz2 commented May 11, 2026

📌 Description

Description

This MR makes two cuDNN GEMM backend cleanups/improvements:

  1. Move effective M-bucket mapper lookup into the cuDNN runners

    The effective map_to_tuning_buckets lookup is now owned by the cuDNN GEMM runners instead of the higher-level GEMM dispatch functions. This keeps the bucket-to-cache_m logic local to the backend that uses it, while still respecting active autotune overrides such as custom tuning_buckets / round_up.

  2. Query override-shape workspace size dynamically on cuDNN 9.23+

    For cuDNN override-shape GEMM execution, cuDNN 9.23+ can query the workspace requirement for the actual runtime problem shape via get_workspace_size_plan_at_index(...) with override shapes and strides. The code now uses this API when available, so workspace allocation matches the executed dynamic problem size. Older cuDNN versions continue to query workspace by execution plan index without override-shape metadata.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Improvements
    • Better cuDNN version compatibility checks for more reliable GPU support across driver versions.
    • More robust and efficient workspace sizing and allocation for matrix-multiply workloads, improving performance and memory use.
    • Autotuning and runner initialization simplified so tuning behavior is computed automatically, improving reliability across BF16/FP8/FP4/MXFP8 workloads.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 11, 2026

📝 Walkthrough

Walkthrough

Runner factories (BF16/FP8/MXFP8/FP4) now compute effective AutoTuner mappers internally; override-shape availability and workspace sizing are cuDNN-backend-version aware; standard and override execution helpers resize workspaces and pass explicit cudnn handles; call sites updated to new runner signatures.

Changes

cuDNN GEMM Runner Mapper Internalization

Layer / File(s) Summary
Version & Workspace Sizing Helpers
flashinfer/gemm/gemm_base.py
is_cudnn_override_shape_available() now conditionally requires newer cudnn-frontend when cudnn.backend_version() >= 92300. _get_cudnn_override_shape_workspace_size(...) selects workspace-sizing API by backend version.
Standard Graph Workspace Sizing
flashinfer/gemm/gemm_base.py
Non-override FP4/MXFP8/FP8/BF16 execution paths now use _get_cudnn_workspace_size(graph, tactic) and resize the provided workspace tensor.
Override-Shape Execution Updates
flashinfer/gemm/gemm_base.py
FP4, MXFP8, FP8, and BF16 override-shape helpers compute workspace via _get_cudnn_override_shape_workspace_size(...), resize/allocate workspace, obtain/pass explicit cudnn_handle, and call graph.execute_plan_at_index(...) with override uids/shapes/strides.
Runner Factory Mapper Internalization
flashinfer/gemm/gemm_base.py
_cudnn_gemm_fp8_runner, _cudnn_gemm_bf16_runner, and _cudnn_gemm_mxfp8_runner remove external m_bucket_mapper parameter and compute the effective mapper from AutoTuner.get().get_effective_map_to_tuning_buckets(...). _cudnn_gemm_fp4_runner now accepts tuning_config and derives the mapper internally.
Call Site Updates
flashinfer/gemm/gemm_base.py
Call sites (bf16_gemm_sm100, fp8_gemm_sm100, mxfp8_gemm_sm100, mm_fp4) stop precomputing mappers and instantiate runners with the new signatures (pass layout flags or tuning_config where required).

Possibly related PRs

  • flashinfer-ai/flashinfer#3260: Modifies cuDNN override-shape handling and GEMM runner APIs (BF16/FP4/FP8/MXFP8) related to override vs. non-override behavior and tactic/cache logic.
  • flashinfer-ai/flashinfer#2948: Changes BF16 GEMM runner and override-shape execution paths in gemm_base.py; related to runner/graph handling updates.
  • flashinfer-ai/flashinfer#3192: Previously injected m_bucket_mapper via AutoTuner into runners; this PR internalizes that computation inside runner factories.

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • bkryu
  • dhiraj113

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰
I hop through tensors, neat and spry,
Mappers tucked inside—no need to pry.
Workspaces sized by backend's song,
Runners hum along and run strong.
🍃

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main change: moving to cuDNN 9.23's new API for querying workspace with override shapes, which aligns with the primary objective of the changeset.
Description check ✅ Passed The PR description covers both objectives clearly: moving mapper lookup into runners and querying override-shape workspace dynamically. Pre-commit checks are marked complete, but tests are not marked complete, which may warrant attention but doesn't fail the description.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yanqinz/cudnn-override-shape-workspace

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the cuDNN GEMM implementation by moving the effective_m_bucket_mapper logic into runner factories and updating version checks to support cuDNN backend 9.23.0 with frontend 1.24. It also introduces a helper function for workspace size calculation. Feedback focuses on using workspace.resize_() instead of reassigning the local variable to ensure in-place updates are reflected in the caller's reference, which avoids repeated allocations and skewed performance measurements during autotuning.

Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2094-2101: The probe that checks cudnn.backend_version() and
parses cudnn.__version__ currently swallows all exceptions with a bare "except
Exception" which can mask real errors; update the exception handling in that
block (the code referencing backend_version, version_str, major, minor,
required_frontend_version, and cudnn.__version__) to catch only the expected
failure types (e.g., AttributeError, ValueError, TypeError, OSError) when
probing/parsing and let other exceptions propagate (or re-raise) so unexpected
errors aren't hidden—replace the broad except Exception with a tuple of these
specific exceptions and handle/log them accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a666c109-a296-4313-af7c-5b21db207d1f

📥 Commits

Reviewing files that changed from the base of the PR and between 7016955 and d4ed677.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

Comment thread flashinfer/gemm/gemm_base.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants