2222 require_optimum_quanto ,
2323 require_read_token ,
2424 require_torch_accelerator ,
25- require_torch_gpu ,
2625 slow ,
2726 torch_device ,
2827)
@@ -181,11 +180,11 @@ def test_generate_quality_cpu(self):
181180 """
182181 self .check_inference_correctness (self .quantized_model , "cpu" )
183182
184- def test_generate_quality_cuda (self ):
183+ def test_generate_quality_accelerator (self ):
185184 """
186- Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
185+ Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens
187186 """
188- self .check_inference_correctness (self .quantized_model , "cuda" )
187+ self .check_inference_correctness (self .quantized_model , torch_device )
189188
190189 def test_quantized_model_layers (self ):
191190 from optimum .quanto import QBitsTensor , QModuleMixin , QTensor
@@ -215,7 +214,7 @@ def test_quantized_model_layers(self):
215214 )
216215 self .quantized_model .to (0 )
217216 self .assertEqual (
218- self .quantized_model .transformer .h [0 ].self_attention .query_key_value .weight ._data .device .type , "cuda"
217+ self .quantized_model .transformer .h [0 ].self_attention .query_key_value .weight ._data .device .type , torch_device
219218 )
220219
221220 def test_serialization_bin (self ):
@@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa
430429 weights = "int4"
431430
432431
433- @require_torch_gpu
432+ @require_torch_accelerator
434433class QuantoQuantizationActivationTest (unittest .TestCase ):
435434 def test_quantize_activation (self ):
436435 quantization_config = QuantoConfig (
@@ -443,7 +442,7 @@ def test_quantize_activation(self):
443442
444443
445444@require_optimum_quanto
446- @require_torch_gpu
445+ @require_torch_accelerator
447446class QuantoKVCacheQuantizationTest (unittest .TestCase ):
448447 @slow
449448 @require_read_token
0 commit comments