Skip to content

Adding tests for save/load support #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from torchao.quantization.weight_only import (
WeightOnlyInt8QuantLinear
)

import os

torch.manual_seed(0)

Expand Down Expand Up @@ -904,6 +904,63 @@ def test_weight_only_quant_use_mixed_mm(self):
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)

class TestSaveLoadMeta(unittest.TestCase):
@torch.no_grad()
def _test_handle_save_load_meta_impl(self, api):
m, k, n = 32, 64, 32
class test_model(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(k, n)
self.relu = nn.ReLU()
self.lin2 = nn.Linear(n, n)

def forward(self, x):
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
return x

x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")

# get float reference
model = test_model().to(torch.bfloat16).cuda().eval()
ref_f = model(x)

# save quantized state_dict
api(model)
torch.save(model.state_dict(), "test.pth")
# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
ref_q = model_qc(x).detach()

assert SQNR(ref_f, ref_q) > 35

# load model structure
with torch.device('meta'):
model = test_model()
api(model)

# load quantized state_dict
state_dict = torch.load("test.pth", mmap=True)
os.remove("test.pth")
model.load_state_dict(state_dict, assign=True)
model = model.to(torch.bfloat16).cuda().eval()

# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
test = model_qc(x).detach()

assert SQNR(ref_f, test) > 35
self.assertTrue(torch.equal(ref_q, test))

@torch.no_grad()
def test_save_load_dqtensors(self):
self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors)

@torch.no_grad()
def test_save_load_woqtensors(self):
self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors)

class TorchCompileUnitTest(unittest.TestCase):
def test_fullgraph(self):
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs):
# transposed/detached, instead we can just pass the int_data to the
# new instance and alter the transposed flag where needed.
kwargs["device"] = int_data.device
kwargs["dtype"] = kwargs.get("dtype", torch.int8)
kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype)
size = int_data.shape[::-1] if transposed else int_data.shape
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
Expand Down