Skip to content

Commit 64719d5

Browse files
authored
Add compile tests to test suite (#906)
* Add compile tests to test suite Summary: This is a follow up PR addressing #839 (comment) We can add more compiler related tests in the future. Next * refactor a bit to use quantize_ API directly * use the test suite in existing API tests Test Plan: python torchao/testing/utils.py Reviewers: Subscribers: Tasks: Tags: * rename * add result check
1 parent d267622 commit 64719d5

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

torchao/testing/utils.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def new_test(self, value=value):
6969

7070

7171
class TorchAOBasicTestCase(common_utils.TestCase):
72-
"""Basic test case for tensor subclasses
73-
"""
7472
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
7573
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
7674

@@ -142,6 +140,66 @@ def test_linear(self, device, dtype):
142140
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
143141
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
144142

143+
144+
class TorchAOCompileTestCase(common_utils.TestCase):
145+
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
146+
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
147+
148+
TENSOR_SUBCLASS = AffineQuantizedTensor
149+
FACTORY_FN = to_affine_quantized_intx
150+
kwargs = {
151+
"mapping_type": MappingType.ASYMMETRIC,
152+
"block_size": (1, 32),
153+
"target_dtype": torch.uint8,
154+
}
155+
# minimum sqnr for linear operation when the weight is quantized to low precision
156+
# with the above setting
157+
LINEAR_MIN_SQNR = 40
158+
COMPILE_MIN_SQNR = 50
159+
160+
@common_utils.parametrize("device", COMMON_DEVICES)
161+
@common_utils.parametrize("dtype", COMMON_DTYPES)
162+
def test_input_output_tensor_subclass(self, device, dtype):
163+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
164+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
165+
def f(tensor):
166+
return tensor
167+
168+
ref = f(lp_tensor)
169+
f = torch.compile(f)
170+
compiled = f(lp_tensor)
171+
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
172+
self.assertEqual(ref.dequantize(), compiled.dequantize())
173+
174+
@common_utils.parametrize("device", COMMON_DEVICES)
175+
@common_utils.parametrize("dtype", COMMON_DTYPES)
176+
def test_input_tensor_subclass(self, device, dtype):
177+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
178+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
179+
def f(tensor):
180+
return tensor.dequantize()
181+
182+
ref = f(lp_tensor)
183+
f = torch.compile(f)
184+
compiled = f(lp_tensor)
185+
self.assertFalse(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
186+
self.assertEqual(ref, compiled)
187+
188+
@common_utils.parametrize("device", COMMON_DEVICES)
189+
@common_utils.parametrize("dtype", COMMON_DTYPES)
190+
def test_output_tensor_subclass(self, device, dtype):
191+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
192+
def f(hp_tensor):
193+
return self.FACTORY_FN(hp_tensor, **self.kwargs)
194+
195+
ref = f(hp_tensor)
196+
f = torch.compile(f)
197+
compiled = f(hp_tensor)
198+
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
199+
# bfloat16 seems to result in much larger numerical differences
200+
if dtype != torch.bfloat16:
201+
self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR)
202+
145203
@common_utils.parametrize("device", COMMON_DEVICES)
146204
@common_utils.parametrize("dtype", COMMON_DTYPES)
147205
def test_linear_compile(self, device, dtype):
@@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype):
155213
lp_res = torch.compile(l)(hp_act_tensor)
156214
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
157215

216+
217+
158218
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
219+
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
159220

160221
if __name__ == "__main__":
161222
unittest.main()

0 commit comments

Comments
 (0)