Skip to content

Commit 72c4a43

Browse files
committed
Add compile tests to test suite
Summary: This is a follow up PR addressing pytorch#839 (comment) We can add more compiler related tests in the future. Next * refactor a bit to use quantize_ API directly * use the test suite in existing API tests Test Plan: python torchao/testing/utils.py Reviewers: Subscribers: Tasks: Tags:
1 parent 53b6b78 commit 72c4a43

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

torchao/testing/utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def new_test(self, value=value):
6969

7070

7171
class TorchAOBasicTestCase(common_utils.TestCase):
72-
"""Basic test case for tensor subclasses
73-
"""
7472
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
7573
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
7674

@@ -142,6 +140,43 @@ def test_linear(self, device, dtype):
142140
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
143141
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
144142

143+
144+
class TorchAOCompileTestCase(common_utils.TestCase):
145+
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
146+
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
147+
148+
TENSOR_SUBCLASS = AffineQuantizedTensor
149+
FACTORY_FN = to_affine_quantized_intx
150+
kwargs = {
151+
"mapping_type": MappingType.ASYMMETRIC,
152+
"block_size": (1, 32),
153+
"target_dtype": torch.uint8,
154+
}
155+
# minimum sqnr for linear operation when the weight is quantized to low precision
156+
# with the above setting
157+
LINEAR_MIN_SQNR = 40
158+
159+
@common_utils.parametrize("device", COMMON_DEVICES)
160+
@common_utils.parametrize("dtype", COMMON_DTYPES)
161+
def test_input_output(self, device, dtype):
162+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
163+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
164+
def f(tensor):
165+
return tensor.t()
166+
167+
f = torch.compile(f)
168+
self.assertTrue(isinstance(f(lp_tensor), self.TENSOR_SUBCLASS))
169+
170+
@common_utils.parametrize("device", COMMON_DEVICES)
171+
@common_utils.parametrize("dtype", COMMON_DTYPES)
172+
def test_input_output(self, device, dtype):
173+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
174+
def f(hp_tensor):
175+
return self.FACTORY_FN(hp_tensor, **self.kwargs)
176+
177+
f = torch.compile(f)
178+
self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS))
179+
145180
@common_utils.parametrize("device", COMMON_DEVICES)
146181
@common_utils.parametrize("dtype", COMMON_DTYPES)
147182
def test_linear_compile(self, device, dtype):
@@ -155,7 +190,10 @@ def test_linear_compile(self, device, dtype):
155190
lp_res = torch.compile(l)(hp_act_tensor)
156191
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
157192

193+
194+
158195
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
196+
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
159197

160198
if __name__ == "__main__":
161199
unittest.main()

0 commit comments

Comments
 (0)