Skip to content

First version of FP8 scaled_mm. #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

cfgfung
Copy link

@cfgfung cfgfung commented Jun 11, 2025

This is the initial implementation of scaled_mm(), enabling element-wise scaling for A & B matrices (scales of A & B have to be of the same size as A or B, so this implementation is agnostic of the quantization scheme).

Based on the existing W8A8 normal GEMM, this version focuses on delivering core functionality.

@sanchitintel will use this PR as a foundation for further optimizations in subsequent PRs:

  1. Passing FP32 scales instead of FP16 (this PR)
  2. This PR requires scale tensors to be of the same size as A and B. They are created in the FW side & passed to cutlass, so there's a creation overhead, the memory bandwidth requirements increases, and register spilling also happens. Ideally, the framework should pass scales as is.

Copy link
Collaborator

@sanchitintel sanchitintel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cfgfung, I'll defer to the cutlass team for review, and will add some comments later.
But maybe you can change the state of this PR to draft to rebase later.
The corresponding integration code with IPEX is failing accuracy, BTW.

@sanchitintel sanchitintel marked this pull request as draft June 12, 2025 08:22
@cfgfung
Copy link
Author

cfgfung commented Jun 13, 2025

@cfgfung, I'll defer to the cutlass team for review, and will add some comments later. But maybe you can change the state of this PR to draft to rebase later. The corresponding integration code with IPEX is failing accuracy, BTW.

You may wish to follow up with IPEX team and check what is the root cause of the failing accuracy. This code passes the verify()/ unit test. I have granted the rights to you. You should be able to commit/change the code.

*
**************************************************************************************************/
/*
* This implements the scaled_mm for W8A8 GEMM (FP8 weights and activations) using FP16 compute as a workaround,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is scaled_mm? Can you express that in CUTLASS terms?

Copy link
Author

@cfgfung cfgfung Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review.
This means matrix multiplication with scaling factors.
(scaleA.*TensorA) @ (scaleB.*TensorB)

Copy link
Collaborator

@sanchitintel sanchitintel Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cfgfung, please add info on the granularity of scales in the code as well as the PR description, since this scaling is specific to DeepSeek.

For example, for 128x128 A & B blocks, A scale is applied at the granularity of per-token-per-128-channel sub-vector to an A block, and B scale is applied to the whole B block.

For this PR, however, this granularity doesn't matter to the collective & the kernel, because the cutlass library user has to ensure that the scales of A & B are of the same size as A & B, so the scaling in this PR is elementwise.

I'll make it more generic in my subsequent PR, so that various workgroup tile-sizes could work with various quantization block sizes.
However, can you please rename the collective & kernel files to indicate that the quantization is supposed to be DeepSeek-style block-wise? In that case, I could modify the same files. Thanks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we call this scaled F8 GEMM? I do not think weights and activations are terms that are relevant on CUTLASS level. I am also not against adding more details about how scaling is done.

Copy link
Collaborator

@sanchitintel sanchitintel Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @t4c1, this implementation requires users to pass scales at the granularity of each element of A or B matrices (so scale matrices also have to be of the same size as A or B). That makes this implementation quantization-scheme agnostic. It can't be used in real workloads, though.


template <typename SrcT, typename ScaleT>
void elementwise_multiply_scale(SrcT* d_src, size_t size, ScaleT* d_scale){
SrcT* h_src_multiplied = new SrcT[size];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::vector instead of directly calling new

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sycl::memcpy() cannot take the vector as the argument.
I added delete[] instead to handle the memory issue.

Copy link
Collaborator

@sanchitintel sanchitintel Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sycl::memcpy() cannot take the vector as the argument

The pointer pertaining to a vector could have been used, though.

Regardless, it doesn't matter, as we shouldn't do this computation on CPU, so please change it to something like https://github.com/mehdi-goli/cutlass-fork/blob/884fa6a8b94adddba9f32b41b6a9d011e1642217/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp#L202-L207

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was not addressed. Is anything unclear here?

cfgfung added 2 commits June 21, 2025 01:29
This will enable the W8A8 Block quantized matmul for DeepSeek-R1 model.
@cfgfung cfgfung force-pushed the raymond/scaled_mm branch from ef7ff20 to 2573e3b Compare June 20, 2025 17:46
@cfgfung
Copy link
Author

cfgfung commented Jun 20, 2025

@t4c1 @sanchitintel
Thanks for the reviews and comments. I have updated and rebased the code to address the comments.

Copy link
Collaborator

@t4c1 t4c1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some tests.

Also can you clarify in what way this differs from the changes from #450.

using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales B must match shape of the copy atom for B in the number of elements


template <typename SrcT, typename ScaleT>
void elementwise_multiply_scale(SrcT* d_src, size_t size, ScaleT* d_scale){
SrcT* h_src_multiplied = new SrcT[size];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was not addressed. Is anything unclear here?

@cfgfung
Copy link
Author

cfgfung commented Jul 8, 2025

Closing this as it is duplicating the efforts with others.

@cfgfung cfgfung closed this Jul 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants