@@ -209,53 +209,6 @@ def test_print_quantized_module(self, apply_quant):
209209 ql = apply_quant (linear )
210210 assert "AffineQuantizedTensor" in str (ql )
211211
212- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
213- @common_utils .parametrize (
214- "apply_quant" , get_quantization_functions (False , True , "cuda" , False )
215- )
216- def test_copy_ (self , apply_quant ):
217- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
218- linear2 = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
219-
220- if isinstance (apply_quant , AOBaseConfig ):
221- quantize_ (linear , apply_quant )
222- ql = linear
223- quantize_ (linear2 , apply_quant )
224- ql2 = linear2
225- else :
226- ql = apply_quant (linear )
227- ql2 = apply_quant (linear2 )
228-
229- example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" )
230- output = ql (example_input )
231- ql2 .weight .copy_ (ql .weight )
232- ql2 .bias = ql .bias
233- output2 = ql2 (example_input )
234- self .assertEqual (output , output2 )
235-
236- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
237- @common_utils .parametrize (
238- "apply_quant" , get_quantization_functions (False , True , "cuda" , False )
239- )
240- def test_copy__mismatch_metadata (self , apply_quant ):
241- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
242- linear2 = torch .nn .Linear (128 , 512 , dtype = torch .bfloat16 , device = "cuda" )
243-
244- if isinstance (apply_quant , AOBaseConfig ):
245- quantize_ (linear , apply_quant )
246- ql = linear
247- quantize_ (linear2 , apply_quant )
248- ql2 = linear2
249- else :
250- ql = apply_quant (linear )
251- ql2 = apply_quant (linear2 )
252-
253- # copy should fail due to shape mismatch
254- with self .assertRaisesRegex (
255- ValueError , "Not supported args for copy_ due to metadata mistach:"
256- ):
257- ql2 .weight .copy_ (ql .weight )
258-
259212
260213class TestAffineQuantizedBasic (TestCase ):
261214 COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
0 commit comments