Skip to content

How to calibrate a w8a8 quantized model? #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
chenghuaWang opened this issue Oct 3, 2024 · 5 comments
Closed

How to calibrate a w8a8 quantized model? #1002

chenghuaWang opened this issue Oct 3, 2024 · 5 comments

Comments

@chenghuaWang
Copy link

I used the following code to quantize an LLM, employing an w8a8 quantization setting:

model = AutoModelForCausalLM.from_pretrained("./Qwen1.5-0.5B-Chat").to(dtype=torch.bfloat16, device='cpu')
quantize_(model, int8_dynamic_activation_int8_weight())

Everything is running smoothly, but the model's accuracy has decreased significantly. How can I calibrate a quantized model to enhance its accuracy?


I have another question:

I printed out a parameter and noticed that the weights were quantized using per-channel quantization. What is the purpose of the fp16 AffineQuantizedTensor? Shouldn't the activation only require one scale parameter when using per-tensor quantization?

I'm not very familiar with the quantization mechanism in PyTorch, and I hope you can give me some tips.

Parameter Name: model.layers.0.self_attn.q_proj.weight
Parameter Shape: torch.Size([1024, 1024])
Parameter Values: LinearActivationQuantizedTensor(AffineQuantizedTensor(data=tensor([[ 0.2148, -0.1196, -0.0898,  ..., -0.0388,  0.0869,  0.0898],
        [ 0.0830, -0.2188, -0.1436,  ...,  0.0566,  0.0679,  0.0830],
        [ 0.0552, -0.2480, -0.1621,  ...,  0.0242,  0.0688,  0.0830],
        ...,
        [ 0.0742, -0.0417, -0.1641,  ..., -0.0356,  0.1543, -0.0566],
        [-0.0640,  0.0771,  0.2695,  ...,  0.0537, -0.1982,  0.0938],
        [-0.1216,  0.1025, -0.1074,  ..., -0.0327,  0.1592, -0.1123]],
       dtype=torch.bfloat16)..., shape=torch.Size([1024, 1024]), block_size=(1, 1024), device=cpu, dtype=torch.bfloat16, requires_grad=False, layout_tensor=PlainAQTLayout(data=tensor([[ 72, -40, -30,  ..., -13,  29,  30],
        [ 22, -58, -38,  ...,  15,  18,  22],
        [ 16, -72, -47,  ...,   7,  20,  24],
        ...,
        [ 25, -14, -55,  ..., -12,  52, -19],
        [-19,  23,  80,  ...,  16, -59,  28],
        [-26,  22, -23,  ...,  -7,  34, -24]], dtype=torch.int8)... , scale=tensor([0.0030, 0.0038, 0.0034,  ..., 0.0030, 0.0034, 0.0047],
       dtype=torch.bfloat16)... , zero_point=tensor([0, 0, 0,  ..., 0, 0, 0])... , layout_type=PlainLayoutType())), <function _int8_symm_per_token_reduced_range_quant at 0x751a4815fe20>)
@jerryzh168
Copy link
Contributor

  1. int8_dynamic_activation_int8_weight is quantizing activation dynamically, so you don't need to do calibration (that's for static activation quantization), but you can try to skip some sensitive layers by passing around a filter_fn(
    filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
    ) for quantize_ API
  2. int8_dynamic_activation_int8_weight is using _int8_symm_per_token_reduced_range_quant to dynamically quantization and weight is per axis (axis=1) quantized, because block_size=(1, 1024) and 1024 is the dimension for axis=1, also dtype=torch.bfloat16 doesn't mean the quantized tensor is quantized to bfloat16 actually, that is the source dtype, this is because we need to make autograd happy I think, here are some discussions on this specific topic: What should .dtype for tensor subclass return? #442

@chenghuaWang
Copy link
Author

Got it, thanks for the reply!

@chenghuaWang
Copy link
Author

I used the following code to test the performance of w8a8.

@torch.no_grad()
def generate(model, tokenizer, device, prompt, max_new_tokens):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
    start = time.time()
    outputs = model.generate(
        input_ids=inputs.input_ids.to(device),
        max_new_tokens=max_new_tokens,
        attention_mask=inputs.attention_mask.to(device),
        do_sample=True,
        top_k=50,
        top_p=0.9,
    )
    end = time.time()
    generated_text = tokenizer.decode(outputs[0])
    print(f"Generated '{generated_text}' in [{end - start:.2f} s]")

But I encountered performance issues. I tested on an Intel CPU. Under the same prompt, Huggingface FP16 takes 3 seconds to complete, but the quantized model takes 60 seconds to compute. Am I missing any steps?

@chenghuaWang chenghuaWang reopened this Oct 3, 2024
@jerryzh168
Copy link
Contributor

@chenghuaWang can you try running torch.compile(model, mode="max-autotune") before benchmark? also our current optimization efforts is mostly focused on CUDA I think

@chenghuaWang
Copy link
Author

Unfortunately, after using torch.compile, there was not much speed improvement; the inference time went from 60 seconds to 42 seconds. It is still much slower than the model using FP16.

also our current optimization efforts is mostly focused on CUDA I think

I will test it on some accelerators. Thank you for your answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants