Skip to content

Commit b54f235

Browse files
committed
Update on "Move QAT out of prototype"
**Summary:** Move QAT out of prototype so we can provide stronger BC guarantees moving forward. **BC-breaking notes** Before: ``` from torchao.quantization.prototype.qat import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.prototype.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.prototype.qat.fake_quantizer import ( FakeQuantizer, ) ``` After: ``` from torchao.quantization.qat import ( ComposableQATQuantizer, Int4WeightOnlyQATQuantizer, Int4WeightOnlyEmbeddingQATQuantizer Int8DynActInt4WeightQATQuantizer, ) from torchao.quantization.qat.linear import ( disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, Int8DynActInt4WeightQATLinear, ) from torchao.quantization.qat.api import ( FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, ) ``` **Test Plan:** python test/quantization/test_qat.py Differential Revision: [D64555609](https://our.internmc.facebook.com/intern/diff/D64555609) [ghstack-poisoned]
2 parents 58f402d + add3c16 commit b54f235

File tree

10 files changed

+121
-1
lines changed

10 files changed

+121
-1
lines changed

test/quantization/test_qat.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,59 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
937937
baseline_out = embedding_forward_4w(x2, fq_embedding.weight)
938938
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)
939939

940+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
941+
def test_qat_prototype_bc(self):
942+
"""
943+
Just to make sure we can import all the old prototype paths.
944+
We will remove this test in the near future when we actually break BC.
945+
"""
946+
from torchao.quantization.prototype.qat import (
947+
disable_4w_fake_quant,
948+
disable_8da4w_fake_quant,
949+
enable_4w_fake_quant,
950+
enable_8da4w_fake_quant,
951+
ComposableQATQuantizer,
952+
Int8DynActInt4WeightQATLinear,
953+
Int4WeightOnlyEmbeddingQATQuantizer,
954+
Int4WeightOnlyQATQuantizer,
955+
Int8DynActInt4WeightQATQuantizer,
956+
)
957+
from torchao.quantization.prototype.qat._module_swap_api import (
958+
disable_4w_fake_quant_module_swap,
959+
enable_4w_fake_quant_module_swap,
960+
disable_8da4w_fake_quant_module_swap,
961+
enable_8da4w_fake_quant_module_swap,
962+
Int4WeightOnlyQATQuantizerModuleSwap,
963+
Int8DynActInt4WeightQATQuantizerModuleSwap,
964+
)
965+
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
966+
AffineFakeQuantizedTensor,
967+
to_affine_fake_quantized,
968+
)
969+
from torchao.quantization.prototype.qat.api import (
970+
ComposableQATQuantizer,
971+
FakeQuantizeConfig,
972+
)
973+
from torchao.quantization.prototype.qat.embedding import (
974+
FakeQuantizedEmbedding,
975+
Int4WeightOnlyEmbeddingQATQuantizer,
976+
Int4WeightOnlyEmbedding,
977+
Int4WeightOnlyQATEmbedding,
978+
)
979+
from torchao.quantization.prototype.qat.fake_quantizer import (
980+
FakeQuantizer,
981+
)
982+
from torchao.quantization.prototype.qat.linear import (
983+
disable_4w_fake_quant,
984+
disable_8da4w_fake_quant,
985+
enable_4w_fake_quant,
986+
enable_8da4w_fake_quant,
987+
FakeQuantizedLinear,
988+
Int4WeightOnlyQATLinear,
989+
Int4WeightOnlyQATQuantizer,
990+
Int8DynActInt4WeightQATLinear,
991+
Int8DynActInt4WeightQATQuantizer,
992+
)
940993

941994
if __name__ == "__main__":
942-
unittest.main()
995+
unittest.main()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Note: QAT has been moved to torchao/quantization/qat.
2+
This is a legacy folder only for backward compatibility
3+
and will be removed in the near future.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from torchao.quantization.qat import (
2+
ComposableQATQuantizer,
3+
Int4WeightOnlyEmbeddingQATQuantizer,
4+
Int4WeightOnlyQATQuantizer,
5+
Int8DynActInt4WeightQATQuantizer,
6+
)
7+
from torchao.quantization.qat.linear import (
8+
disable_4w_fake_quant,
9+
disable_8da4w_fake_quant,
10+
enable_4w_fake_quant,
11+
enable_8da4w_fake_quant,
12+
Int8DynActInt4WeightQATLinear,
13+
)
14+
15+
__all__ = [
16+
"disable_4w_fake_quant",
17+
"disable_8da4w_fake_quant",
18+
"enable_4w_fake_quant",
19+
"enable_8da4w_fake_quant",
20+
"ComposableQATQuantizer",
21+
"Int4WeightOnlyQATQuantizer",
22+
"Int4WeightOnlyEmbeddingQATQuantizer"
23+
"Int8DynActInt4WeightQATQuantizer",
24+
"Int8DynActInt4WeightQATLinear",
25+
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# For backward compatibility only
2+
# These will be removed in the future
3+
4+
from torchao.quantization.qat.linear import (
5+
Int8DynActInt4WeightQATQuantizer as Int8DynActInt4WeightQATQuantizerModuleSwap,
6+
Int4WeightOnlyQATQuantizer as Int4WeightOnlyQATQuantizerModuleSwap,
7+
enable_8da4w_fake_quant as enable_8da4w_fake_quant_module_swap,
8+
disable_8da4w_fake_quant as disable_8da4w_fake_quant_module_swap,
9+
enable_4w_fake_quant as enable_4w_fake_quant_module_swap,
10+
disable_4w_fake_quant as disable_4w_fake_quant_module_swap,
11+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchao.quantization.qat.affine_fake_quantized_tensor import (
2+
AffineFakeQuantizedTensor,
3+
to_affine_fake_quantized,
4+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchao.quantization.qat.api import (
2+
ComposableQATQuantizer,
3+
FakeQuantizeConfig,
4+
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torchao.quantization.qat.embedding import (
2+
FakeQuantizedEmbedding,
3+
Int4WeightOnlyEmbeddingQATQuantizer,
4+
Int4WeightOnlyEmbedding,
5+
Int4WeightOnlyQATEmbedding,
6+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.quantization.qat.fake_quantizer import (
2+
FakeQuantizer,
3+
)
Loading
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torchao.quantization.qat.linear import (
2+
disable_4w_fake_quant,
3+
disable_8da4w_fake_quant,
4+
enable_4w_fake_quant,
5+
enable_8da4w_fake_quant,
6+
FakeQuantizedLinear,
7+
Int4WeightOnlyQATLinear,
8+
Int4WeightOnlyQATQuantizer,
9+
Int8DynActInt4WeightQATLinear,
10+
Int8DynActInt4WeightQATQuantizer,
11+
)

0 commit comments

Comments
 (0)