Skip to content

Add generic fake quantized embedding for QAT #1085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from torchao.quantization.prototype.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.prototype.qat.embedding import (
FakeQuantizedEmbedding,
)
from torchao.quantization.prototype.qat.linear import (
FakeQuantizedLinear,
)
Expand Down Expand Up @@ -852,6 +855,40 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_embedding_4w(self):
"""
Test that we can express int4 per group symmetric weight only fake quantization
with `FakeQuantizedEmbedding`.
"""
num_embeddings = 64
embedding_dim = 128
group_size = 32
torch.manual_seed(self.SEED)
fq_embedding = FakeQuantizedEmbedding(
num_embeddings,
embedding_dim,
weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size),
)

def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Baseline for int4 per group symmetric weight only fake quantization.
"""
(s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32)
zp = zp.to(torch.int32)
(qmin, qmax) = _get_qmin_qmax(4)
w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size)
return F.embedding(x, w_fq)

# Compare embedding values
torch.manual_seed(self.SEED)
x = torch.randint(num_embeddings, (5, 10))
x2 = copy.deepcopy(x)
fq_out = fq_embedding(x)
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
169 changes: 116 additions & 53 deletions torchao/quantization/prototype/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,73 @@
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import TorchAODType
from .api import FakeQuantizeConfig
from .fake_quantizer import FakeQuantizer
from .utils import (
_fake_quantize_per_channel_group,
_get_qmin_qmax,
)


class FakeQuantizedEmbedding(torch.nn.Embedding):
"""
General embedding layer with fake quantized weights.

Specific target dtypes, granularity, schemes etc. are specified
through separate configs for weights and activations.

Example usage::

weight_config = FakeQuantizeConfig(
dtype=torch.int4,
group_size=8,
symmetric=True,
)
fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
fq_embedding(torch.LongTensor([3]))
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
weight_config: Optional[FakeQuantizeConfig] = None,
*args,
**kwargs,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
*args,
**kwargs,
)
if weight_config is not None:
self.weight_fake_quantizer = FakeQuantizer(weight_config)
else:
self.weight_fake_quantizer = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.weight_fake_quantizer is not None:
w = self.weight_fake_quantizer(self.weight)
else:
w = self.weight
return F.embedding(
x, w, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse,
)


# ======================================
# | Embedding int4 weight-only QAT |
# ======================================
Expand All @@ -40,7 +101,7 @@ def __init__(
self.bit_width = 4
self.group_size: int = group_size
self.scale_precision: torch.dtype = scale_precision
self.zero_point_precision: torch.dtype = zero_point_precision,
self.zero_point_precision: torch.dtype = zero_point_precision

def prepare(
self,
Expand All @@ -56,16 +117,18 @@ def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = Int4WeightOnlyQATEmbedding(
group_size=self.group_size,

# other nn.Embedding args
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=self.group_size,
scale_precision=self.scale_precision,
zero_point_precision=self.zero_point_precision,
device=child.weight.device,
)
# In distributed training, the model may be instantiated
Expand Down Expand Up @@ -98,28 +161,31 @@ def _convert_helper(self, module: torch.nn.Module):
from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper
for name, child in module.named_children():
if isinstance(child, Int4WeightOnlyQATEmbedding):
group_size = child.weight_fake_quantizer.config.group_size
scale_precision = child.weight_fake_quantizer.config.scale_precision
zero_point_precision = child.weight_fake_quantizer.config.zero_point_precision
quantized_embedding = Int4WeightOnlyEmbedding(
group_size=child.group_size,
scale_precision=child.scale_precision,
zero_point_precision=child.zero_point_precision,

# other nn.Embedding args
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=group_size,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
device=child.weight.device,
)
setattr(module, name, quantized_embedding)

# Load weights and qparams into quantized embedding
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, child.group_size)
(s, zp) = get_group_qparams_symmetric(child.weight, self.bit_width, group_size)
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
child.weight, s, zp, qmin, qmax, torch.int8, child.group_size,
child.weight, s, zp, qmin, qmax, torch.int8, group_size,
)
quantized_embedding.weight = q_weight
quantized_embedding.scales = s
Expand All @@ -128,7 +194,7 @@ def _convert_helper(self, module: torch.nn.Module):
self._convert_helper(child)


class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):
class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
"""
This module implements a embedding layer with int4 fake quantized
grouped per channel weights.
Expand All @@ -141,47 +207,42 @@ class Int4WeightOnlyQATEmbedding(torch.nn.Embedding):

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self._fake_quant_enabled = True

def forward(self, x):
weight = self.weight

if self._fake_quant_enabled:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, self.bit_width, self.group_size, self.scale_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_point_precision)
(weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width)
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
weight_qmin,
weight_qmax,
self.group_size,
)
else:
w_fq = self.weight

return F.embedding(
x, w_fq, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse,
weight_config = FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
)
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
weight_config,
*args,
**kwargs,
)

def enable_fake_quant(self, enabled: bool = True):
self._fake_quant_enabled = enabled
self.weight_fake_quantizer.enabled = enabled

def disable_fake_quant(self):
self.enable_fake_quant(False)
Expand All @@ -194,25 +255,21 @@ class Int4WeightOnlyEmbedding(torch.nn.Module):
"""
def __init__(
self,
group_size: int,
scale_precision: torch.dtype,
zero_point_precision: torch.dtype,

# nn.Embedding args
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
device: torch.device = None,
):
super().__init__()
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision

# nn.Embedding args
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
Expand All @@ -221,6 +278,12 @@ def __init__(
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse

# quantization args
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
Expand Down
Loading