Skip to content

Commit 643ffca

Browse files
committed
Adding tests for save/load support
Summary: we are able to save a model quantized with a tensor subclass, save the state dict, then later, load model as meta tensor (i.e. only load tensor metadata not actually parameters) apply quantization api, and then load the quantized model state dict. We change the dtype of the subclass to match the dtype of the dequantized form, both to align with subclass design guidelines and to make this work Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent a9e0596 commit 643ffca

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

test/test.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from torchao.quantization.weight_only import (
5757
WeightOnlyInt8QuantLinear
5858
)
59-
59+
import os
6060

6161
torch.manual_seed(0)
6262

@@ -904,6 +904,63 @@ def test_weight_only_quant_use_mixed_mm(self):
904904
sqnr = compute_error(y_ref, y_wo)
905905
self.assertGreater(sqnr, 43.0)
906906

907+
class TestSaveLoadMeta(unittest.TestCase):
908+
@torch.no_grad()
909+
def _test_handle_save_load_meta_impl(self, api):
910+
m, k, n = 32, 64, 32
911+
class test_model(nn.Module):
912+
def __init__(self):
913+
super().__init__()
914+
self.lin1 = nn.Linear(k, n)
915+
self.relu = nn.ReLU()
916+
self.lin2 = nn.Linear(n, n)
917+
918+
def forward(self, x):
919+
x = self.lin1(x)
920+
x = self.relu(x)
921+
x = self.lin2(x)
922+
return x
923+
924+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
925+
926+
# get float reference
927+
model = test_model().to(torch.bfloat16).cuda().eval()
928+
ref_f = model(x)
929+
930+
# save quantized state_dict
931+
api(model)
932+
torch.save(model.state_dict(), "test.pth")
933+
# get quantized reference
934+
model_qc = torch.compile(model, mode="max-autotune")
935+
ref_q = model_qc(x).detach()
936+
937+
assert SQNR(ref_f, ref_q) > 35
938+
939+
# load model structure
940+
with torch.device('meta'):
941+
model = test_model()
942+
api(model)
943+
944+
# load quantized state_dict
945+
state_dict = torch.load("test.pth", mmap=True)
946+
os.remove("test.pth")
947+
model.load_state_dict(state_dict, assign=True)
948+
model = model.to(torch.bfloat16).cuda().eval()
949+
950+
# get quantized reference
951+
model_qc = torch.compile(model, mode="max-autotune")
952+
test = model_qc(x).detach()
953+
954+
assert SQNR(ref_f, test) > 35
955+
self.assertTrue(torch.equal(ref_q, test))
956+
957+
@torch.no_grad()
958+
def test_save_load_dqtensors(self):
959+
self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors)
960+
961+
@torch.no_grad()
962+
def test_save_load_woqtensors(self):
963+
self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors)
907964

908965
class TorchCompileUnitTest(unittest.TestCase):
909966
def test_fullgraph(self):

torchao/quantization/subclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs):
3535
# transposed/detached, instead we can just pass the int_data to the
3636
# new instance and alter the transposed flag where needed.
3737
kwargs["device"] = int_data.device
38-
kwargs["dtype"] = kwargs.get("dtype", torch.int8)
38+
kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype)
3939
size = int_data.shape[::-1] if transposed else int_data.shape
4040
kwargs["layout"] = (
4141
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout

0 commit comments

Comments
 (0)