Skip to content

fix fp8 kv cache dequantize kernels #3896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

mxz297
Copy link
Contributor

@mxz297 mxz297 commented Mar 28, 2025

Summary:
Fix fp8 kv cache dequantization kernel and enable unit test on AMD.

The kernel uses each thread to dequantize 4 elements for both K and V and each warp for a head. The dim is always 128. So on NV this works as one warp has 32 threads on NV (4 * 32 = 128).

On AMD, each wavefront (warp) has 64 threads, so the second 32 threads will all do out-of-bound memory access....

This diff simply masks those threads to do nothing. Obviously the perf is not good but from E2E testing, it does not seem to matter. If we need to optimize the perf for AMD, we can let thread 0 ~ 31 dequantize 4 elements for K and thread 32 ~ 63 thread dequantize 4 elements for V.

Differential Revision: D72062745

Summary:
Fix fp8 kv cache dequantization kernel and enable unit test on AMD.

The kernel uses each thread to dequantize 4 elements for both K and V and each warp for a head. The dim is always 128. So on NV this works as one warp has 32 threads on NV (4 * 32 = 128).

On AMD, each wavefront (warp) has 64 threads, so the second 32 threads will all do out-of-bound memory access....

This diff simply masks those threads to do nothing. Obviously the perf is not good but from E2E testing, it does not seem to matter. If we need to optimize the perf for AMD, we can let thread 0 ~ 31 dequantize 4 elements for K and thread 32 ~ 63 thread dequantize 4 elements for V.

Differential Revision: D72062745
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72062745

Copy link

netlify bot commented Mar 28, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 8556aa5
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67e71b0d5fa2a90008d70819
😎 Deploy Preview https://deploy-preview-3896--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in a303797.

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 47635cf.

q10 pushed a commit to q10/FBGEMM that referenced this pull request Apr 10, 2025
Summary:
X-link: pytorch#3896

Pull Request resolved: facebookresearch/FBGEMM#987

Fix fp8 kv cache dequantization kernel and enable unit test on AMD.

The kernel uses each thread to dequantize 4 elements for both K and V and each warp for a head. The dim is always 128. So on NV this works as one warp has 32 threads on NV (4 * 32 = 128).

On AMD, each wavefront (warp) has 64 threads, so the second 32 threads will all do out-of-bound memory access....

This diff simply masks those threads to do nothing. Obviously the perf is not good but from E2E testing, it does not seem to matter. If we need to optimize the perf for AMD, we can let thread 0 ~ 31 dequantize 4 elements for K and thread 32 ~ 63 thread dequantize 4 elements for V.

Reviewed By: Aya-ZIbra

Differential Revision: D72062745

fbshipit-source-id: 1b813057586054a13df4e9088be00b08f912bc57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants