Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def _has_nested_attr(obj, attr_path):
)
else:
self.assertTrue(
torch.allclose(submodule.weight, comp_decomp_obj.weight, atol=0.2),
torch.allclose(
submodule.weight.to(torch_device),
comp_decomp_obj.weight.to(torch_device),
atol=0.2,
),
f"Weight mismatch for module '{name}' in quantized-only or stacked model.",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_config_to_from_dict(self):
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)

def test_tinyllama_w8a8(self):
expected_out = "<s> Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n"
expected_out = "<s> Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1"
self._test_quantized_model(self.tinyllama_w8a8, expected_out)

def test_tinyllama_w4a16(self):
Expand All @@ -59,7 +59,7 @@ def test_tinyllama_w8a16(self):
self._test_quantized_model(self.tinyllama_w8a16, expected_out)

def test_llama_8b_fp8(self):
expected_out = "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous museum in Paris that is home to the Mona Lisa? The Louvre\nWhat is the name of the famous bridge in Paris that is often associated with the city"
expected_out = "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous opera house in Paris? Palais Garnier\nWhat is the name of the"
self._test_quantized_model(self.llama3_8b_fp8, expected_out)

def _test_quantized_model(self, model_name: str, expected_output: str):
Expand Down