Skip to content

Commit afe8b62

Browse files
committed
Add AffineQuantizedObserver
Summary: In our static_quant flow tutorial we were still using observers from `torch.ao` which we plan to deprecate, this PR adds a more general observer for `AffineQuantizedTensor`, and has shown that we can replace the old observers (min max observer), there could be futhre work to improve perf, add new types of observation, e.g. tracking stats other than just min/max, moving average observer, histogram observer. Test Plan: python test/quantization/test_observer.py python tutorials/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent 433cd14 commit afe8b62

File tree

4 files changed

+278
-20
lines changed

4 files changed

+278
-20
lines changed

test/quantization/test_observer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
from torch.testing._internal.common_utils import TestCase
3+
from torchao.quantization.observer import (
4+
AffineQuantizedMinMaxObserver,
5+
PerTensor,
6+
PerAxis,
7+
)
8+
from torchao.quantization.quant_primitives import (
9+
MappingType,
10+
)
11+
import unittest
12+
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
13+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
14+
15+
class TestQuantFlow(TestCase):
16+
def _test_obs_helper(self, obs1, obs2):
17+
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
18+
for example_input in example_inputs:
19+
obs1(example_input)
20+
obs2(example_input)
21+
22+
scale1, zero_point1 = obs1.calculate_qparams()
23+
scale2, zero_point2 = obs2.calculate_qparams()
24+
self.assertTrue(torch.allclose(scale1, scale2))
25+
self.assertTrue(torch.allclose(zero_point1, zero_point2))
26+
27+
def test_min_max_per_tensor_affine(self):
28+
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
29+
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
30+
self._test_obs_helper(obs, ref_obs)
31+
32+
def test_min_max_per_channel_affine(self):
33+
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
34+
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
35+
self._test_obs_helper(obs, ref_obs)
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main()

torchao/quantization/observer.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
from .quant_primitives import (
3+
_get_reduction_params,
4+
choose_qparams_affine_with_min_max,
5+
MappingType,
6+
ZeroPointDomain,
7+
)
8+
9+
from dataclasses import dataclass
10+
from typing import Callable, List, Tuple, Optional
11+
from functools import partial
12+
13+
@dataclass(frozen=True)
14+
class GranularityType:
15+
pass
16+
17+
@dataclass(frozen=True)
18+
class PerTensor(GranularityType):
19+
pass
20+
21+
@dataclass(frozen=True)
22+
class PerAxis(GranularityType):
23+
axis: int
24+
25+
# borrowed from torch.ao.quantization.observer
26+
class _PartialWrapper:
27+
def __init__(self, p):
28+
self.p = p
29+
30+
def __call__(self, *args, **keywords):
31+
return self.p(*args, **keywords)
32+
33+
def __repr__(self):
34+
return self.p.__repr__()
35+
36+
def with_args(self, *args, **kwargs):
37+
return _with_args(self, *args, **kwargs)
38+
39+
def _with_args(cls_or_self, *args, **kwargs):
40+
r"""Wrapper that allows creation of class factories.
41+
42+
This can be useful when there is a need to create classes with the same
43+
constructor arguments, but different instances.
44+
45+
Example::
46+
47+
>>> # xdoctest: +SKIP("Undefined vars")
48+
>>> Foo.with_args = classmethod(_with_args)
49+
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
50+
>>> foo_instance1 = foo_builder()
51+
>>> foo_instance2 = foo_builder()
52+
>>> id(foo_instance1) == id(foo_instance2)
53+
False
54+
"""
55+
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
56+
return r
57+
58+
def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]:
59+
if isinstance(granularity_type, PerTensor):
60+
return input_shape
61+
elif isinstance(granularity_type, PerAxis):
62+
block_size = list(input_shape)
63+
block_size[granularity_type.axis] = 1
64+
return tuple(block_size)
65+
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
66+
67+
class AffineQuantizedObserver(torch.nn.Module):
68+
with_args = classmethod(_with_args)
69+
70+
def __init__(self,
71+
update_stats: Callable[[Callable, torch.Tensor], None],
72+
calculate_qparams: Callable[[Callable], Tuple[torch.Tensor, torch.Tensor]],
73+
mapping_type: MappingType,
74+
target_dtype: torch.dtype,
75+
block_size: Optional[Tuple[int, ...]] = None,
76+
granularity_type: Optional[GranularityType] = None,
77+
quant_min: Optional[int] = None,
78+
quant_max: Optional[int] = None,
79+
eps: Optional[float] = None,
80+
scale_dtype: Optional[torch.dtype] = None,
81+
zero_point_dtype: Optional[torch.dtype] = None,
82+
preserve_zero: bool = True,
83+
zero_point_domain = ZeroPointDomain.INT,
84+
):
85+
"""
86+
"""
87+
super().__init__()
88+
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type"
89+
self._update_stats = update_stats
90+
self._calculate_qparams = calculate_qparams
91+
self.mapping_type = mapping_type
92+
self.target_dtype = target_dtype
93+
self.block_size = block_size
94+
self.granularity_type = granularity_type
95+
self.quant_min = quant_min
96+
self.quant_max = quant_max
97+
self.eps = eps
98+
self.scale_dtype = scale_dtype
99+
self.zero_point_dtype = zero_point_dtype
100+
self.preserve_zero = preserve_zero
101+
self.zero_point_domain = zero_point_domain
102+
103+
def forward(self, input: torch.Tensor) -> torch.Tensor:
104+
self._update_stats(self, input)
105+
return input
106+
107+
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
108+
return self._calculate_qparams(self)
109+
110+
def get_min_max_funcs():
111+
112+
def update_stats_min_max(self, input: torch.Tensor):
113+
if input.numel() == 0:
114+
return
115+
116+
input = input.detach()
117+
if self.block_size is None:
118+
self.block_size = get_block_size(input.shape, self.granularity_type)
119+
120+
shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input.size())
121+
input = input.view(shape_for_reduction)
122+
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
123+
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
124+
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
125+
self.min_val = min_val
126+
self.max_val = max_val
127+
else:
128+
min_val = torch.min(self.min_val, min_val)
129+
max_val = torch.max(self.max_val, max_val)
130+
self.min_val.copy_(min_val)
131+
self.max_val.copy_(max_val)
132+
133+
def calculate_qparams_min_max(self) -> Tuple[torch.Tensor, torch.Tensor]:
134+
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
135+
136+
return choose_qparams_affine_with_min_max(
137+
self.min_val,
138+
self.max_val,
139+
self.mapping_type,
140+
self.block_size,
141+
self.target_dtype,
142+
self.quant_min,
143+
self.quant_max,
144+
self.eps,
145+
self.scale_dtype,
146+
self.zero_point_dtype,
147+
self.preserve_zero,
148+
self.zero_point_domain
149+
)
150+
151+
return update_stats_min_max, calculate_qparams_min_max
152+
153+
_update_stats_min_max, _calculate_qparams_min_max = get_min_max_funcs()
154+
AffineQuantizedMinMaxObserver = AffineQuantizedObserver.with_args(_update_stats_min_max, _calculate_qparams_min_max)

torchao/quantization/quant_primitives.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"safe_int_mm",
2222
"int_scaled_matmul",
2323
"choose_qparams_affine",
24+
"choose_qparams_affine_with_min_max",
2425
"quantize_affine",
2526
"dequantize_affine",
2627
"fake_quantize_affine",
@@ -570,9 +571,51 @@ def choose_qparams_affine(
570571
zero_point_domain.name
571572
)
572573

574+
575+
def choose_qparams_affine_with_min_max(
576+
min_val: torch.Tensor,
577+
max_val: torch.Tensor,
578+
mapping_type: MappingType,
579+
block_size: Tuple[int, ...],
580+
target_dtype: torch.dtype,
581+
quant_min: Optional[int] = None,
582+
quant_max: Optional[int] = None,
583+
eps: Optional[float] = None,
584+
scale_dtype: Optional[torch.dtype] = None,
585+
zero_point_dtype: Optional[torch.dtype] = None,
586+
preserve_zero: bool = True,
587+
zero_point_domain = ZeroPointDomain.INT,
588+
) -> Tuple[torch.Tensor, torch.Tensor]:
589+
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
590+
operator that pass in min_val and max_val directly instead of deriving these from a single input.
591+
This is used for observers in static quantization where min_val and max_val may be obtained through
592+
tracking all the data in calibration data set.
593+
594+
Args:
595+
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
596+
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
597+
and then scale/zero_point, we pass in min_val/max_val directly
598+
"""
599+
return _choose_qparams_affine(
600+
None,
601+
mapping_type.name,
602+
block_size,
603+
target_dtype,
604+
quant_min,
605+
quant_max,
606+
eps,
607+
scale_dtype,
608+
zero_point_dtype,
609+
preserve_zero,
610+
zero_point_domain.name,
611+
min_val,
612+
max_val,
613+
)
614+
615+
573616
@register_custom_op
574617
def _choose_qparams_affine(
575-
input: torch.Tensor,
618+
input: Optional[torch.Tensor],
576619
mapping_type: str,
577620
block_size: List[int],
578621
target_dtype: torch.dtype,
@@ -583,23 +626,38 @@ def _choose_qparams_affine(
583626
zero_point_dtype: Optional[torch.dtype] = None,
584627
preserve_zero: bool = True,
585628
zero_point_domain: str = "INT",
629+
min_val: Optional[torch.Tensor] = None,
630+
max_val: Optional[torch.Tensor] = None,
586631
) -> Tuple[torch.Tensor, torch.Tensor]:
587632
"""op definition that has compatible signatures with custom op library
588633
"""
589634
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
590635
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
591636

592-
if scale_dtype is None:
593-
scale_dtype = input.dtype
594-
if zero_point_dtype is None:
595-
zero_point_dtype = input.dtype
637+
if input is not None:
638+
if scale_dtype is None:
639+
scale_dtype = input.dtype
640+
if zero_point_dtype is None:
641+
zero_point_dtype = input.dtype
642+
if eps is None:
643+
eps = torch.finfo(input.dtype).eps
596644

597-
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
598-
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
599-
input = input.view(shape_for_reduction)
645+
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
646+
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
647+
input = input.view(shape_for_reduction)
648+
649+
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
650+
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
651+
else:
652+
assert min_val is not None and max_val is not None, "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
653+
assert min_val.dtype == max_val.dtype, "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
600654

601-
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
602-
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
655+
if scale_dtype is None:
656+
scale_dtype = min_val.dtype
657+
if zero_point_dtype is None:
658+
zero_point_dtype = min_val.dtype
659+
if eps is None:
660+
eps = torch.finfo(min_val.dtype).eps
603661

604662
if preserve_zero:
605663
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
@@ -615,10 +673,12 @@ def _choose_qparams_affine(
615673
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
616674
if zero_point_domain != ZeroPointDomain.INT.name:
617675
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
676+
scale = torch.clamp(scale, min=eps)
618677
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
619678
else:
620679
assert mapping_type == MappingType.ASYMMETRIC.name
621680
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
681+
scale = torch.clamp(scale, min=eps)
622682
if preserve_zero:
623683
zero_point = quant_min - torch.round(min_val_neg / scale)
624684
zero_point = torch.clamp(zero_point, quant_min, quant_max)
@@ -627,8 +687,4 @@ def _choose_qparams_affine(
627687
mid_point = (quant_max + quant_min + 1) / 2
628688
zero_point = min_val_neg + scale * mid_point
629689

630-
if eps is None:
631-
eps = torch.finfo(input.dtype).eps
632-
scale = torch.clamp(scale, min=eps)
633-
634690
return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)

tutorials/calibration_flow/static_quant.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
import torch
55
import copy
66

7-
# TODO: use the generalized observer for affine qunatization in the future
8-
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
97
import torch.nn.functional as F
108
from torch import Tensor
119
from torchao.dtypes import to_affine_quantized_static
1210
from torchao.quantization.utils import compute_error
1311
from torchao.quantization import quantize_
1412
from torchao.quantization import to_linear_activation_quantized
1513
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
16-
14+
from torchao.quantization.observer import (
15+
AffineQuantizedMinMaxObserver,
16+
PerTensor,
17+
PerAxis,
18+
)
19+
from torchao.quantization.quant_primitives import (
20+
MappingType,
21+
)
1722

1823

1924
class ObservedLinear(torch.nn.Linear):
@@ -105,16 +110,20 @@ def forward(self, x):
105110
x = self.linear2(x)
106111
return x
107112

113+
torch.manual_seed(0)
114+
108115
dtype = torch.bfloat16
109116
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
117+
118+
m_for_test = copy.deepcopy(m)
119+
110120
m_bf16 = copy.deepcopy(m)
111121
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
112122

113123
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
114124

115-
# TODO: use the generalized observer for affine qunatization in the future
116-
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
117-
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")
125+
act_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
126+
weight_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
118127

119128
before_quant = m(*example_inputs)
120129

0 commit comments

Comments
 (0)