You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "Add generic fake quantized linear for QAT"
**Summary:** This commit adds a generic fake quantized linear module
to replace the uses of the existing more specific QAT linears.
For example, `Int8DynActInt4WeightQATLinear` can be expressed
as follows:
```
from torchao.quantization.prototype.qat.api import FakeQuantizeConfig
from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=8)
fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config)
```
The main motivation is to provide a more flexible way to perform
QAT on models with linear layers. Previously, we would have to
create a new linear class every time we wish to experiment with
different fake quantization settings, e.g. different group size
or different bit width. Now we can express this easily using a
single linear module.
**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity
python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases
python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type
python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w
python test/quantization/test_qat.py -k test_fake_quantized_linear_4w
[ghstack-poisoned]
0 commit comments