Skip to content

Commit f782422

Browse files
committed
Allow setting eps in FakeQuantizeConfig
**Summary:** Today, we always use `torch.finfo(x.dtype).eps`, where `x` is the value we are trying to quantize, and there is no way for users to configure this. However, users lowering to XNNPACK may wish to use this combination of dtypes during training for end-to-end numerical match: - input activations: bf16 - input activation scales: fp32 - input activation eps: `torch.finfo(torch.float32).eps` - weight: bf16 - weight scales: bf16 - weight eps: `torch.finfo(torch.bfloat16).eps` Adding `eps` to `FakeQuantizeConfig`enables such a use case. **Test Plan:** TBD
1 parent c9b9adc commit f782422

File tree

5 files changed

+25
-4
lines changed

5 files changed

+25
-4
lines changed

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,10 @@ def linear_forward_8da4w(
938938
# TODO: in future add ability to specify activation_scale_dtype to PTQ configs
939939
# and enable similar change here
940940
x = per_token_dynamic_quant(
941-
x, scale_dtype=torch.float32, zero_point_dtype=torch.float32
941+
x,
942+
scale_dtype=torch.float32,
943+
zero_point_dtype=torch.float32,
944+
eps=torch.finfo(torch.float32).eps,
942945
)
943946

944947
# TODO: verify and remove following reshape code

torchao/quantization/qat/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class FakeQuantizeConfig:
8585
zero_point_domain: ZeroPointDomain
8686
is_dynamic: bool = True
8787
range_learning: bool = False
88+
eps: Optional[float] = None
8889

8990
def __init__(
9091
self,
@@ -96,6 +97,7 @@ def __init__(
9697
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
9798
is_dynamic: bool = True,
9899
range_learning: bool = False,
100+
eps: Optional[float] = None,
99101
*,
100102
group_size: Optional[int] = None,
101103
is_symmetric: Optional[bool] = None,
@@ -110,6 +112,7 @@ def __init__(
110112
self.zero_point_domain = zero_point_domain
111113
self.is_dynamic = is_dynamic
112114
self.range_learning = range_learning
115+
self.eps = eps
113116

114117
# Validate dtype
115118
all_dtypes = [torch.int8, torch.uint8]

torchao/quantization/qat/fake_quantizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _per_token_forward(self, x: torch.Tensor):
8181
target_dtype=self.config.dtype,
8282
quant_min=qmin,
8383
quant_max=qmax,
84+
eps=self.config.eps,
8485
scale_dtype=self.config.scale_precision,
8586
zero_point_dtype=self.config.zero_point_precision,
8687
)
@@ -117,13 +118,15 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
117118
bit_width,
118119
group_size,
119120
scale_precision,
121+
eps=self.config.eps,
120122
)
121123
else:
122124
(self.scale, self.zero_point) = get_groupwise_affine_qparams(
123125
x,
124126
bit_width,
125127
group_size,
126128
scale_precision,
129+
eps=self.config.eps,
127130
)
128131
self.zero_point = self.zero_point.to(zero_point_precision)
129132

torchao/quantization/qat/linear.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def __init__(
177177
self.padding_allowed: bool = padding_allowed
178178
self.precision: torch.dtype = precision
179179
self.scales_precision: torch.dtype = scales_precision
180+
# TODO: generalize this
181+
self.activation_scales_precision = torch.float32
180182

181183
def prepare(
182184
self, model: torch.nn.Module, *args: Any, **kwargs: Any
@@ -247,7 +249,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
247249
self._convert_qat_linear_8da4w(child)
248250

249251
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
250-
return _get_8da4w_activation_config(self.scales_precision)
252+
return _get_8da4w_activation_config(self.activation_scales_precision)
251253

252254
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
253255
return _get_8da4w_weight_config(self.groupsize, self.scales_precision)
@@ -280,6 +282,7 @@ def __init__(
280282
) -> None:
281283
# Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant,
282284
# which is used in PTQ routines
285+
# TODO: generalize this
283286
activation_config = _get_8da4w_activation_config(torch.float32)
284287
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
285288
super().__init__(
@@ -320,13 +323,16 @@ def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantize
320323
"""
321324
Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
322325
"""
326+
# TODO: generalize this
327+
assert qparams_precision == torch.float32
323328
return FakeQuantizeConfig(
324329
dtype=torch.int8,
325330
granularity="per_token",
326331
is_symmetric=False,
327332
is_dynamic=True,
328333
scale_precision=qparams_precision,
329334
zero_point_precision=qparams_precision,
335+
eps=torch.finfo(qparams_precision).eps,
330336
)
331337

332338

torchao/quantization/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def get_groupwise_affine_qparams(
324324
dtype=torch.bfloat16,
325325
zero_point_domain=ZeroPointDomain.FLOAT,
326326
preserve_zero=False,
327+
eps=None,
327328
):
328329
if groupsize > w.shape[-1]:
329330
groupsize = w.shape[-1]
@@ -337,7 +338,8 @@ def get_groupwise_affine_qparams(
337338
block_size = (1, groupsize)
338339
quant_min = 0
339340
quant_max = 2**n_bit - 1
340-
eps = 1e-6
341+
if eps is None:
342+
eps = 1e-6
341343
scale_dtype = dtype
342344
zero_point_dtype = (
343345
dtype if zero_point_domain != ZeroPointDomain.INT else torch.int32
@@ -529,6 +531,7 @@ def get_group_qparams_symmetric(
529531
groupsize=128,
530532
precision=torch.float32,
531533
mapping_type=MappingType.SYMMETRIC,
534+
eps=None,
532535
):
533536
# needed for GPTQ with padding
534537
if groupsize > w.shape[-1]:
@@ -539,7 +542,8 @@ def get_group_qparams_symmetric(
539542
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
540543

541544
block_size = (1, groupsize)
542-
eps = torch.finfo(w.dtype).eps
545+
if eps is None:
546+
eps = torch.finfo(w.dtype).eps
543547
ranges = {}
544548
ranges[1] = (-1, 0)
545549
# generating ranges for bit 2 to 8
@@ -590,6 +594,7 @@ def per_token_dynamic_quant(
590594
input: torch.Tensor,
591595
scale_dtype: torch.dtype = torch.float32,
592596
zero_point_dtype: torch.dtype = torch.float32,
597+
eps: Optional[float] = None,
593598
) -> torch.Tensor:
594599
mapping_type = MappingType.ASYMMETRIC
595600
block_size = _get_per_token_block_size(input)
@@ -607,6 +612,7 @@ def per_token_dynamic_quant(
607612
quant_max,
608613
scale_dtype=scale_dtype,
609614
zero_point_dtype=zero_point_dtype,
615+
eps=eps,
610616
)
611617
q = quantize_affine(
612618
input,

0 commit comments

Comments
 (0)