Skip to content

Commit 7c5bd24

Browse files
authored
[tests] make quanto tests device-agnostic (#36328)
* make device-agnostic * name change
1 parent 678885b commit 7c5bd24

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tests/quantization/quanto_integration/test_quanto.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
require_optimum_quanto,
2323
require_read_token,
2424
require_torch_accelerator,
25-
require_torch_gpu,
2625
slow,
2726
torch_device,
2827
)
@@ -181,11 +180,11 @@ def test_generate_quality_cpu(self):
181180
"""
182181
self.check_inference_correctness(self.quantized_model, "cpu")
183182

184-
def test_generate_quality_cuda(self):
183+
def test_generate_quality_accelerator(self):
185184
"""
186-
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
185+
Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens
187186
"""
188-
self.check_inference_correctness(self.quantized_model, "cuda")
187+
self.check_inference_correctness(self.quantized_model, torch_device)
189188

190189
def test_quantized_model_layers(self):
191190
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
@@ -215,7 +214,7 @@ def test_quantized_model_layers(self):
215214
)
216215
self.quantized_model.to(0)
217216
self.assertEqual(
218-
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
217+
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, torch_device
219218
)
220219

221220
def test_serialization_bin(self):
@@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa
430429
weights = "int4"
431430

432431

433-
@require_torch_gpu
432+
@require_torch_accelerator
434433
class QuantoQuantizationActivationTest(unittest.TestCase):
435434
def test_quantize_activation(self):
436435
quantization_config = QuantoConfig(
@@ -443,7 +442,7 @@ def test_quantize_activation(self):
443442

444443

445444
@require_optimum_quanto
446-
@require_torch_gpu
445+
@require_torch_accelerator
447446
class QuantoKVCacheQuantizationTest(unittest.TestCase):
448447
@slow
449448
@require_read_token

0 commit comments

Comments
 (0)