Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn.functional as F
from parameterized import parameterized
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

from torchao import quantize_
Expand Down Expand Up @@ -40,7 +41,6 @@
Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
Expand All @@ -53,12 +53,16 @@
MappingType,
TorchAODType,
ZeroPointDomain,
choose_qparams_affine,
dequantize_affine,
fake_quantize_affine,
quantize_affine,
)
from torchao.quantization.unified import (
TwoStepQuantizer,
)
from torchao.quantization.utils import (
_get_per_token_block_size,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
Expand Down Expand Up @@ -134,12 +138,13 @@ def forward(self, x):


class M4(torch.nn.Module):
def __init__(self):
def __init__(self, dtype: torch.dtype = torch.float32):
super().__init__()
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
self.dtype = dtype
self.linear = torch.nn.Linear(512, 256, bias=False).to(dtype)

def example_inputs(self):
return (torch.randn(1, 512).to(torch.float),)
return (torch.randn(1, 512).to(self.dtype),)

def forward(self, x):
return self.linear(x)
Expand Down Expand Up @@ -219,30 +224,41 @@ def test_fake_quantize_per_token(self):
torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
x2 = copy.deepcopy(x)
# TODO: use torch.ops.aten.quantized_decomposed version instead
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
block_size = _get_per_token_block_size(x)
(s, zp) = choose_qparams_affine(
x,
mapping_type=MappingType.ASYMMETRIC,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=torch.float32,
zero_point_dtype=torch.int32,
)

# fake quant op
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
out.sum().backward()

# compare against PTQ ops
out_ptq = torch.ops.quantized_decomposed.quantize_per_token(
out_ptq = quantize_affine(
x2,
block_size,
s,
zp,
torch.int8,
qmin,
qmax,
torch.int8,
)
out_ptq = torch.ops.quantized_decomposed.dequantize_per_token(
out_ptq = dequantize_affine(
out_ptq,
block_size,
s,
zp,
torch.int8,
qmin,
qmax,
torch.int8,
torch.float32,
output_dtype=torch.float32,
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

Expand Down Expand Up @@ -1004,8 +1020,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
"""
# activations
(s, zp) = _choose_qparams_per_token_asymmetric(
x, torch.float32, torch.int32
(s, zp) = choose_qparams_affine(
x,
mapping_type=MappingType.ASYMMETRIC,
block_size=_get_per_token_block_size(x),
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=torch.float32,
zero_point_dtype=torch.int32,
)
(qmin, qmax) = _get_qmin_qmax(8)
x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax)
Expand Down Expand Up @@ -1427,10 +1450,11 @@ def test_qat_linear_bias(self):
example_inputs = m.example_inputs()
m(*example_inputs)

@parameterized.expand([torch.float32, torch.bfloat16, torch.float16])
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_fake_quantize_per_token_vs_convert(self):
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
"""
Test that the following produce the exact same numerics:
1. FakeQuantizer with asymmetric per_token config
Expand All @@ -1439,17 +1463,18 @@ def test_fake_quantize_per_token_vs_convert(self):
from torchao.quantization.utils import per_token_dynamic_quant

torch.manual_seed(self.SEED)
x = torch.randn(1, 235, 2048)
x = torch.randn(1, 235, 2048).to(dtype)
config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
fake_quantizer = FakeQuantizer(config)
fake_quantizer_out = fake_quantizer(x)
baseline_out = per_token_dynamic_quant(x)
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)

@parameterized.expand([torch.float32, torch.bfloat16, torch.float16])
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_8da4w_prepare_vs_convert(self):
def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
"""
Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
numerics that match exactly over N trials.
Expand All @@ -1463,7 +1488,7 @@ def test_qat_8da4w_prepare_vs_convert(self):

for seed in range(self.SEED, self.SEED + num_trials):
torch.manual_seed(seed)
m = M4()
m = M4(dtype)
torch.manual_seed(seed)
x = m.example_inputs()

Expand Down
19 changes: 14 additions & 5 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from torchao.quantization.quant_primitives import (
_DTYPE_TO_BIT_WIDTH,
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
choose_qparams_affine,
)
from torchao.quantization.utils import (
_get_per_token_block_size,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
)
Expand All @@ -26,7 +29,6 @@
FakeQuantizeConfig,
)
from .utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
)
Expand Down Expand Up @@ -69,13 +71,19 @@ def _per_token_forward(self, x: torch.Tensor):
"""
if self.config.is_symmetric:
raise NotImplementedError("Symmetric per token is not supported yet")

qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
if self._should_compute_qparams():
(self.scale, self.zero_point) = _choose_qparams_per_token_asymmetric(
self.scale, self.zero_point = choose_qparams_affine(
x,
self.config.scale_precision,
self.config.zero_point_precision,
mapping_type=MappingType.ASYMMETRIC,
block_size=_get_per_token_block_size(x),
target_dtype=self.config.dtype,
quant_min=qmin,
quant_max=qmax,
scale_dtype=self.config.scale_precision,
zero_point_dtype=self.config.zero_point_precision,
)
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)

def _per_channel_or_group_forward(self, x: torch.Tensor):
Expand All @@ -100,6 +108,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
raise ValueError("Unexpected granularity '%s'" % granularity)

# get scales and zero points
# TODO: refactor this to use `choose_qparams_affine`
if self._should_compute_qparams():
bit_width = _DTYPE_TO_BIT_WIDTH[self.config.dtype]
if is_symmetric:
Expand Down
54 changes: 2 additions & 52 deletions torchao/quantization/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple
from typing import List

import torch

Expand Down Expand Up @@ -126,9 +126,8 @@ def _fake_quantize_per_token(

_per_token_quant_qparam_dim_check(input, scales, zero_points)
block_size = _get_per_token_block_size(input)
fq_input = input.to(torch.float32)
fq = _GenericFakeQuantize.apply(
fq_input,
input,
block_size,
scales,
zero_points,
Expand All @@ -138,55 +137,6 @@ def _fake_quantize_per_token(
return fq.reshape_as(input).to(input.dtype)


# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
# The version in pytorch does not have backward support yet so we add
# it here for now until https://github.com/pytorch/pytorch/pull/123452
# is landed.
def _choose_qparams_per_token_asymmetric(
input: torch.Tensor,
scales_precision: torch.dtype = torch.float32,
zero_points_precision: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)

Args:
input (torch.Tensor): original float32/float16 Tensor
scales_precision (torch.dtype): precision of returned scales
zero_points_precision (torch.dtype): precision of returned zero points

Returns:
scales and zero_points, both float32 Tensors
"""
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
qmin, qmax = -128, 127
min_val = torch.amin(input, dim=-1, keepdim=True)
max_val = torch.amax(input, dim=-1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
eps = torch.finfo(torch.float32).eps # use xnnpack eps?

# scale
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = scale.clamp(min=eps)

# zero point
descaled_min = min_val_neg / scale
descaled_max = max_val_pos / scale
zero_point_from_min_error = qmin + descaled_min
zero_point_from_max_error = qmax + descaled_max
zero_point = torch.where(
zero_point_from_min_error + zero_point_from_max_error > 0,
qmin - descaled_min,
qmax - descaled_max,
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()

return scale.to(scales_precision), zero_point.to(zero_points_precision)


def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
if symmetric:
qmin = -(2 ** (n_bit - 1))
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def get_group_qparams_symmetric(
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"

block_size = (1, groupsize)
eps = torch.finfo(torch.float32).eps
eps = torch.finfo(w.dtype).eps
ranges = {}
ranges[1] = (-1, 0)
# generating ranges for bit 2 to 8
Expand Down
Loading