Skip to content

Commit ecfb650

Browse files
committed
Add generic fake quantized linear for QAT
Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w ghstack-source-id: 07c17d3 Pull Request resolved: #1020
1 parent d4b2f33 commit ecfb650

File tree

10 files changed

+903
-211
lines changed

10 files changed

+903
-211
lines changed

test/integration/test_integration.py

+2
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def test_swap(self):
328328
assert torch.allclose(y_ref, y)
329329

330330
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
331+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
331332
def test_weight_t_and_non_t_numerics_match(self):
332333
# verify that numerics match whether weight is stored
333334
# in transposed format (for cuBLAS) vs non-transposed format
@@ -1126,6 +1127,7 @@ def test_shape_logger(self):
11261127
class SmoothquantIntegrationTest(unittest.TestCase):
11271128
@torch.no_grad()
11281129
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1130+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
11291131
def test_non_dynamically_quantizable_linear(self):
11301132
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
11311133
self.skipTest("test requires SM capability of at least (8, 0).")

test/quantization/test_qat.py

+293-43
Large diffs are not rendered by default.

torchao/quantization/README.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ change_linear_weights_to_int8_dqtensors(model)
136136

137137
```python
138138
# for torch 2.4+
139-
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140-
from torchao.quantization.quant_api import PerTensor
139+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
141140
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142141
```
143142

@@ -321,7 +320,7 @@ This API works today but has not been extensively tested and benchmarked yet. Ha
321320

322321
```python
323322
# for torch 2.5+
324-
from torchao.quantization.quant_api import quantize_, PerRow, float8_dynamic_activation_float8_weight
323+
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
325324
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
326325
```
327326

torchao/quantization/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from .weight_only import * # noqa: F403
1313
from .unified import *
1414
from .autoquant import *
15-
from .linear_activation_quantized_tensor import ( # noqat: F403
15+
from .granularity import *
16+
from .linear_activation_quantized_tensor import (
1617
LinearActivationQuantizedTensor,
1718
to_linear_activation_quantized,
1819
)
19-
from .linear_activation_scale import ( # noqat: F403
20+
from .linear_activation_scale import (
2021
to_weight_tensor_with_linear_activation_scale_metadata,
2122
)
2223

torchao/quantization/granularity.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class PerTensor(Granularity):
2222
"""
2323
Represents per-tensor granularity in quantization.
2424
25-
This granularity type calcualtes the quantization parameters
25+
This granularity type calculates the quantization parameters
2626
based off the entire tensor.
2727
"""
2828
pass
@@ -32,26 +32,24 @@ class PerAxis(Granularity):
3232
"""
3333
Represents per-axis granularity in quantization.
3434
35-
This granularity type calcualtes different quantization parameters
35+
This granularity type calculates different quantization parameters
3636
along a specified axis of the tensor.
3737
3838
For example if the input tensor is shape [8, 16] and axis=0, then
3939
the quantization parameters are calculated for each row of the tensor.
4040
Giving a total of 8 quantization parameters.
4141
42-
4342
Attributes:
4443
axis (int): The axis along which reduction is performed.
4544
"""
4645
axis: int
4746

4847
@dataclass(frozen=True)
49-
5048
class PerGroup(Granularity):
5149
"""
5250
Represents per-channel group granularity in quantization.
5351
54-
This granularity type calcualtes different quantization parameters
52+
This granularity type calculates different quantization parameters
5553
for each group of <group_size> elements.
5654
5755
For example if the input tensor is shape [8, 16], and the group size is 4, then
@@ -74,3 +72,19 @@ class PerRow(Granularity):
7472
is quantized with a block_size of (1, weight.shape[1]).
7573
"""
7674
pass
75+
76+
class PerToken(Granularity):
77+
"""
78+
Represents per-token granularity in quantization.
79+
80+
This granularity type calculates a different set of quantization parameters
81+
for each token, which is represented as the last dimension of the tensor.
82+
83+
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
84+
with 4 elements each, and we will calculate 6 sets of quantization parameters,
85+
one for each token.
86+
87+
If the input tensor has only two dimensions, e.g. [8, 16], then this is
88+
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
89+
"""
90+
pass

torchao/quantization/prototype/qat/api.py

+214-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,224 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List
7+
from dataclasses import dataclass
8+
from enum import Enum
9+
from typing import Any, List, Optional, Union
810

911
import torch
1012

13+
from torchao.quantization.granularity import (
14+
Granularity,
15+
PerAxis,
16+
PerGroup,
17+
PerToken,
18+
)
1119
from torchao.quantization.unified import TwoStepQuantizer
20+
from torchao.quantization.quant_primitives import (
21+
_SUB_BYTE_INT_BOUNDS,
22+
_SUB_BYTE_UINT_BOUNDS,
23+
MappingType,
24+
TorchAODType,
25+
ZeroPointDomain,
26+
)
27+
28+
29+
@dataclass
30+
class FakeQuantizeConfig:
31+
"""
32+
Config for how to fake quantize weights or activations.
33+
34+
args:
35+
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
36+
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
37+
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
38+
granularity: granularity of scales and zero points, e.g. PerGroup(32).
39+
We also support the following strings:
40+
1) 'per_token': equivalent to PerToken()
41+
2) 'per_channel': equivalent to PerAxis(0)
42+
3) 'per_group': equivalent to PerGroup(group_size), must be combined
43+
with separate `group_size` kwarg, Alternatively, just set the
44+
`group_size` kwarg and leave this field empty.
45+
mapping_type: whether to use symmetric (default) or asymmetric quantization
46+
Alternatively, set `is_symmetric` (bool) and leave this field empty.
47+
scale_precision: scale dtype (default torch.fp32)
48+
zero_point_precision: zero point dtype (default torch.int32)
49+
zero_point_domain: whether zero point is in integer (default) or float domain
50+
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
51+
range_learning: whether to learn scale and zero points during training (coming soon)
52+
53+
kwargs (optional):
54+
group_size: size of each group in per group fake quantization,
55+
can be set instead of `granularity`
56+
is_symmetric: whether to use symmetric or asymmetric quantization,
57+
can be set instead of `mapping_type`
58+
59+
Example usage::
60+
61+
# Per token asymmetric quantization
62+
FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
63+
FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
64+
65+
# Per channel symmetric quantization
66+
FakeQuantizeConfig(torch.int4, "per_channel")
67+
FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
68+
FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
69+
70+
# Per group symmetric quantization
71+
FakeQuantizeConfig(torch.int4, group_size=32)
72+
FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
73+
FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
74+
FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
75+
"""
76+
dtype: Union[torch.dtype, TorchAODType]
77+
granularity: Granularity
78+
mapping_type: MappingType
79+
scale_precision: torch.dtype
80+
zero_point_precision: torch.dtype
81+
zero_point_domain: ZeroPointDomain
82+
is_dynamic: bool = True
83+
range_learning: bool = False
84+
85+
def __init__(
86+
self,
87+
dtype: Union[torch.dtype, TorchAODType],
88+
granularity: Union[Granularity, str, None] = None,
89+
mapping_type: Optional[MappingType] = None,
90+
scale_precision: torch.dtype = torch.float32,
91+
zero_point_precision: torch.dtype = torch.int32,
92+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
93+
is_dynamic: bool = True,
94+
range_learning: bool = False,
95+
*,
96+
group_size: Optional[int] = None,
97+
is_symmetric: Optional[bool] = None,
98+
):
99+
self.dtype = dtype
100+
self.granularity = self._get_granularity(granularity, group_size)
101+
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
102+
self.scale_precision = scale_precision
103+
self.zero_point_precision = zero_point_precision
104+
self.zero_point_domain = zero_point_domain
105+
self.is_dynamic = is_dynamic
106+
self.range_learning = range_learning
107+
108+
# Validate dtype
109+
all_dtypes = [torch.int8, torch.uint8]
110+
all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
111+
all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
112+
if dtype not in all_dtypes:
113+
raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))
114+
115+
def _get_granularity(
116+
self,
117+
granularity: Union[Granularity, str, None],
118+
group_size: Optional[int],
119+
) -> Granularity:
120+
"""
121+
Parse the `Granularity` represented in the args.
122+
123+
Granularity can be specified in one of three ways:
124+
1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
125+
2) str: one of 'per_token', 'per_channel', and 'per_group'
126+
3) None: `group_size` must be set instead, represents per group granularity
127+
"""
128+
# If group_size is set, then granularity must be either "per_group" or None
129+
if group_size is not None and granularity != "per_group" and granularity is not None:
130+
raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)
131+
132+
# Case 1: Granularity object
133+
if isinstance(granularity, Granularity):
134+
if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
135+
raise ValueError("Granularity '%s' is not supported" % granularity)
136+
if isinstance(granularity, PerAxis) and granularity.axis != 0:
137+
raise ValueError("Only axis=0 is supported for PerAxis granularity")
138+
return granularity
139+
140+
# Case 2: str granularity
141+
if granularity == "per_token":
142+
return PerToken()
143+
elif granularity == "per_channel":
144+
return PerAxis(axis=0)
145+
elif granularity == "per_group":
146+
if group_size is None:
147+
raise ValueError("Granularity was 'per_group' but no `group_size` was set")
148+
return PerGroup(group_size)
149+
elif isinstance(granularity, str):
150+
raise ValueError(
151+
"Unexpected granularity: '%s', must be one of %s" %
152+
(granularity, ["per_token", "per_channel", "per_group"])
153+
)
154+
155+
# Case 3: None granularity + group_size was specified
156+
if granularity is not None:
157+
raise ValueError(
158+
"Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
159+
)
160+
if group_size is None:
161+
raise ValueError("At least one of `granularity` or `group_size` must be set")
162+
return PerGroup(group_size)
163+
164+
def _get_mapping_type(
165+
self,
166+
mapping_type: Optional[MappingType],
167+
is_symmetric: Optional[bool],
168+
) -> MappingType:
169+
"""
170+
Parse the `MappingType` represented in the args.
171+
172+
Mapping type can be specified in one of two ways:
173+
1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
174+
2): is_symmetric bool
175+
"""
176+
if mapping_type is not None and is_symmetric is not None:
177+
raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")
178+
179+
# Case 0: Default to symmetric
180+
if mapping_type is None and is_symmetric is None:
181+
return MappingType.SYMMETRIC
182+
183+
# Case 1: MappingType object
184+
if mapping_type is not None:
185+
if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
186+
raise ValueError("MappingType '%s' is not supported" % mapping_type)
187+
return mapping_type
188+
189+
# Case 2: is_symmetric flag
190+
assert is_symmetric is not None
191+
if is_symmetric:
192+
return MappingType.SYMMETRIC
193+
else:
194+
return MappingType.ASYMMETRIC
195+
196+
@property
197+
def group_size(self) -> int:
198+
"""
199+
If this is per group granularity, return the group size.
200+
Otherwise, throw an error.
201+
"""
202+
if isinstance(self.granularity, PerGroup):
203+
return self.granularity.group_size
204+
else:
205+
raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)
206+
207+
@property
208+
def is_symmetric(self) -> bool:
209+
"""
210+
Return True if mapping type is symmetric, else False (asymmetric).
211+
"""
212+
return self.mapping_type == MappingType.SYMMETRIC
213+
214+
def __setattr__(self, name: str, value: Any):
215+
"""
216+
Support setting `group_size` and `is_symmetric`.
217+
"""
218+
if name == "group_size":
219+
super().__setattr__("granularity", PerGroup(value))
220+
elif name == "is_symmetric":
221+
mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
222+
super().__setattr__("mapping_type", mapping_type)
223+
else:
224+
super().__setattr__(name, value)
12225

13226

14227
class ComposableQATQuantizer(TwoStepQuantizer):

0 commit comments

Comments
 (0)