Skip to content

Commit 0bb73d8

Browse files
committed
Adding int4 tensor subclass
Summary: Adding int4 tensor subclass support, also refactoring tensor subclass code to be easier to use with multiple subclasses. This subclass uses the tinygemm int4 mixed dtype gemm that was added to pytroch as _weight_int4pack_mm and _convert_weight_to_int4pack. Also added support for .to for tensor subclasses to get the save/loading of meta tensors working for int4. Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 591d99c Pull Request resolved: #15
1 parent f2c7762 commit 0bb73d8

File tree

6 files changed

+493
-139
lines changed

6 files changed

+493
-139
lines changed

test/test.py

Lines changed: 98 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
from torch._inductor.utils import run_and_get_code
14-
14+
from torch._dynamo import config
1515
from torch.ao.quantization import MinMaxObserver, QConfigMapping
1616

1717
from torchao.quantization.dynamic_quant import (
@@ -21,7 +21,8 @@
2121
apply_dynamic_quant,
2222
apply_weight_only_int8_quant,
2323
change_linear_weights_to_dqtensors,
24-
change_linear_weights_to_woqtensors,
24+
change_linear_weights_to_int8woqtensors,
25+
change_linear_weights_to_int4woqtensors,
2526
_replace_with_custom_fn_if_matches_filter,
2627
)
2728
from torchao.quantization.quant_primitives import (
@@ -42,8 +43,9 @@
4243
swap_linear_with_smooth_fq_linear,
4344
)
4445
from torchao.quantization.subclass import (
45-
DynamicallyQuantizedLinearWeight,
46-
WeightOnlyQuantizedLinearWeight
46+
Int8DynamicallyQuantizedLinearWeight,
47+
Int8WeightOnlyQuantizedLinearWeight,
48+
Int4WeightOnlyQuantizedLinearWeight
4749
)
4850
from torchao.quantization.utils import (
4951
apply_logging_hook,
@@ -59,6 +61,7 @@
5961
import os
6062

6163
torch.manual_seed(0)
64+
config.cache_size_limit = 100
6265

6366

6467
class SmoothquantUnitTest(unittest.TestCase):
@@ -788,62 +791,108 @@ def test_qlinear_per_channel_numerics_cuda(self):
788791

789792

790793
class TestSubclass(unittest.TestCase):
794+
def _test_dequantize_impl(
795+
self,
796+
test_subclass_from_float,
797+
min_sqnr=35,
798+
test_dtype=torch.bfloat16,
799+
test_shape=[32, 64, 64],
800+
):
801+
m, k, n = test_shape
802+
lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype)
803+
w = lin.weight.detach()
804+
lin.weight = torch.nn.Parameter(
805+
test_subclass_from_float(lin.weight), requires_grad=False
806+
)
807+
self.assertGreater(SQNR(w, lin.weight.dequantize()), min_sqnr, f"{lin.weight.__class__.__name__} failed dtype={test_dtype}")
808+
self.assertGreater(SQNR(w.t(), lin.weight.t().dequantize()), min_sqnr, f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}")
809+
810+
def test_dequantize_int8_dynamic_quant_subclass(self):
811+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
812+
self._test_dequantize_impl(Int8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype)
813+
814+
def test_dequantize_int8_weight_only_quant_subclass(self):
815+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
816+
self._test_dequantize_impl(Int8WeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype)
817+
818+
def test_dequantize_int4_weight_only_quant_subclass(self):
819+
self._test_dequantize_impl(Int4WeightOnlyQuantizedLinearWeight.from_float, 15, test_shape=[1, 1024, 8])
820+
for groupsize in [256, 128]:
821+
for inner_k_tiles in [8, 2]:
822+
for m in [1, 256]:
823+
self._test_dequantize_impl(lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), 15, test_shape=[m, 256, 8])
824+
791825
def _test_lin_weight_subclass_impl(self,
792-
test_subclass,
826+
test_subclass_from_float,
793827
min_sqnr=35,
794-
test_dtypes=[torch.float32, torch.float16, torch.bfloat16],
795-
test_shape=[32, 64, 32]
828+
test_dtype=torch.bfloat16,
829+
test_shape=[32, 64, 32],
796830
):
797-
for test_dtype in test_dtypes:
798-
m, k, n = test_shape
799-
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
800-
lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype)
801-
ref_f = lin(x)
802-
803-
lin.weight = torch.nn.Parameter(
804-
test_subclass.from_float(lin.weight), requires_grad=False
805-
)
806-
test = lin(x)
807-
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{test_subclass.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}")
808-
lin_comp = torch.compile(lin, mode='max-autotune')
809-
test_comp = lin_comp(x)
810-
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{test_subclass.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}")
831+
m, k, n = test_shape
832+
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
833+
lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype)
834+
ref_f = lin(x)
835+
836+
lin.weight = torch.nn.Parameter(
837+
test_subclass_from_float(lin.weight), requires_grad=False
838+
)
839+
test = lin(x)
840+
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}")
841+
lin_comp = torch.compile(lin, mode='max-autotune')
842+
test_comp = lin_comp(x)
843+
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}")
811844

812845
def test_int8_dynamic_quant_subclass(self):
813-
self._test_lin_weight_subclass_impl(DynamicallyQuantizedLinearWeight, 35)
846+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
847+
self._test_lin_weight_subclass_impl(Int8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype)
814848

815849
def test_int8_weight_only_quant_subclass(self):
816-
self._test_lin_weight_subclass_impl(WeightOnlyQuantizedLinearWeight, 40)
850+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
851+
self._test_lin_weight_subclass_impl(Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype)
852+
853+
def test_int4_weight_only_quant_subclass(self):
854+
self._test_lin_weight_subclass_impl(Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8])
855+
for groupsize in [128, 64]:
856+
for inner_k_tiles in [4, 2]:
857+
for m in [1, 256]:
858+
self._test_lin_weight_subclass_impl(lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), 10, test_shape=[m, 256, 8])
817859

818860
@torch.no_grad()
819861
def _test_lin_weight_subclass_api_impl(
820862
self,
821863
api,
822864
min_sqnr=35,
823-
test_dtypes=[torch.float32, torch.float16, torch.bfloat16],
865+
test_dtype=torch.bfloat16,
824866
test_shape=[32, 64, 32]
825867
):
826-
for test_dtype in test_dtypes:
827-
m, k, n = test_shape
828-
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
829-
mod = nn.Sequential(
830-
nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda")
831-
).to(test_dtype)
832-
ref_f = mod(x)
833-
api(mod)
834-
test = mod(x)
835-
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}")
836-
837-
mod_qc = torch.compile(mod, mode="max-autotune")
838-
test_comp = mod_qc(x)
839-
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}")
868+
m, k, n = test_shape
869+
x = torch.randn(m, k, device="cuda", dtype=test_dtype)
870+
mod = nn.Sequential(
871+
nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda")
872+
).to(test_dtype)
873+
ref_f = mod(x)
874+
api(mod)
875+
test = mod(x)
876+
self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}")
877+
mod_qc = torch.compile(mod, mode="max-autotune")
878+
test_comp = mod_qc(x)
879+
self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}")
840880

841881

842882
def test_int8_dynamic_quant_subclass_api(self):
843-
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35)
883+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
884+
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35)
844885

845886
def test_int8_weight_only_quant_subclass_api(self):
846-
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_woqtensors, 40)
887+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
888+
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_int8woqtensors, 40)
889+
890+
def test_int4_weight_only_quant_subclass_api(self):
891+
self._test_lin_weight_subclass_api_impl(change_linear_weights_to_int4woqtensors, 15, test_shape=[1, 1024, 256])
892+
for groupsize in [64, 32]:
893+
for inner_k_tiles in [4, 2]:
894+
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
895+
self._test_lin_weight_subclass_api_impl(lambda mod: change_linear_weights_to_int4woqtensors(mod, **kwargs), 15, test_shape=[256, 256, 8])
847896

848897
class TestDynamicQuant(unittest.TestCase):
849898
def test_dynamic_quant(self):
@@ -906,7 +955,7 @@ def test_weight_only_quant_use_mixed_mm(self):
906955

907956
class TestSaveLoadMeta(unittest.TestCase):
908957
@torch.no_grad()
909-
def _test_handle_save_load_meta_impl(self, api):
958+
def _test_handle_save_load_meta_impl(self, api, min_sqnr=35):
910959
m, k, n = 32, 64, 32
911960
class test_model(nn.Module):
912961
def __init__(self):
@@ -934,7 +983,7 @@ def forward(self, x):
934983
model_qc = torch.compile(model, mode="max-autotune")
935984
ref_q = model_qc(x).detach()
936985

937-
assert SQNR(ref_f, ref_q) > 35
986+
assert SQNR(ref_f, ref_q) > min_sqnr
938987

939988
# load model structure
940989
with torch.device('meta'):
@@ -951,16 +1000,20 @@ def forward(self, x):
9511000
model_qc = torch.compile(model, mode="max-autotune")
9521001
test = model_qc(x).detach()
9531002

954-
assert SQNR(ref_f, test) > 35
1003+
assert SQNR(ref_f, test) > min_sqnr
9551004
self.assertTrue(torch.equal(ref_q, test))
9561005

9571006
@torch.no_grad()
9581007
def test_save_load_dqtensors(self):
9591008
self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors)
9601009

9611010
@torch.no_grad()
962-
def test_save_load_woqtensors(self):
963-
self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors)
1011+
def test_save_load_int8woqtensors(self):
1012+
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8woqtensors)
1013+
1014+
@torch.no_grad()
1015+
def test_save_load_int4woqtensors(self):
1016+
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4woqtensors, 20)
9641017

9651018
class TorchCompileUnitTest(unittest.TestCase):
9661019
def test_fullgraph(self):

torchao/quantization/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"apply_weight_only_int8_quant",
1717
"apply_dynamic_quant",
1818
"change_linear_weights_to_dqtensors",
19-
"change_linear_weights_to_woqtensors",
19+
"change_linear_weights_to_int8woqtensors",
20+
"change_linear_weights_to_int4woqtensors",
2021
"insert_subclass",
2122
"safe_int_mm",
2223
"dynamically_quantize_per_tensor",
@@ -34,8 +35,9 @@
3435
"swap_linear_with_smooth_fq_linear",
3536
"smooth_fq_linear_to_inference",
3637
"set_smooth_fq_attribute",
37-
"DynamicallyQuantizedLinearWeight",
38-
"WeightOnlyQuantizedLinearWeight",
38+
"Int8DynamicallyQuantizedLinearWeight",
39+
"Int8WeightOnlyQuantizedLinearWeight",
40+
"Int4WeightOnlyQuantizedLinearWeight",
3941
"log_with_rank",
4042
"clear_logs",
4143
"compute_error",

torchao/quantization/quant_api.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from .subclass import (
23-
DynamicallyQuantizedLinearWeight,
24-
WeightOnlyQuantizedLinearWeight,
23+
Int8DynamicallyQuantizedLinearWeight,
24+
Int8WeightOnlyQuantizedLinearWeight,
25+
Int4WeightOnlyQuantizedLinearWeight,
2526
)
2627
from .weight_only import (
2728
WeightOnlyInt8QuantLinear,
@@ -31,7 +32,8 @@
3132
"apply_weight_only_int8_quant",
3233
"apply_dynamic_quant",
3334
"change_linear_weights_to_dqtensors",
34-
"change_linear_weights_to_woqtensors",
35+
"change_linear_weights_to_int8woqtensors",
36+
"change_linear_weights_to_int4woqtensors",
3537
]
3638

3739

@@ -77,34 +79,46 @@ def apply_dynamic_quant(model):
7779
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
7880
)
7981

80-
def _get_subclass_inserter(cls):
82+
def _get_subclass_inserter(cls, **kwargs):
8183
def insert_subclass(lin):
8284
lin.weight = torch.nn.Parameter(
83-
cls.from_float(lin.weight), requires_grad=False
85+
cls.from_float(lin.weight, **kwargs), requires_grad=False
8486
)
8587
return lin
8688
return insert_subclass
8789

8890
def change_linear_weights_to_dqtensors(model):
8991
"""
90-
Converts all linear weight tensors to the `DynamicallyQuantizedLinearWeight`
92+
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
9193
Tensor subclass, effectively applying the same form of quantization
9294
as apply_dynamic_quant while not modifying the linear modules.
9395
"""
9496
_replace_with_custom_fn_if_matches_filter(
9597
model,
96-
_get_subclass_inserter(DynamicallyQuantizedLinearWeight),
98+
_get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight),
9799
lambda mod, fqn: isinstance(mod, torch.nn.Linear)
98100
)
99101

100-
def change_linear_weights_to_woqtensors(model):
102+
def change_linear_weights_to_int8woqtensors(model):
101103
"""
102-
Converts all linear weight tensors to the `WeightOnlyQuantizedLinearWeight`
104+
Converts all linear weight tensors to the `Int8WeightOnlyQuantizedLinearWeight`
103105
Tensor subclass, effectively applying the same form of quantization
104106
as apply_dynamic_quant while not modifying the linear modules.
105107
"""
106108
_replace_with_custom_fn_if_matches_filter(
107109
model,
108-
_get_subclass_inserter(WeightOnlyQuantizedLinearWeight),
110+
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight),
111+
lambda mod, fqn: isinstance(mod, torch.nn.Linear)
112+
)
113+
114+
def change_linear_weights_to_int4woqtensors(model, **kwargs):
115+
"""
116+
Converts all linear weight tensors to the `Int4WeightOnlyQuantizedLinearWeight`
117+
Tensor subclass, effectively applying the same form of quantization
118+
as apply_dynamic_quant while not modifying the linear modules.
119+
"""
120+
_replace_with_custom_fn_if_matches_filter(
121+
model,
122+
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs),
109123
lambda mod, fqn: isinstance(mod, torch.nn.Linear)
110124
)

0 commit comments

Comments
 (0)