|
56 | 56 | from torchao.quantization.weight_only import (
|
57 | 57 | WeightOnlyInt8QuantLinear
|
58 | 58 | )
|
59 |
| - |
| 59 | +import os |
60 | 60 |
|
61 | 61 | torch.manual_seed(0)
|
62 | 62 |
|
@@ -904,6 +904,63 @@ def test_weight_only_quant_use_mixed_mm(self):
|
904 | 904 | sqnr = compute_error(y_ref, y_wo)
|
905 | 905 | self.assertGreater(sqnr, 43.0)
|
906 | 906 |
|
| 907 | +class TestSaveLoadMeta(unittest.TestCase): |
| 908 | + @torch.no_grad() |
| 909 | + def _test_handle_save_load_meta_impl(self, api): |
| 910 | + m, k, n = 32, 64, 32 |
| 911 | + class test_model(nn.Module): |
| 912 | + def __init__(self): |
| 913 | + super().__init__() |
| 914 | + self.lin1 = nn.Linear(k, n) |
| 915 | + self.relu = nn.ReLU() |
| 916 | + self.lin2 = nn.Linear(n, n) |
| 917 | + |
| 918 | + def forward(self, x): |
| 919 | + x = self.lin1(x) |
| 920 | + x = self.relu(x) |
| 921 | + x = self.lin2(x) |
| 922 | + return x |
| 923 | + |
| 924 | + x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") |
| 925 | + |
| 926 | + # get float reference |
| 927 | + model = test_model().to(torch.bfloat16).cuda().eval() |
| 928 | + ref_f = model(x) |
| 929 | + |
| 930 | + # save quantized state_dict |
| 931 | + api(model) |
| 932 | + torch.save(model.state_dict(), "test.pth") |
| 933 | + # get quantized reference |
| 934 | + model_qc = torch.compile(model, mode="max-autotune") |
| 935 | + ref_q = model_qc(x).detach() |
| 936 | + |
| 937 | + assert SQNR(ref_f, ref_q) > 35 |
| 938 | + |
| 939 | + # load model structure |
| 940 | + with torch.device('meta'): |
| 941 | + model = test_model() |
| 942 | + api(model) |
| 943 | + |
| 944 | + # load quantized state_dict |
| 945 | + state_dict = torch.load("test.pth", mmap=True) |
| 946 | + os.remove("test.pth") |
| 947 | + model.load_state_dict(state_dict, assign=True) |
| 948 | + model = model.to(torch.bfloat16).cuda().eval() |
| 949 | + |
| 950 | + # get quantized reference |
| 951 | + model_qc = torch.compile(model, mode="max-autotune") |
| 952 | + test = model_qc(x).detach() |
| 953 | + |
| 954 | + assert SQNR(ref_f, test) > 35 |
| 955 | + self.assertTrue(torch.equal(ref_q, test)) |
| 956 | + |
| 957 | + @torch.no_grad() |
| 958 | + def test_save_load_dqtensors(self): |
| 959 | + self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors) |
| 960 | + |
| 961 | + @torch.no_grad() |
| 962 | + def test_save_load_woqtensors(self): |
| 963 | + self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors) |
907 | 964 |
|
908 | 965 | class TorchCompileUnitTest(unittest.TestCase):
|
909 | 966 | def test_fullgraph(self):
|
|
0 commit comments