@@ -69,8 +69,6 @@ def new_test(self, value=value):
69
69
70
70
71
71
class TorchAOBasicTestCase (common_utils .TestCase ):
72
- """Basic test case for tensor subclasses
73
- """
74
72
COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
75
73
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
76
74
@@ -142,6 +140,66 @@ def test_linear(self, device, dtype):
142
140
lp_res = torch .nn .functional .linear (hp_act_tensor , lp_tensor )
143
141
self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
144
142
143
+
144
+ class TorchAOCompileTestCase (common_utils .TestCase ):
145
+ COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
146
+ COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
147
+
148
+ TENSOR_SUBCLASS = AffineQuantizedTensor
149
+ FACTORY_FN = to_affine_quantized_intx
150
+ kwargs = {
151
+ "mapping_type" : MappingType .ASYMMETRIC ,
152
+ "block_size" : (1 , 32 ),
153
+ "target_dtype" : torch .uint8 ,
154
+ }
155
+ # minimum sqnr for linear operation when the weight is quantized to low precision
156
+ # with the above setting
157
+ LINEAR_MIN_SQNR = 40
158
+ COMPILE_MIN_SQNR = 50
159
+
160
+ @common_utils .parametrize ("device" , COMMON_DEVICES )
161
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
162
+ def test_input_output_tensor_subclass (self , device , dtype ):
163
+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
164
+ lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
165
+ def f (tensor ):
166
+ return tensor
167
+
168
+ ref = f (lp_tensor )
169
+ f = torch .compile (f )
170
+ compiled = f (lp_tensor )
171
+ self .assertTrue (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
172
+ self .assertEqual (ref .dequantize (), compiled .dequantize ())
173
+
174
+ @common_utils .parametrize ("device" , COMMON_DEVICES )
175
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
176
+ def test_input_tensor_subclass (self , device , dtype ):
177
+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
178
+ lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
179
+ def f (tensor ):
180
+ return tensor .dequantize ()
181
+
182
+ ref = f (lp_tensor )
183
+ f = torch .compile (f )
184
+ compiled = f (lp_tensor )
185
+ self .assertFalse (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
186
+ self .assertEqual (ref , compiled )
187
+
188
+ @common_utils .parametrize ("device" , COMMON_DEVICES )
189
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
190
+ def test_output_tensor_subclass (self , device , dtype ):
191
+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
192
+ def f (hp_tensor ):
193
+ return self .FACTORY_FN (hp_tensor , ** self .kwargs )
194
+
195
+ ref = f (hp_tensor )
196
+ f = torch .compile (f )
197
+ compiled = f (hp_tensor )
198
+ self .assertTrue (isinstance (f (hp_tensor ), self .TENSOR_SUBCLASS ))
199
+ # bfloat16 seems to result in much larger numerical differences
200
+ if dtype != torch .bfloat16 :
201
+ self .assertGreater (torchao .quantization .utils .compute_error (ref .dequantize (), compiled .dequantize ()), self .COMPILE_MIN_SQNR )
202
+
145
203
@common_utils .parametrize ("device" , COMMON_DEVICES )
146
204
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
147
205
def test_linear_compile (self , device , dtype ):
@@ -155,7 +213,10 @@ def test_linear_compile(self, device, dtype):
155
213
lp_res = torch .compile (l )(hp_act_tensor )
156
214
self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
157
215
216
+
217
+
158
218
common_utils .instantiate_parametrized_tests (TorchAOBasicTestCase )
219
+ common_utils .instantiate_parametrized_tests (TorchAOCompileTestCase )
159
220
160
221
if __name__ == "__main__" :
161
222
unittest .main ()
0 commit comments