Skip to content

int4wo can't use same packed weight for cpu and cuda #1117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
HDCharles opened this issue Oct 18, 2024 · 20 comments
Closed

int4wo can't use same packed weight for cpu and cuda #1117

HDCharles opened this issue Oct 18, 2024 · 20 comments

Comments

@HDCharles
Copy link
Contributor

HDCharles commented Oct 18, 2024

This is mostly to keep track of this problem which has been around for a while

if you ever do something like 1)quantize cpu model with int4, 2)move it to cuda
then the output of the model will be nonsense.

e.g. if in https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py#L231

you did

quantize_(model.cpu(), int4_weight_only(group_size=groupsize))
model.cuda()

the output of hte model is nonsensical

Hello, my name is♠ zewnętrz zewnętrz@{ zewnętrz zewnętrz zewnętrz))] ord zewnętrzŻ zewnętrz zewnętrz zewnętrz zewnętrzŻ zewnętrz zewnętrz Хронологи

because it simply changes same packed weight from cpu to cuda without addressing teh fact that the format is numerically different for each backend

https://github.com/pytorch/pytorch/blob/912ea5601bb3e7d360202927cb2de1ddc1d72cf6/aten/src/ATen/native/native_functions.yaml#L4144-L4148

despite the different packing paths there's no metadata to detect which backend packing algorithm was actually used so can't even error out intelligently.

We could manually keep track of this in affine quantized tensor and having code to unpack and repack if someone calls .to(device) but it doesn't fully solve the issue because again, we can't detect it. Users can do stuff like serialize the model on cuda, reload on cpu and we're in the same situation because when you try to do .cuda() you would want to unpack->repack but would use the cpu unpacking which wont work since hte original packing was done on cuda. You'd have to further add a field to keep track of which device the packed weight was most recently packed in and if someone tries to do .to("device") you have to check what the original device was, and if its different from the current device then you first move it before the unpack->repack. We should either implement such a solution or identify whether this is going to be rectified in some other way.

small repro:

import torch
import torchao
from torchao.quantization import quantize_, int4_weight_only
import copy
model = torch.nn.Linear(1024, 1024, dtype=torch.bfloat16)
input = torch.randn(1024, 1024, device="cuda", dtype=torch.bfloat16)
model_q_cpu=copy.deepcopy(model)
model_q_cuda=copy.deepcopy(model.cuda())
quantize_(model_q_cpu, int4_weight_only())
quantize_(model_q_cuda, int4_weight_only())
out=model_q_cpu.to("cuda")(input) # AQT actually doesn't let you run the model on cpu
out2=model_q_cuda(input)
print(out-out2)
@leslie-fang-intel
Copy link
Collaborator

cc @yanbing-j

@yanbing-j
Copy link
Contributor

int4 woq indeed cannot use the same packed weight for CPU (AVX2/AVX512) and CUDA because of the different packed methods between different ISAs and devices.

We raised a PR pytorch/pytorch#129940, which uses common serialized layout ([n][k/2] uint8) for different devices or ISAs as the input weight of _convert_weight_to_int4pack, and each backend can choose how to interpret as compute layout. Therefore, the conversion of packed weight is postpone after model is loaded to a specific device, e.g, mode = model.to(device), then do the _convert_weight_to_int4pack in quantize_. And the serialized weight ([n][k/2] uint8) can be shared in different test machines, without re-generating in one certain platform.

@jerryzh168
Copy link
Contributor

@yanbing-j that does not help with solving the problem I think, what we want is to do a device conversion after _convert_weight_to_int4pack, I'm wondering why do you use the same op, but produce numerically different tensors for cpu and cuda? would it be easier for cpu to just use a different packing and quantized linear op?

@yanbing-j
Copy link
Contributor

@jerryzh168
I suppose the different packed methods are specific in devices and ISAs to achieve the best performance.

cc @mingfeima Mingfei, do you have any other comments about this issue?

@jerryzh168
Copy link
Contributor

If the packing is specific to device, I feel the most natural way to structure it is to define them with different layout? packing format (or layout) in principle should not be specific to device actually, since it just talks about how we rearrange the tensor value, people should be able to implement the same algorithm in different devices, even you only use that format on a specific device

@yanbing-j
Copy link
Contributor

cc @mingfeima @jgong5

@mingfeima
Copy link

@jerryzh168 i agree that packing format is non specific to device, that's exactly why we want to decouple it.

the thing is , on CUDA, int4 is packed to a very special format: [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype) https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/int4mm.cu#L1284

i assume that your concept of non-device-specific-packing-layout refers to n][k / 2] (uint8 dtype). this is what we propose to how to save the int4 model weights. During model init, when which device to use is settled, it is coverted to device specific layout, which is in CUDA case - n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype); and in CPU case - [N/64, K, 32].

@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 28, 2024

@mingfeima I think using [n][k / 2] (uint8 dtype) as a packing format and have additional conversions during model init is not the UX we are looking for. with a real packing format, the UX should just be:

quantize:

# pack the weights to a format that can be consumed by kernel
quantize(model, ...)
# save the weights
torch.save(model.state_dict())

inference

model_weights = torch.load(...)
# note: no additional conversions
model.load_state_dict(model_weights)

I believe the introduction of [n][k / 2] (uint8 dtype) is a hack to combine cpu and cuda packing format. but it's not a requirement that cpu and cuda kernels have to use the same packing format (layout), so I'm proposing for cpu to just have it's own packing format in the very beginning ([N/64, K, 32]) without this additional indirection of [n][k / 2] (uint8 dtype)

Basically:
cuda (TensorCoreTiledLayout): [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
cpu (IntCPULayout, naming TBD): [N/64, K, 32]

conversion between these layout should be explicit in the user code, or we can have a util for that I think.

would this work for you?

@mingfeima
Copy link

I am afraid this won't work. This is actually the first version of my int4 weight only quantization patch on CPU, I tried to use different output shape for CPU and CUDA. See pytorch/pytorch@30befa5#diff-6dffed1ade0ba3e887f9a4eafa3bfcec267ab2365b8adcb91bd391f49b3fd2e3R3429

the problem comes from torch/_meta_registrations.py which requires OP has identical output shape for multiple devices (Maybe I am not precise here, I can't remember the details, has been too long).

@jerryzh168
Copy link
Contributor

the problem comes from torch/_meta_registrations.py which requires OP has identical output shape for multiple devices (Maybe I am not precise here, I can't remember the details, has been too long).

so it's reasonable that the same op should output the same shape, but I'm talking about different ops here though, would that work? is "CPU and CUDA using the same packing op" a requirement?

@mingfeima
Copy link

if you think using different packing ops for CPU and CUDA is OK, it certainly works for us.

@jerryzh168
Copy link
Contributor

@mingfeima yeah I think so, could you or @yanbing-j help with refactoring cpu packing and int4mm to different ops?

@yanbing-j
Copy link
Contributor

Hi @jerryzh168 ,

As for your request of seperating cpu packing and cuda packing into different ops, we can do this like the following figure. This figure is not using [n, k / 2] uint8, use [n, k] int32 back as the input of the convert function, and output is [n / 64, k, 32] uint8 in CPU.
image

However, this behavior still cannot solve the initial repro of this issue, to quantize_ first and change device. Actually, it is never possible to support the behavior of quantize_ first and change device. I'm confused about this and the request of seperating packing functions. Could you please give some more comments? Thanks!

cc @mingfeima

@jerryzh168
Copy link
Contributor

@yanbing-j

yeah the proposed format for cpu packed weight looks good, also the packing op should be talking about layout instead of device I think.

However, this behavior still cannot solve the initial repro of this issue, to quantize_ first and change device. Actually, it is never possible to support the behavior of quantize_ first and change device. I'm confused about this and the request of seperating packing functions. Could you please give some more comments? Thanks!

Yeah it's correct that we are never going to support just changing the packed weight for tinygemm from cuda to cpu and expect that to work for cpu. The proposed API is that:

We have two layouts:

layout1: TensorCoreTiledLayout (for tinygemm, only available for cuda device)
layout2 (name TBD): Int4CPULayout (for cpu int4 weight only linear op, only available for cpu device)
model = quantize_(model, int4_weight_only(layout=TensorCoreTiledLayout(...))
model.weight is an AffineQuantizedTensor with TensorCoreTiledLayout, on cuda device

model = quantize_(model, int4_weight_only(layout=Int4CPULayout(...))
model.weight is an AffineQuantizedTensor with Int4CPULayout, on cpu device


def move_int4_weight_from_cuda_to_cpu(affine_quantized_tensor):
    assert isinstance(affine_quantized_tensor.layout, TensorCoreTiledLayout)
    assert isinstance(affine_quantized_tensor.device, "cuda")
    # this will unpack TensorCoreTiledLayout packed weights in cuda, move the data to cpu, and repack in cpu
    affine_quantized_tensor = affine_quantized_tensor.to(_layout=Int4CPULayout(...), device="cpu")
    return affine_quantized_tensor
    
user can use `move_int4_weight_from_cuda_to_cpu` util to move from TensorCoreTiledLayout to Int4CPULayout and cuda device to cpu device.

@yanbing-j
Copy link
Contributor

@jerryzh168 Thanks for the confirmation!

I draft a PR pytorch/pytorch#139611 to split int4wo weight packing. For convience, I set the input weight of _convert_weight_to_int4pack_for_cpu to [n, k] int32, output is [n, k / 2] uint8. The input packed weight of _weight_int4pack_mm_for_cpu is [n, k / 2] uint8. It is a little bit different between the output of _convert_weight_to_int4pack_for_cpu and the previous mentioned [n / 64, k, 32]. Is this Okay for you?

cc @mingfeima .

@jerryzh168
Copy link
Contributor

@yanbing-j thanks, why do you use [n, k/2] as the shape of packed weight instead of [n/64, k, 32]? doesn't this incur extra performance cost since it's not the final format that's needed by the weight_int4pack_mm op?

@yanbing-j
Copy link
Contributor

@jerryzh168 I suppose [n, k / 2] is the packed format, [n/64, k, 32] is the compute format. As shown in
pytorch/pytorch@5391b4b#diff-db2b91c51cffd1fba933933ac20de256f5f0ced095890a1a49a17d73037001a2L3412, the version at first is [n, k / 2] and then changes to 4D like CUDA format. I think there is no extra performance cost.

cc @mingfeima Do you have any other comments?

@jerryzh168
Copy link
Contributor

OK, if there is no extra cost then that's fine, just want to say that we don't need [n, k/2] to be able to align with other formats or anything, so you are free to choose the packed shape for cpu layout here, whatever is the best perf/most convenient for you would make sense I think

@mingfeima
Copy link

@jerryzh168 I suppose [n, k / 2] is the packed format, [n/64, k, 32] is the compute format. As shown in pytorch/pytorch@5391b4b#diff-db2b91c51cffd1fba933933ac20de256f5f0ced095890a1a49a17d73037001a2L3412, the version at first is [n, k / 2] and then changes to 4D like CUDA format. I think there is no extra performance cost.

cc @mingfeima Do you have any other comments?

sounds good to me ~

@yanbing-j
Copy link
Contributor

@jerryzh168 @mingfeima Thanks for the comments!

Please review pytorch/pytorch#139611 and I will fix CI failures simultaneously. And will also include Nikita and Sanchit when PR is ready.

zhangxiaoli73 pushed a commit to zhangxiaoli73/pytorch that referenced this issue Nov 13, 2024
Fixes pytorch/ao#1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment).

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: pytorch#139611
Approved by: https://github.com/jerryzh168
zero000064 pushed a commit to zero000064/pytorch that referenced this issue Nov 14, 2024
Fixes pytorch/ao#1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment).

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: pytorch#139611
Approved by: https://github.com/jerryzh168
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this issue Dec 2, 2024
Fixes pytorch/ao#1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment).

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: pytorch#139611
Approved by: https://github.com/jerryzh168
pobin6 pushed a commit to pobin6/pytorch that referenced this issue Dec 5, 2024
Fixes pytorch/ao#1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment).

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: pytorch#139611
Approved by: https://github.com/jerryzh168
fmo-mt pushed a commit to fmo-mt/pytorch that referenced this issue Dec 11, 2024
Fixes pytorch/ao#1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on pytorch/ao#1117 (comment).

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: pytorch#139611
Approved by: https://github.com/jerryzh168
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants