Skip to content

Commit 174cce6

Browse files
committed
Fix slice and padding for TensorCoreTiledLayout for int4 weight only quantization
Summary: Previously some of the code paths are not exercised, so the bug was not discovered but there are some bug related to slice operation and padding, basically scale and zero_point are not padded before, this results in errors when it is required. Test Plan: python test/dtypes/test_affine_quantized.py -k test_slice Reviewers: Subscribers: Tasks: Tags:
1 parent 0231a68 commit 174cce6

File tree

4 files changed

+59
-17
lines changed

4 files changed

+59
-17
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchao.core.config import AOBaseConfig
1818
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1919
from torchao.quantization import (
20+
Int4WeightOnlyConfig,
2021
Int8DynamicActivationInt8WeightConfig,
2122
float8_weight_only,
2223
int4_dynamic_activation_int4_weight,
@@ -307,6 +308,18 @@ def test_alias(self, device, dtype):
307308
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
308309
_ = dummy.weight[...]
309310

311+
@common_utils.parametrize("device", ["cuda"] if torch.cuda.is_available() else [])
312+
@common_utils.parametrize("dtype", [torch.bfloat16])
313+
def test_slice(self, device, dtype):
314+
# in_feature not divisible by 1024
315+
# out_feature not divisible by 8
316+
# to test slice + padding for int4 weight only quantization
317+
dummy = nn.Linear(256, 321, dtype=dtype, device=device)
318+
quantize_(dummy, Int4WeightOnlyConfig())
319+
# make sure these run without error
320+
_ = dummy.weight.narrow(0, 0, 64)
321+
_ = dummy.weight.narrow(1, 0, 128)
322+
310323

311324
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
312325
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def from_hp_to_intx(
284284
)
285285
# Note: output will be uint8 tensor for sub byte tensors for now
286286

287-
data = _layout.post_process(data)
287+
data, scale, zero_point = _layout.post_process(
288+
data, scale, zero_point, block_size
289+
)
288290
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
289291
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
290292
return cls(

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,30 @@ def pre_process_static(
153153
zero_point = torch.nn.functional.pad(zero_point, padding_changes)
154154
return input, scale, zero_point
155155

156-
def post_process(self, input: torch.Tensor) -> torch.Tensor:
156+
def post_process(
157+
self,
158+
input: torch.Tensor,
159+
scale: torch.Tensor,
160+
zero_point: torch.Tensor,
161+
block_size: Tuple[int, ...],
162+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
157163
orig_out_features, orig_in_features = input.shape
158164
in_features = find_multiple(orig_in_features, 1024)
159165
out_features = find_multiple(orig_out_features, 8)
160166
input = torch.nn.functional.pad(
161167
input,
162168
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
163169
)
164-
return input
170+
assert (
171+
len(block_size) == 2
172+
), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}"
173+
scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0]
174+
scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1]
175+
scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0))
176+
zero_point = torch.nn.functional.pad(
177+
zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0)
178+
)
179+
return input, scale, zero_point
165180

166181
def extra_repr(self):
167182
return f"inner_k_tiles={self.inner_k_tiles}"
@@ -335,31 +350,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
335350

336351
if func is aten.slice.Tensor:
337352
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
338-
if dim == 0:
339-
int_data, scale, zero_point = self.get_plain()
340-
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
341-
# this is to handle padding
342-
int_data = self._layout.post_process(int_data)
343-
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
344-
return return_and_correct_aliasing(func, args, kwargs, sliced)
345-
elif dim == 1:
353+
if dim in [0, 1]:
346354
int_data, scale, zero_point = self.get_plain()
347-
assert step == 1, "Only step == 1 is supported in slicing right now"
348355
data_len = int_data.shape[dim]
349356
scale_len = scale.shape[dim]
350357
ratio = data_len / scale_len
351358
start_scale = int(start / ratio)
352359
end_scale = int(end / ratio)
353360

354361
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
355-
# this is to handle padding
356-
int_data = self._layout.post_process(int_data)
357362
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
358363
zero_point = aten.slice.Tensor(
359364
zero_point, dim, start_scale, end_scale, step
360365
)
366+
# this is to handle padding
367+
int_data, scale, zero_point = self._layout.post_process(
368+
int_data, scale, zero_point, self.block_size
369+
)
361370
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
362-
return sliced
371+
return return_and_correct_aliasing(func, args, kwargs, sliced)
363372
else:
364373
raise NotImplementedError(
365374
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
@@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
371380

372381
__torch_function__ = torch._C._disabled_torch_function_impl
373382

383+
@property
384+
def block_size(self):
385+
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
386+
387+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
388+
cur_shape = self.shape
389+
assert len(cur_shape) == 4
390+
inner_k_tiles = cur_shape[-1] * 2
391+
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
392+
groupsize = int(original_shape[1] / scale.shape[-2])
393+
return (1, groupsize)
394+
374395
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
375396
from torchao.quantization.quant_primitives import (
376397
ZeroPointDomain,

torchao/dtypes/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,14 @@ class Layout:
4444
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
4545
return input
4646

47-
def post_process(self, input: torch.Tensor) -> torch.Tensor:
48-
return input
47+
def post_process(
48+
self,
49+
input: torch.Tensor,
50+
scale: torch.Tensor,
51+
zero_point: torch.Tensor,
52+
block_size: Tuple[int, ...],
53+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
54+
return input, scale, zero_point
4955

5056
def pre_process_static(
5157
self,

0 commit comments

Comments
 (0)