Skip to content

Expose zero_point_domain as arguments #1401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 17, 2024
Merged
18 changes: 15 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
)


def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
Expand All @@ -36,6 +38,14 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
if int4_zp_int:
base_functions.append(
int4_weight_only(
group_size=32,
layout=Int4CPULayout(),
zero_point_domain=ZeroPointDomain.INT,
)
)
else:
base_functions.append(int4_weight_only(group_size=32))

Expand Down Expand Up @@ -71,7 +81,9 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
Expand Down
161 changes: 100 additions & 61 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def check_idempotent(self, fn, *args, **kwargs):


# Legacy tinygemm ops
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
def _get_groupwise_affine_qparams(
w,
n_bit=4,
groupsize=128,
dtype=torch.bfloat16,
zero_point_domain=ZeroPointDomain.FLOAT,
):
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
Expand All @@ -70,21 +76,25 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
quant_min = 0
quant_max = max_int
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)
if zero_point_domain == ZeroPointDomain.FLOAT:
zeros = min_val + scales * (2 ** (n_bit - 1))
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
else:
zeros = quant_min - torch.round(min_val / scales)
zeros = torch.clamp(zeros, quant_min, quant_max)
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
scales = scales.to(dtype=dtype).reshape(w.shape[0], -1)
return scales, zeros


def _groupwise_affine_quantize_tensor_from_qparams(
w,
scales,
zeros,
n_bit=4,
groupsize=128,
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
assert n_bit == 4
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
Expand All @@ -97,17 +107,28 @@ def _groupwise_affine_quantize_tensor_from_qparams(

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
min_val = zeros - scales * (2 ** (n_bit - 1))
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
else:
w_int4x8 = (
to_quant.div(scales)
.round()
.add(zeros)
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

if TORCH_VERSION_AT_LEAST_2_5:
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
Expand All @@ -121,6 +142,7 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
Expand All @@ -133,12 +155,15 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
else:
w_dq = w_int4x8_grouped.sub(zeros).mul(scales).reshape_as(w_int4x8)
return w_dq


Expand Down Expand Up @@ -650,10 +675,8 @@ def test_not_preserve_zero_not_supported(self):
def test_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16
)

zero_point_domains = [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
Expand All @@ -662,19 +685,27 @@ def test_get_groupwise_affine_qparams(self):
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)
for zero_point_domain in zero_point_domains:
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input,
n_bit=n_bit,
groupsize=128,
dtype=torch.bfloat16,
zero_point_domain=zero_point_domain,
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
zero_point_domain=zero_point_domain,
)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
Expand All @@ -686,14 +717,15 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)

self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))
self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))

def test_groupwise_affine_dequantize_tensor_from_qparams(self):
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
Expand All @@ -702,20 +734,27 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (
is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
else:
if zero_point_domain == ZeroPointDomain.INT:
continue
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

Expand Down
13 changes: 13 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ We also have a unified quantized tensor subclass that implements how to get a qu
#### Layouts
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.

### Zero Point Domains
```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py).
The following support matrix illustrates the relationship between layouts and zero point domains, which may be updated with backend changes:

|Layout|None(Symmetric)|Float|Int|
|------|---------------|-----|---|
|TensorCoreTiledLayout| Yes | Yes(Default) | No|
|Int4CPULayout | Yes | Yes(Default) | No |
|MarlinSparseLayout | No | No | Yes(Default) |


### Full Affine Quantization Flow Example
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
as an example:
Expand Down Expand Up @@ -239,6 +250,8 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
group_size = 32
# only works for torch 2.4+
quantize_(m, int4_weight_only(group_size=group_size))
## If different zero_point_domain needed
# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT)

# temporary workaround for tensor subclass + torch.compile
# NOTE: this is only need for torch version < 2.5+
Expand Down
37 changes: 32 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchao.dtypes import (
AffineQuantizedTensor,
Float8Layout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
PlainLayout,
Expand Down Expand Up @@ -110,6 +111,19 @@
"Int8DynActInt4WeightGPTQQuantizer",
]

# update according to the support matrix
LAYOUT_TO_ZERO_POINT_DOMAIN = {
TensorCoreTiledLayout: [ZeroPointDomain.FLOAT],
MarlinSparseLayout: [ZeroPointDomain.INT],
Int4CPULayout: [ZeroPointDomain.FLOAT],
}

LAYOUT_TO_PRESERVE_ZEROS = {
TensorCoreTiledLayout: False,
MarlinSparseLayout: True,
Int4CPULayout: False,
}


######
# TO BE DEPRECATED START
Expand Down Expand Up @@ -630,7 +644,10 @@ def int8_dynamic_activation_int4_weight(


def int4_weight_only(
group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
group_size=128,
layout=TensorCoreTiledLayout(inner_k_tiles=8),
use_hqq=False,
zero_point_domain=None,
):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please update docs for zero_point_domain before landing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update in 51a4505

Expand All @@ -650,6 +667,7 @@ def int4_weight_only(
size is more fine grained, choices are [256, 128, 64, 32]
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
`use_hqq`: whether to use hqq or default quantization mode, default is False
`zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
"""

def apply_int4_weight_only_quant(weight):
Expand All @@ -665,17 +683,26 @@ def apply_int4_weight_only_quant(weight):
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

nonlocal zero_point_domain
assert (
type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
if zero_point_domain is None:
# the first value is the default one
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
else:
assert (
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"

# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
# we should consider moving this logic somewhere else.
if isinstance(layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
assert (
group_size == 128 or group_size == weight.shape[-1]
), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
Expand Down
Loading