Skip to content

INT4 XPU enabling #1577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Apr 10, 2025
Merged

INT4 XPU enabling #1577

merged 26 commits into from
Apr 10, 2025

Conversation

airMeng
Copy link
Collaborator

@airMeng airMeng commented Jan 17, 2025

The PR is a draft currently.

The PR will add 2 kinds of INT4 support on XPU: floating zero points and integer zero points, following the discussion in #1264.

Integer zero points which has been natively supported via OneDNN pytorch/pytorch#137566

Floating zero points, the default behaviour in this repo, supported by intel/torch-xpu-ops#1130, more implementations on the way.

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1577

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 00c742c with merge base 6726b0b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@airMeng airMeng marked this pull request as draft January 17, 2025 03:20
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
layout_list = []
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
layout_list.append(Int4CPULayout())
elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here as well, 2_6 or 2_7?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.7


__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

@jerryzh168 jerryzh168 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw for this one, we have some unpacking op for tensor core tiled layout that we should really be using:

m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout);
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout);

might be better to do the same instead of hacking with quantize ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I will give a check.

@jerryzh168
Copy link
Contributor

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@airMeng
Copy link
Collaborator Author

airMeng commented Jan 17, 2025

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@mingfeima @EikanWang can you comment?

@mingfeima
Copy link

btw why the op is added in pytorch/pytorch#137566 instead of in torchao? any plans to move it to torchao?

@mingfeima @EikanWang can you comment?

The situation is different for XPU (the intel GPUs) from CPU and CUDA here. Not sure that whether providing sycl or oneDNN xpu ops in ao is a feasible solution.

@airMeng airMeng force-pushed the xpu_int4 branch 2 times, most recently from 91067e2 to 895376f Compare February 24, 2025 01:40
@airMeng airMeng marked this pull request as ready for review February 26, 2025 09:53
@sunjiweiswift
Copy link
Contributor

@jerryzh168 pls review again·

_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
# TODO(#1690): delete this once config migration is done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vkuzo we can delete these now?

Comment on lines +211 to +219
if self.scale_and_zero is not None:
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout]
else:
return ["packed_weight", "scale", "zero"], [self.transposed, self._layout]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have two formats here? maybe should split into multiple layouts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

integer zp and floating zp
I don't split into 2 layout because from user side it will be confusing

current:

quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.Float))

but if different layouts

quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutIntZP(), zero_point_domain=ZeroPointDomain.INT))
quantize_(model, int4_weight_only(group_size=32, layout=Int4XPULayoutFloatZP(), zero_point_domain=ZeroPointDomain.Float))

I think the current implementation is more straightforward for users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but layout defines how we store the packed weights actually, using a single layout for multiple things is breaking this abstraction I feel

is the concern around specifying zero_point_domain multiple times? we could remove that and just infer the zero_point_domain from layout I think (the latter API)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only XPU supports integer zp, can I move it in the next PR?
layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only XPU supports integer zp, can I move it in the next PR?

what is this referring to?

layout defines how we store the packed weights actually it should include the layout of scales and zeros, right?

yeah that's correct, ideally I think we should not use layout to control whether we have packed weight / scale_and_zero / scale, zero, the duplication should actually happen in the tensor level (we create different tensor subclass tensors), not layout. feel free to go that route if want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I separate into different layouts, and bind the zero point domain into each layout in the next PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure

@@ -242,6 +247,11 @@ def matmul(self, x):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
if is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7 \
and not isinstance(self.scales_and_zeros, torch.Tensor):
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
Copy link
Contributor

@jerryzh168 jerryzh168 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to match line 550 in GPTQ.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed hqq support in this PR to simply the logic

@@ -546,6 +546,14 @@ def linear_forward_int4(
groupsize,
scales_and_zeros.to(scales_precision),
).to(dtype=x.dtype)
elif is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7:
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why do we have this function? can the slicing (scales_and_zeros[0] and scales_and_zeros[1]) be done in the _weight_int4pack_mm itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed GPTQ support in this PR to simply the logic. I will open another PR after I seperate the layouts

@airMeng airMeng requested a review from jerryzh168 March 19, 2025 06:46
@@ -166,6 +167,10 @@ def process_hqq_quants(self, W_q, meta):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
if is_device(W_q.device.type, "Xpu") and TORCH_VERSION_AT_LEAST_2_7:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "xpu"?

@@ -407,6 +407,8 @@ def _quantize_affine_no_dtype_cast(
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
if shape_after_reduction[0] == 12288:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remve?

@@ -954,6 +956,7 @@ def _choose_qparams_affine(
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
zero_point_dtype = torch.int32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this set here?

Copy link
Collaborator Author

@airMeng airMeng Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact preserve_zero and INT zero point domain couples here, I think it is duplicated someway
The reason for setting this parameter as an int is that many places calling this function use the default floating parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserve_zero talks about whether zero (in original floating point domain) is exactly representable or not, it's not coupled with zero point domain I think, even zero is exactly representation, we can still choose zero_point_domain to be in float

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, from the math side they are not related. but the code here implies this, see the condition dispatch from Line954 to 966. We need a refactor here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to do assert instead of changing condition if there is coupling?

@@ -315,7 +316,7 @@ def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float3
return dequantized


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16, zero_point_domain_is_int=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm wondering if we should just expose zero_point_domain as an arg directly here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which way do you prefer? how about

def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, data_dtype=torch.bfloat16, scale_dtype=torch.bfloat16, zero_dtype=torch.bfloat16):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this sounds good. what about preserve_zero and zero_point_domain? I don't think these can be fully inferred?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to be updated? I feel it's a bit weird to introduce a boolean flag for zero_point_domain when we can just pass zero_point_domain itself around

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

287d069 is it okay?

@@ -850,6 +860,7 @@ def _int4_weight_only_transform(
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"

preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] if zero_point_domain!=ZeroPointDomain.INT else True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does zero_point_dtype need to change after this is set

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the above. In fact preserve_zero and INT zero point domain couples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are three things, preserve_zero = {True, False}, zero_point_domain = {FLOAT, INT, NONE} and zero_point_dtype = {float, int, ...}

it's true that not all combinations are valid, but I don't think they are coupled, see

preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
, zero_point_domain might be something we can remove though. cc @jainapurva

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about 946f530, expose it as an independent argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this makes sense

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine for now, but at some point we will probably add new tensor subclass tensors for these

@airMeng airMeng force-pushed the xpu_int4 branch 3 times, most recently from e8357d5 to b3d985d Compare April 9, 2025 05:02
Signed-off-by: Meng, Hengyu <[email protected]>
@airMeng airMeng force-pushed the xpu_int4 branch 3 times, most recently from 4a7bb7b to 075a34a Compare April 9, 2025 07:40
remove zero_point_dtype assigning

Signed-off-by: Meng, Hengyu <[email protected]>

fix import lint

enable zp dtype: u8/s8/s16/s32/s64

Signed-off-by: Meng, Hengyu <[email protected]>
Signed-off-by: Meng, Hengyu <[email protected]>
@@ -676,6 +676,18 @@ def is_sm_at_least_100():
)


def check_cpu_version(device, version="2.6.0"):
Copy link
Contributor

@jerryzh168 jerryzh168 Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel a more descriptive name might be better, this is a bit vague, it's a boolean, so the name should be is_xxx, also version arg seems to be not used in any of the callsite, it can be removed I think

e.g. is_cpu_device_and_after_torch_2_6, similar for the xpu check

Copy link
Collaborator Author

@airMeng airMeng Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also version arg seems to be not used in any of the callsite, it can be removed I think

There might be another check, for example, checking cpu and torch version 2.8. And there are a lot of this kind of checks especially on CUDA: CUDA+2.4, CUDA+2.5, so on

CUDA+2.6

not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
CUDA+2.4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking cuda is fine, but this check is specific for cpu and xpu right, and related to the change of the int4mm operator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking cpu and torch version 2.8

when do we need this?

Copy link
Collaborator Author

@airMeng airMeng Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if TORCH_VERSION_AT_LEAST_2_6:
assert (
int_data.dtype == torch.int32
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
int_data,
1, # TODO:remove
)
elif TORCH_VERSION_AT_LEAST_2_5:

sorry, should be 2_6 and 2_5. So even for the same device, there should be check for different version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is getting too complicated I feel, can we just drop the support for some pytorch versions?

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine if you want to keep the version number, but still it would be good to change the function name to make it clearer I think, also this discussion doesn't have to block the PR, please feel free to merge and fix later

Copy link
Collaborator Author

@airMeng airMeng Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is getting too complicated I feel, can we just drop the support for some pytorch versions?

there are still regression tests in current CI, as early as torch 2.3. But I agree we should only support the latest version since AO is an experimental(?) innovation project

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we're only committed to support the most recent 2 version of pytorch, but here I meant drop support for CPU layout for older pytorch versions and just keep one. torchao started as an experimental project but now it's more official.

@@ -744,8 +748,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (
is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6
if (not (check_cpu_version(input.device))) and (
Copy link
Contributor

@jerryzh168 jerryzh168 Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my comment was actually meant to say that we can encapsulate these format changes to some function input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) (together with the device and version checks) so we don't need to have these across the codebase since it's pretty error prone and we are not sure when to use them

but this can be done in a separate PR

@airMeng
Copy link
Collaborator Author

airMeng commented Apr 10, 2025

@jerryzh168 if no more comment, I will squash and merge sorry I missed your previous comments, please give review again.

@airMeng airMeng merged commit df46e7a into pytorch:main Apr 10, 2025
18 checks passed
@airMeng airMeng deleted the xpu_int4 branch April 10, 2025 04:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants