Skip to content

perf: cache cudaGetDeviceProperties in gdn_prefill to avoid per-call overhead#2509

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
xutizhou:perf/cache-cudaGetDeviceProperties
Feb 9, 2026
Merged

perf: cache cudaGetDeviceProperties in gdn_prefill to avoid per-call overhead#2509
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
xutizhou:perf/cache-cudaGetDeviceProperties

Conversation

@xutizhou
Copy link
Copy Markdown
Contributor

@xutizhou xutizhou commented Feb 6, 2026

Summary

  • Cache device_properties.major and device_properties.multiProcessorCount using
    static thread_local variables in gdn_prefill_launcher() and gdn_prefill()
    respectively, instead of calling cudaGetDeviceProperties on every invocation.
  • Each cudaGetDeviceProperties call adds ~1ms of synchronous CPU overhead, which
    accumulates to ~72ms across 36 GDN layers during prefill. This change eliminates
    that repeated overhead.

Changes

  • csrc/gdn_prefill_launcher.cu: Replace per-call cudaGetDeviceProperties with
    static thread_local cached values for device_major (used for SM architecture
    dispatch) and sm_count (used for kernel launch configuration).

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Streamlined GPU device attribute retrieval mechanisms for accessing compute capability and processor information.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @xutizhou, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the performance of the GDN prefill process by optimizing how CUDA device properties are accessed. By caching frequently used device properties, it eliminates redundant and costly cudaGetDeviceProperties calls, thereby reducing CPU overhead and speeding up the prefill operation.

Highlights

  • Performance Optimization: Introduced caching for device_properties.major and device_properties.multiProcessorCount using static thread_local variables in gdn_prefill_launcher() and gdn_prefill() respectively.
  • Overhead Reduction: Eliminated repeated cudaGetDeviceProperties calls, which previously added approximately 1ms of synchronous CPU overhead per call, accumulating to about 72ms across 36 GDN layers during the prefill process.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/gdn_prefill_launcher.cu
    • Implemented static thread_local int cached_device_major to store the CUDA device's major version, avoiding repeated calls to cudaGetDeviceProperties for architecture dispatch.
    • Implemented static thread_local int32_t cached_sm_count to store the number of streaming multiprocessors, preventing redundant cudaGetDeviceProperties calls for kernel launch configuration.
Activity
  • No human activity (comments, reviews, etc.) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

CUDA device property retrieval in gdn_prefill_launcher.cu is refactored to use cudaDeviceGetAttribute instead of cudaGetDeviceProperties, retrieving compute capability and multiprocessor count directly via targeted attribute lookups.

Changes

Cohort / File(s) Summary
CUDA Device Property Refactoring
csrc/gdn_prefill_launcher.cu
Replaces cudaGetDeviceProperties calls with cudaDeviceGetAttribute for compute capability major and multiprocessor count retrieval. Maintains control flow and function signatures.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

Poem

🐰 Device attributes, now direct and quick,
No heavyweight properties—a cleaner trick!
CUDA calls optimized, lean and tight,
Compute power whispers, perfectly right! ✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title accurately describes the main change: caching cudaGetDeviceProperties to avoid per-call overhead in gdn_prefill functions. However, the actual implementation uses cudaDeviceGetAttribute instead of static caching, creating a discrepancy. Clarify whether the PR implements static thread_local caching (as title suggests) or uses cudaDeviceGetAttribute (as summary indicates). Update title if implementation differs from caching approach.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed PR description follows the template structure with Summary, Changes, and completed Pre-commit Checklist sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by caching the results of cudaGetDeviceProperties to avoid its overhead on repeated calls. The use of static thread_local is a good approach for this.

However, I've identified a potential issue in the current implementation. The caching mechanism doesn't account for scenarios where a thread's associated CUDA device might change. This could lead to using stale device properties, causing incorrect behavior in multi-GPU environments.

My review includes suggestions to make the caching robust by also tracking the device ID, ensuring the cache is invalidated if the device changes. These changes preserve the performance gain while adding correctness for more complex use cases.

Comment thread csrc/gdn_prefill_launcher.cu Outdated
Comment on lines +44 to +51
static thread_local int cached_device_major = -1;
if (cached_device_major < 0) {
int dev_id;
cudaGetDevice(&dev_id);
cudaDeviceProp device_properties;
cudaGetDeviceProperties(&device_properties, dev_id);
cached_device_major = device_properties.major;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current caching implementation with static thread_local doesn't account for cases where a thread's active CUDA device might change between calls. If cudaSetDevice() is used to switch the device for a thread, this function will continue to use the stale cached device_major from the first device it encountered. This could lead to incorrect behavior, such as dispatching to a kernel for the wrong SM architecture.

To make this robust, I suggest caching the device ID as well and re-fetching the properties only when the device ID changes. cudaGetDevice() is a very lightweight call, so this check will have negligible overhead while ensuring correctness.

    static thread_local int cached_device_major = -1, cached_dev_id = -1;
    int dev_id;
    cudaGetDevice(&dev_id);
    if (cached_dev_id != dev_id) {
      cudaDeviceProp device_properties;
      cudaGetDeviceProperties(&device_properties, dev_id);
      cached_device_major = device_properties.major;
      cached_dev_id = dev_id;
    }

Comment thread csrc/gdn_prefill_launcher.cu Outdated
Comment on lines +166 to +174
static thread_local int32_t cached_sm_count = -1;
if (cached_sm_count < 0) {
int dev_id;
cudaGetDevice(&dev_id);
cudaDeviceProp device_properties;
cudaGetDeviceProperties(&device_properties, dev_id);
cached_sm_count = device_properties.multiProcessorCount;
}
int32_t sm_count = cached_sm_count;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to my other comment, the caching for sm_count has a potential issue. It doesn't handle cases where the thread's CUDA device is changed between calls. This could result in using an incorrect sm_count for kernel launches if a thread is reused for a different GPU.

To prevent this, we should also cache the device ID and re-query cudaGetDeviceProperties when the device ID for the thread changes. This ensures the cached sm_count is always correct for the current device.

  static thread_local int32_t cached_sm_count = -1; static thread_local int cached_dev_id = -1;
  int dev_id;
  cudaGetDevice(&dev_id);
  if (cached_dev_id != dev_id) {
    cudaDeviceProp device_properties;
    cudaGetDeviceProperties(&device_properties, dev_id);
    cached_sm_count = device_properties.multiProcessorCount;
    cached_dev_id = dev_id;
  }
  int32_t sm_count = cached_sm_count;

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@csrc/gdn_prefill_launcher.cu`:
- Around line 165-174: The cached_sm_count is a thread_local scalar that never
invalidates when the active CUDA device changes, causing stale values across
device switches; update the cache to be per-device by storing also a
thread_local cached_sm_dev_id (initialized to -1) and when calling cudaGetDevice
compare the current dev_id to cached_sm_dev_id, and only call
cudaGetDeviceProperties and update cached_sm_count (and cached_sm_dev_id) when
they differ; reference the existing symbols cached_sm_count, cudaGetDevice,
cudaGetDeviceProperties, and sm_count and include error checking on
cudaGetDevice/cudaGetDeviceProperties before assigning the cached values.
- Around line 43-51: The thread_local caches cached_device_major and
cached_sm_count become stale if the active CUDA device changes; modify each to
also store a thread_local cached_device_id (initialize to -1) and before using
the cached values call cudaGetDevice(&dev_id) and if dev_id != cached_device_id
re-query device properties (via cudaGetDeviceProperties) to refresh both
cached_device_id and the corresponding cached value (device_major or sm_count);
ensure you update both caches whenever the device differs so cudaSetDevice()
switches are handled correctly.
🧹 Nitpick comments (1)
csrc/gdn_prefill_launcher.cu (1)

43-51: Consider consolidating the two independent device-property caches.

Both gdn_prefill_launcher and gdn_prefill independently call cudaGetDevice + cudaGetDeviceProperties to cache different fields. Since gdn_prefill already passes sm_count to the launcher, it could also resolve device_major and pass it as an additional parameter — eliminating one cache site and one redundant cudaGetDeviceProperties call on first invocation per thread.

Also applies to: 165-174

Comment thread csrc/gdn_prefill_launcher.cu Outdated
Comment thread csrc/gdn_prefill_launcher.cu Outdated
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 6, 2026

Hi @xutizhou , yes we can use cache to reduce the overhead.

But how about a simpler fix: instead of relying on cudaGetDeviceProperties (which has huge runtime overhead).
We can try:

FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));

Which should be much faster.

@xutizhou xutizhou force-pushed the perf/cache-cudaGetDeviceProperties branch from 62a7c27 to 74255c1 Compare February 6, 2026 11:31
Comment thread flashinfer/gdn_decode.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For changes to this file, we can probably create another PR?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1277-1278: The reshape call creating h0_source from initial_state
can silently create a copy when initial_state is non-contiguous, so in-place
kernel updates to h0_source won't propagate back; to fix, check
initial_state.is_contiguous() before reshaping and if it's non-contiguous create
a contiguous copy (e.g., tmp = initial_state.contiguous()) and reshape that to
h0_source, then after your kernel updates ensure you copy the modified values
back into the original initial_state (e.g.,
initial_state.copy_(h0_source.view(pool_size, HV, V, K))) so the returned
initial_state reflects the in-place updates; apply the same guard/roundtrip used
in the non-pooled case to the h0_source/initial_state handling.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

237-237: Non-English comment.

V 方向分 tiles — consider replacing with an English equivalent (e.g., "Tile along V dimension") for consistency with the rest of the codebase.

Comment thread flashinfer/gdn_decode.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1077-1078: The pooled-mode reshape can produce a copy when `state`
is non-contiguous causing silent state corruption because the kernel updates the
copy but only copies back in non-pooled mode; in the `gdn_decode` code path
around the `state.reshape(pool_size * HV, V, K)` call (variables: state,
pool_size, HV, V, K, and flag use_pool_indexing) enforce contiguity before
reshaping by either asserting `state.is_contiguous()` when `use_pool_indexing`
is True or making an explicit contiguous copy (e.g., `state =
state.contiguous()`) so the kernel updates the real buffer, and update the
comment/docs to state that pooled mode requires a contiguous `state` tensor.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

237-237: Non-English comment in kernel code.

Line 237 contains a Chinese comment # V 方向分 tiles (meaning "split tiles in V direction"). For consistency and accessibility, consider using English: # Split V dimension into tiles.

Suggested fix
-        # V 方向分 tiles
+        # Split V dimension into tiles

Same applies to line 526 in the big batch kernel.

Comment thread flashinfer/gdn_decode.py Outdated
…overhead

Cache device_major and sm_count using static thread_local variables
instead of calling cudaGetDeviceProperties on every invocation.
Each call adds ~1ms of synchronous CPU overhead, which accumulates
to ~72ms across 36 GDN layers during prefill.
…n gdn_prefill

Replace cudaGetDeviceProperties (which has ~1ms overhead) with
cudaDeviceGetAttribute for querying device_major and sm_count.
cudaDeviceGetAttribute is much faster, eliminating the need for
thread_local caching entirely.

Addresses review comment from yzh119 on PR flashinfer-ai#2509.
@xutizhou xutizhou force-pushed the perf/cache-cudaGetDeviceProperties branch from ad893d6 to 076282c Compare February 8, 2026 13:41
@yzh119 yzh119 merged commit 30bf78e into flashinfer-ai:main Feb 9, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants