@@ -69,8 +69,6 @@ def new_test(self, value=value):
69
69
70
70
71
71
class TorchAOBasicTestCase (common_utils .TestCase ):
72
- """Basic test case for tensor subclasses
73
- """
74
72
COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
75
73
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
76
74
@@ -142,6 +140,43 @@ def test_linear(self, device, dtype):
142
140
lp_res = torch .nn .functional .linear (hp_act_tensor , lp_tensor )
143
141
self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
144
142
143
+
144
+ class TorchAOCompileTestCase (common_utils .TestCase ):
145
+ COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
146
+ COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
147
+
148
+ TENSOR_SUBCLASS = AffineQuantizedTensor
149
+ FACTORY_FN = to_affine_quantized_intx
150
+ kwargs = {
151
+ "mapping_type" : MappingType .ASYMMETRIC ,
152
+ "block_size" : (1 , 32 ),
153
+ "target_dtype" : torch .uint8 ,
154
+ }
155
+ # minimum sqnr for linear operation when the weight is quantized to low precision
156
+ # with the above setting
157
+ LINEAR_MIN_SQNR = 40
158
+
159
+ @common_utils .parametrize ("device" , COMMON_DEVICES )
160
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
161
+ def test_input_output (self , device , dtype ):
162
+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
163
+ lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
164
+ def f (tensor ):
165
+ return tensor .t ()
166
+
167
+ f = torch .compile (f )
168
+ self .assertTrue (isinstance (f (lp_tensor ), self .TENSOR_SUBCLASS ))
169
+
170
+ @common_utils .parametrize ("device" , COMMON_DEVICES )
171
+ @common_utils .parametrize ("dtype" , COMMON_DTYPES )
172
+ def test_input_output (self , device , dtype ):
173
+ hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
174
+ def f (hp_tensor ):
175
+ return self .FACTORY_FN (hp_tensor , ** self .kwargs )
176
+
177
+ f = torch .compile (f )
178
+ self .assertTrue (isinstance (f (hp_tensor ), self .TENSOR_SUBCLASS ))
179
+
145
180
@common_utils .parametrize ("device" , COMMON_DEVICES )
146
181
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
147
182
def test_linear_compile (self , device , dtype ):
@@ -155,7 +190,10 @@ def test_linear_compile(self, device, dtype):
155
190
lp_res = torch .compile (l )(hp_act_tensor )
156
191
self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
157
192
193
+
194
+
158
195
common_utils .instantiate_parametrized_tests (TorchAOBasicTestCase )
196
+ common_utils .instantiate_parametrized_tests (TorchAOCompileTestCase )
159
197
160
198
if __name__ == "__main__" :
161
199
unittest .main ()
0 commit comments