Skip to content

Understanding 8da4w #430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
DzAvril opened this issue Jun 24, 2024 · 3 comments
Closed

Understanding 8da4w #430

DzAvril opened this issue Jun 24, 2024 · 3 comments
Labels

Comments

@DzAvril
Copy link

DzAvril commented Jun 24, 2024

Hi there,

I'm new to quantization. From my understanding, "8da4w" means that the weights are pre-quantized to 4 bits, and the activations are quantized to 8 bits at runtime. Following this, the GEMM (General Matrix Multiply) operation between weights and activations is computed in the int8 data type. Do I have this correct?

However, I'm confused by the code for Int8DynActInt4WeightQuantizer. The forward method of Int8DynActInt4WeightLinear calls a method named per_token_dynamic_quant, which can be found here. In this method, the input is first quantized to int8 and then immediately converted back to its original data type without further processing. I don't understand the purpose of this function. Furthermore, I have launched a program using Int8DynActInt4WeightQuantizer and observed the data types of x and w_dq in the method linear_forward_8da4w, which can be found here, they both are float32. This seems to contradict my understanding of the computations involved in '8da4w'.

I realize that I'm likely missing some fundamental aspects of dynamic quantization. Could anyone kindly clarify this process for me?

Thank you!

@supriyar
Copy link
Contributor

Following this, the GEMM (General Matrix Multiply) operation between weights and activations is computed in the int8 data type.

this probably depends on the specific backend. For 8da4w we've tested it to work with ExecuTorch runtime (XNNPack backend) which I believe does the computation in the int bitwidths directly (8-bit act x 4-bit weight)

@jerryzh168 can probably help confirm this and help answer the other questions.

@jerryzh168
Copy link
Contributor

It's true that we will need to use integer compute to speed things up, that's what we are doing in our int8_dynamic_activation_int8_weight (running on CUDA) API: https://github.com/pytorch/ao/tree/main/torchao/quantization#a8w8-dynamic-quantization

But specifically for 8da4w, we don't expect immediate speed up after quantization in server since that is targeting to be used in ExecuTorch (https://github.com/pytorch/ao/tree/main/torchao/quantization#to-be-deprecated-a8w8-dynamic-quantization) and the requirement there is that we produce a representation for quantized model so that it can be matched and lowered to a specific library (e.g. xnnpack). Here is a bit more context on the reasoning behind producing a pattern for further downstream consumption: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md

@jerryzh168
Copy link
Contributor

closing since the question is answered, feel free to reach out for more questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants