perf: cache cudaGetDeviceProperties in gdn_prefill to avoid per-call overhead#2509
Conversation
Summary of ChangesHello @xutizhou, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly improves the performance of the GDN prefill process by optimizing how CUDA device properties are accessed. By caching frequently used device properties, it eliminates redundant and costly Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughCUDA device property retrieval in Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization by caching the results of cudaGetDeviceProperties to avoid its overhead on repeated calls. The use of static thread_local is a good approach for this.
However, I've identified a potential issue in the current implementation. The caching mechanism doesn't account for scenarios where a thread's associated CUDA device might change. This could lead to using stale device properties, causing incorrect behavior in multi-GPU environments.
My review includes suggestions to make the caching robust by also tracking the device ID, ensuring the cache is invalidated if the device changes. These changes preserve the performance gain while adding correctness for more complex use cases.
| static thread_local int cached_device_major = -1; | ||
| if (cached_device_major < 0) { | ||
| int dev_id; | ||
| cudaGetDevice(&dev_id); | ||
| cudaDeviceProp device_properties; | ||
| cudaGetDeviceProperties(&device_properties, dev_id); | ||
| cached_device_major = device_properties.major; | ||
| } |
There was a problem hiding this comment.
The current caching implementation with static thread_local doesn't account for cases where a thread's active CUDA device might change between calls. If cudaSetDevice() is used to switch the device for a thread, this function will continue to use the stale cached device_major from the first device it encountered. This could lead to incorrect behavior, such as dispatching to a kernel for the wrong SM architecture.
To make this robust, I suggest caching the device ID as well and re-fetching the properties only when the device ID changes. cudaGetDevice() is a very lightweight call, so this check will have negligible overhead while ensuring correctness.
static thread_local int cached_device_major = -1, cached_dev_id = -1;
int dev_id;
cudaGetDevice(&dev_id);
if (cached_dev_id != dev_id) {
cudaDeviceProp device_properties;
cudaGetDeviceProperties(&device_properties, dev_id);
cached_device_major = device_properties.major;
cached_dev_id = dev_id;
}
| static thread_local int32_t cached_sm_count = -1; | ||
| if (cached_sm_count < 0) { | ||
| int dev_id; | ||
| cudaGetDevice(&dev_id); | ||
| cudaDeviceProp device_properties; | ||
| cudaGetDeviceProperties(&device_properties, dev_id); | ||
| cached_sm_count = device_properties.multiProcessorCount; | ||
| } | ||
| int32_t sm_count = cached_sm_count; |
There was a problem hiding this comment.
Similar to my other comment, the caching for sm_count has a potential issue. It doesn't handle cases where the thread's CUDA device is changed between calls. This could result in using an incorrect sm_count for kernel launches if a thread is reused for a different GPU.
To prevent this, we should also cache the device ID and re-query cudaGetDeviceProperties when the device ID for the thread changes. This ensures the cached sm_count is always correct for the current device.
static thread_local int32_t cached_sm_count = -1; static thread_local int cached_dev_id = -1;
int dev_id;
cudaGetDevice(&dev_id);
if (cached_dev_id != dev_id) {
cudaDeviceProp device_properties;
cudaGetDeviceProperties(&device_properties, dev_id);
cached_sm_count = device_properties.multiProcessorCount;
cached_dev_id = dev_id;
}
int32_t sm_count = cached_sm_count;
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@csrc/gdn_prefill_launcher.cu`:
- Around line 165-174: The cached_sm_count is a thread_local scalar that never
invalidates when the active CUDA device changes, causing stale values across
device switches; update the cache to be per-device by storing also a
thread_local cached_sm_dev_id (initialized to -1) and when calling cudaGetDevice
compare the current dev_id to cached_sm_dev_id, and only call
cudaGetDeviceProperties and update cached_sm_count (and cached_sm_dev_id) when
they differ; reference the existing symbols cached_sm_count, cudaGetDevice,
cudaGetDeviceProperties, and sm_count and include error checking on
cudaGetDevice/cudaGetDeviceProperties before assigning the cached values.
- Around line 43-51: The thread_local caches cached_device_major and
cached_sm_count become stale if the active CUDA device changes; modify each to
also store a thread_local cached_device_id (initialize to -1) and before using
the cached values call cudaGetDevice(&dev_id) and if dev_id != cached_device_id
re-query device properties (via cudaGetDeviceProperties) to refresh both
cached_device_id and the corresponding cached value (device_major or sm_count);
ensure you update both caches whenever the device differs so cudaSetDevice()
switches are handled correctly.
🧹 Nitpick comments (1)
csrc/gdn_prefill_launcher.cu (1)
43-51: Consider consolidating the two independent device-property caches.Both
gdn_prefill_launcherandgdn_prefillindependently callcudaGetDevice+cudaGetDevicePropertiesto cache different fields. Sincegdn_prefillalready passessm_countto the launcher, it could also resolvedevice_majorand pass it as an additional parameter — eliminating one cache site and one redundantcudaGetDevicePropertiescall on first invocation per thread.Also applies to: 165-174
|
Hi @xutizhou , yes we can use cache to reduce the overhead. But how about a simpler fix: instead of relying on FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));Which should be much faster. |
62a7c27 to
74255c1
Compare
There was a problem hiding this comment.
For changes to this file, we can probably create another PR?
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1277-1278: The reshape call creating h0_source from initial_state
can silently create a copy when initial_state is non-contiguous, so in-place
kernel updates to h0_source won't propagate back; to fix, check
initial_state.is_contiguous() before reshaping and if it's non-contiguous create
a contiguous copy (e.g., tmp = initial_state.contiguous()) and reshape that to
h0_source, then after your kernel updates ensure you copy the modified values
back into the original initial_state (e.g.,
initial_state.copy_(h0_source.view(pool_size, HV, V, K))) so the returned
initial_state reflects the in-place updates; apply the same guard/roundtrip used
in the non-pooled case to the h0_source/initial_state handling.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
237-237: Non-English comment.
V 方向分 tiles— consider replacing with an English equivalent (e.g., "Tile along V dimension") for consistency with the rest of the codebase.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1077-1078: The pooled-mode reshape can produce a copy when `state`
is non-contiguous causing silent state corruption because the kernel updates the
copy but only copies back in non-pooled mode; in the `gdn_decode` code path
around the `state.reshape(pool_size * HV, V, K)` call (variables: state,
pool_size, HV, V, K, and flag use_pool_indexing) enforce contiguity before
reshaping by either asserting `state.is_contiguous()` when `use_pool_indexing`
is True or making an explicit contiguous copy (e.g., `state =
state.contiguous()`) so the kernel updates the real buffer, and update the
comment/docs to state that pooled mode requires a contiguous `state` tensor.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
237-237: Non-English comment in kernel code.Line 237 contains a Chinese comment
# V 方向分 tiles(meaning "split tiles in V direction"). For consistency and accessibility, consider using English:# Split V dimension into tiles.Suggested fix
- # V 方向分 tiles + # Split V dimension into tilesSame applies to line 526 in the big batch kernel.
…overhead Cache device_major and sm_count using static thread_local variables instead of calling cudaGetDeviceProperties on every invocation. Each call adds ~1ms of synchronous CPU overhead, which accumulates to ~72ms across 36 GDN layers during prefill.
…n gdn_prefill Replace cudaGetDeviceProperties (which has ~1ms overhead) with cudaDeviceGetAttribute for querying device_major and sm_count. cudaDeviceGetAttribute is much faster, eliminating the need for thread_local caching entirely. Addresses review comment from yzh119 on PR flashinfer-ai#2509.
ad893d6 to
076282c
Compare
Summary
device_properties.majoranddevice_properties.multiProcessorCountusingstatic thread_localvariables ingdn_prefill_launcher()andgdn_prefill()respectively, instead of calling
cudaGetDevicePropertieson every invocation.cudaGetDevicePropertiescall adds ~1ms of synchronous CPU overhead, whichaccumulates to ~72ms across 36 GDN layers during prefill. This change eliminates
that repeated overhead.
Changes
csrc/gdn_prefill_launcher.cu: Replace per-callcudaGetDevicePropertieswithstatic thread_localcached values fordevice_major(used for SM architecturedispatch) and
sm_count(used for kernel launch configuration).🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit