Skip to content

Very large discrepancy in the quantized model's output compared to the original model when quantizing on CPU #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
JohnnyRacer opened this issue Nov 23, 2024 · 9 comments

Comments

@JohnnyRacer
Copy link

JohnnyRacer commented Nov 23, 2024

Quantization on GPU works as expected with very small errors, but on CPU there seems to be a problem with the quantized model's output. Here is the code to replicate the problem.

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchao.quantization.quant_api import (
    quantize_,
    int4_weight_only,
)

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 30)
        self.relu = nn.ReLU()
        self.seq = nn.Sequential(nn.Linear(30,40), nn.ReLU())

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.seq(x)
        return x

model = TestModel()
cpu_quant_model = TestModel()

device = "cuda:0"

model.to(device)
cpu_quant_model.cpu()

test_input = torch.randn((10, 10), device=device)
original_output = model(test_input)

quantize_(model, int4_weight_only()) # Quantize the model on GPU

quanted_output = model(test_input)
print(F.mse_loss(original_output, quanted_output)) # Only a very small difference of 6.8689e-08 

quantize_(cpu_quant_model, int4_weight_only()) # Quantize the model on CPU

cpu_quanted_output = cpu_quant_model(test_input.cpu())
print(F.mse_loss(original_output, cpu_quanted_output.to(device))) # Getting a large difference of 0.0281 (Close to 50000 times larger error compared to the original?)
@jerryzh168
Copy link
Contributor

quant_model is not defined, not sure what you mean there,

but this might be known issue: #1117, that we are fixing in #1278 which will be landed soon

@JohnnyRacer
Copy link
Author

@jerryzh168 Sorry I changed up a few things during test and left out the line where the model was quantized on the CPU, but basically the model's output when quantizing on the CPU and GPU is significantly different. I don't think its related to #1117 since the difference is the same when executing the CPU quantized model on the CPU itself, instead of quantizing on the CPU and executing the model on the GPU.

image

@jerryzh168
Copy link
Contributor

how do you get the cpu_quant_model? int4_weight_only only works on CUDA IIRC.

@JohnnyRacer
Copy link
Author

quantize_(cpu_quant_model, int4_weight_only()) runs fine without any errors or warnings on the CPU. You can run the code snippet I provided above and it should show the difference. cpu_quant_model was always on the CPU and was never moved to the GPU.

@supriyar
Copy link
Contributor

supriyar commented Dec 4, 2024

@jerryzh168 is this expected behavior on CPU? Or a bug?

@jerryzh168
Copy link
Contributor

In the example code actually nothing is quantized I think, you can check by printing model.linear1.weight and see if it's AffineQuantizedTensor or not.

reason that it's not quantized is because int4_weight_only is using group_size 128 by default, which means the input channel for linear has to be multiples of 128 in order for the layer to be quantized, it also only works for bfloat16.

I fixed the above issue with a simpler example, you can repro the cpu error with the following code:

import torch
from torchao import quantize_
from torchao.quantization import int4_weight_only
import torch.nn.functional as F

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(128, 10).to(torch.bfloat16)

    def forward(self, x):
        return self.linear(x)

model = TestModel()
cpu_quant_model = TestModel()

device = "cuda:0"

model.to(device)
cpu_quant_model.cpu()

test_input = torch.randn((10, 128), device=device, dtype=torch.bfloat16)
original_output = model(test_input)

quantize_(model, int4_weight_only()) # Quantize the model on GPU

quanted_output = model(test_input)
print(F.mse_loss(original_output, quanted_output)) # Only a very small difference of 6.8689e-08

quantize_(cpu_quant_model, int4_weight_only()) # Quantize the model on CPU

cpu_quanted_output = cpu_quant_model(test_input.cpu())
print(F.mse_loss(original_output, cpu_quanted_output.to(device))) # Getting a large difference of 0.0281 \(Close to 50000 times larger error compared to the original?)

you will be able to see:

NotImplementedError: Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_convert_weight_to_int4pack' is only available for these backends: [CUDA, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* updates for 70b and gpu process monitor

* updates for 70b and gpu process monitor
@JohnnyRacer
Copy link
Author

@jerryzh168 What version of torch and ao are you using? This does not raise any errors for me. I have tried other quantization precision such as INT8 weights only, INT8 dynamic activation and FP6 and they all give the same quantization error.

@jerryzh168
Copy link
Contributor

cpu_quanted_output = cpu_quant_model(test_input.cpu())

mine is probably in recent torch and torchao, not exactly sure, but I think you should be able to repro in nightly for both.

can you paste the exact code that you use and errors? is the code in #1335 (comment) up to date?

@JohnnyRacer
Copy link
Author

JohnnyRacer commented Dec 13, 2024

Just realized the real reason was that in the snippet the models had actually two different state_dicts therefore it made it seem like there was an error. Loading the same state_dict solved the problem.

# Inited the model without the same state_dict
model = TestModel()
cpu_quant_model = TestModel()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants