Skip to content

Commit 4e4dd12

Browse files
committed
Add cachemask variant for fake_quantize_affine
Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
1 parent 1029df3 commit 4e4dd12

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torchao.quantization.quant_primitives import (
1212
fake_quantize_affine,
13+
fake_quantize_affine_cachemask,
1314
quantize_affine,
1415
dequantize_affine,
1516
choose_qparams_affine,
@@ -523,5 +524,28 @@ def test_fake_quantize_affine(self):
523524
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
524525
torch.testing.assert_close(dequantized, fake_quantized)
525526

527+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
528+
def test_fake_quantize_affine_cachemask(self):
529+
input = torch.randn(10, 10)
530+
531+
mapping_type = MappingType.SYMMETRIC
532+
block_size = list(input.shape)
533+
for i in range(len(block_size) - 1):
534+
block_size[i] = 1
535+
dtype = torch.int8
536+
eps = 1e-5
537+
quant_min = -127
538+
quant_max = 127
539+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)
540+
541+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
542+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
543+
(fake_quantized, mask) = fake_quantize_affine_cachemask(
544+
input, block_size, scale, zero_point, dtype, quant_min, quant_max,
545+
)
546+
expected_mask = torch.full(input.shape, True)
547+
torch.testing.assert_close(dequantized, fake_quantized)
548+
torch.testing.assert_close(expected_mask, mask)
549+
526550
if __name__ == "__main__":
527551
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"quantize_affine",
2525
"dequantize_affine",
2626
"fake_quantize_affine",
27+
"fake_quantize_affine_cachemask",
2728
]
2829

2930
class MappingType(Enum):
@@ -411,6 +412,87 @@ def fake_quantize_affine(
411412
value during quantization
412413
default is ZeroPointDomain.INT
413414
"""
415+
(_, fq) = _do_fake_quantize_affine(
416+
input,
417+
block_size,
418+
scale,
419+
zero_point,
420+
quant_dtype,
421+
quant_min,
422+
quant_max,
423+
zero_point_domain,
424+
)
425+
return fq
426+
427+
428+
def fake_quantize_affine_cachemask(
429+
input: torch.Tensor,
430+
block_size: Tuple[int, ...],
431+
scale: torch.Tensor,
432+
zero_point: Optional[torch.Tensor],
433+
quant_dtype: torch.dtype,
434+
quant_min: Optional[int] = None,
435+
quant_max: Optional[int] = None,
436+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
437+
) -> Tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
General fake quantize op for quantization-aware training (QAT).
440+
This is equivalent to calling `quantize_affine` + `dequantize_affine`
441+
but without the dtype casts.
442+
443+
Note: Compared to `fake_quantize_affine`, this consumes more memory and
444+
returns an additional outlier mask for intermediate quantized values.
445+
446+
Returns:
447+
A 2-tuple of (
448+
final fake quantized values,
449+
outlier mask for intermediate quantized values
450+
)
451+
452+
Args:
453+
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
454+
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
455+
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
456+
scale (float): quantization parameter for affine quantization
457+
zero_point (int): quantization parameter for affine quantization
458+
quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
459+
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
460+
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
461+
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
462+
if zero_point is in integer domain, zero point is added to the quantized integer value during
463+
quantization
464+
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
465+
value during quantization
466+
default is ZeroPointDomain.INT
467+
"""
468+
(q, dq) = _do_fake_quantize_affine(
469+
input,
470+
block_size,
471+
scale,
472+
zero_point,
473+
quant_dtype,
474+
quant_min,
475+
quant_max,
476+
zero_point_domain,
477+
)
478+
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
479+
return (dq, mask)
480+
481+
482+
def _do_fake_quantize_affine(
483+
input: torch.Tensor,
484+
block_size: Tuple[int, ...],
485+
scale: torch.Tensor,
486+
zero_point: Optional[torch.Tensor],
487+
quant_dtype: torch.dtype,
488+
quant_min: Optional[int] = None,
489+
quant_max: Optional[int] = None,
490+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
491+
) -> Tuple[torch.Tensor, torch.Tensor]:
492+
"""
493+
Helper function for `fake_quantize_affine` that returns both the
494+
intermediate quantized values and the final dequantized values.
495+
"""
414496
input_dtype = input.dtype
415497
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
416498
q = _quantize_affine_no_dtype_cast(
@@ -432,7 +514,7 @@ def fake_quantize_affine(
432514
zero_point_domain.name,
433515
output_dtype=input_dtype,
434516
)
435-
return dq
517+
return (q, dq)
436518

437519

438520
def choose_qparams_affine(

0 commit comments

Comments
 (0)