diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index b00d0ad40b..2bc3fce360 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -29,6 +29,9 @@ from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) +from torchao.quantization.prototype.qat.embedding import ( + FakeQuantizedEmbedding, +) from torchao.quantization.prototype.qat.linear import ( FakeQuantizedLinear, ) @@ -852,6 +855,40 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantized_embedding_4w(self): + """ + Test that we can express int4 per group symmetric weight only fake quantization + with `FakeQuantizedEmbedding`. + """ + num_embeddings = 64 + embedding_dim = 128 + group_size = 32 + torch.manual_seed(self.SEED) + fq_embedding = FakeQuantizedEmbedding( + num_embeddings, + embedding_dim, + weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + ) + + def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Baseline for int4 per group symmetric weight only fake quantization. + """ + (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) + zp = zp.to(torch.int32) + (qmin, qmax) = _get_qmin_qmax(4) + w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + return F.embedding(x, w_fq) + + # Compare embedding values + torch.manual_seed(self.SEED) + x = torch.randint(num_embeddings, (5, 10)) + x2 = copy.deepcopy(x) + fq_out = fq_embedding(x) + baseline_out = embedding_forward_4w(x2, fq_embedding.weight) + torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/quantization/prototype/qat/embedding.py index fb605de166..1f471fa490 100644 --- a/torchao/quantization/prototype/qat/embedding.py +++ b/torchao/quantization/prototype/qat/embedding.py @@ -14,12 +14,73 @@ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) +from torchao.quantization.quant_primitives import TorchAODType +from .api import FakeQuantizeConfig +from .fake_quantizer import FakeQuantizer from .utils import ( _fake_quantize_per_channel_group, _get_qmin_qmax, ) +class FakeQuantizedEmbedding(torch.nn.Embedding): + """ + General embedding layer with fake quantized weights. + + Specific target dtypes, granularity, schemes etc. are specified + through separate configs for weights and activations. + + Example usage:: + + weight_config = FakeQuantizeConfig( + dtype=torch.int4, + group_size=8, + symmetric=True, + ) + fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config) + fq_embedding(torch.LongTensor([3])) + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + weight_config: Optional[FakeQuantizeConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + *args, + **kwargs, + ) + if weight_config is not None: + self.weight_fake_quantizer = FakeQuantizer(weight_config) + else: + self.weight_fake_quantizer = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.weight_fake_quantizer is not None: + w = self.weight_fake_quantizer(self.weight) + else: + w = self.weight + return F.embedding( + x, w, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse, + ) + + # ====================================== # | Embedding int4 weight-only QAT | # ====================================== @@ -40,7 +101,7 @@ def __init__( self.bit_width = 4 self.group_size: int = group_size self.scale_precision: torch.dtype = scale_precision - self.zero_point_precision: torch.dtype = zero_point_precision, + self.zero_point_precision: torch.dtype = zero_point_precision def prepare( self, @@ -56,9 +117,7 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: new_embedding = Int4WeightOnlyQATEmbedding( - group_size=self.group_size, - - # other nn.Embedding args + # nn.Embedding args num_embeddings=child.num_embeddings, embedding_dim=child.embedding_dim, padding_idx=child.padding_idx, @@ -66,6 +125,10 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: norm_type=child.norm_type, scale_grad_by_freq=child.scale_grad_by_freq, sparse=child.sparse, + # quantization args + group_size=self.group_size, + scale_precision=self.scale_precision, + zero_point_precision=self.zero_point_precision, device=child.weight.device, ) # In distributed training, the model may be instantiated @@ -98,12 +161,11 @@ def _convert_helper(self, module: torch.nn.Module): from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper for name, child in module.named_children(): if isinstance(child, Int4WeightOnlyQATEmbedding): + group_size = child.weight_fake_quantizer.config.group_size + scale_precision = child.weight_fake_quantizer.config.scale_precision + zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision quantized_embedding = Int4WeightOnlyEmbedding( - group_size=child.group_size, - scale_precision=child.scale_precision, - zero_point_precision=child.zero_point_precision, - - # other nn.Embedding args + # nn.Embedding args num_embeddings=child.num_embeddings, embedding_dim=child.embedding_dim, padding_idx=child.padding_idx, @@ -111,15 +173,19 @@ def _convert_helper(self, module: torch.nn.Module): norm_type=child.norm_type, scale_grad_by_freq=child.scale_grad_by_freq, sparse=child.sparse, + # quantization args + group_size=group_size, + scale_precision=scale_precision, + zero_point_precision=zero_point_precision, device=child.weight.device, ) setattr(module, name, quantized_embedding) # Load weights and qparams into quantized embedding (qmin, qmax) = _get_qmin_qmax(self.bit_width) - (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, child.group_size) + (s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.group_size, + child.weight, s, zp, qmin, qmax, torch.int8, group_size, ) quantized_embedding.weight = q_weight quantized_embedding.scales = s @@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module): self._convert_helper(child) -class Int4WeightOnlyQATEmbedding(torch.nn.Embedding): +class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding): """ This module implements a embedding layer with int4 fake quantized grouped per channel weights. @@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding): def __init__( self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, group_size: int = 32, scale_precision: torch.dtype = torch.float32, zero_point_precision: torch.dtype = torch.int32, *args, **kwargs, ): - super().__init__(*args, **kwargs) - self.bit_width = 4 - self.group_size = group_size - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - self._fake_quant_enabled = True - - def forward(self, x): - weight = self.weight - - if self._fake_quant_enabled: - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, self.bit_width, self.group_size, self.scale_precision, - ) - # TODO: pass zp dtype to `get_group_qparams_symmetric` instead - weight_zp = weight_zp.to(self.zero_point_precision) - (weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width) - w_fq = _fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.group_size, - ) - else: - w_fq = self.weight - - return F.embedding( - x, w_fq, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse, + weight_config = FakeQuantizeConfig( + dtype=TorchAODType.INT4, + group_size=group_size, + is_symmetric=True, + is_dynamic=True, + scale_precision=scale_precision, + zero_point_precision=zero_point_precision, + ) + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + weight_config, + *args, + **kwargs, ) def enable_fake_quant(self, enabled: bool = True): - self._fake_quant_enabled = enabled + self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) @@ -194,11 +255,6 @@ class Int4WeightOnlyEmbedding(torch.nn.Module): """ def __init__( self, - group_size: int, - scale_precision: torch.dtype, - zero_point_precision: torch.dtype, - - # nn.Embedding args num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, @@ -206,13 +262,14 @@ def __init__( norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, + group_size: int = 32, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, device: torch.device = None, ): super().__init__() - self.bit_width = 4 - self.group_size = group_size - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision + + # nn.Embedding args self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx @@ -221,6 +278,12 @@ def __init__( self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse + # quantization args + self.bit_width = 4 + self.group_size = group_size + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + # currently storing unpacked int8 weights self.register_buffer( "weight",