Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions tests/quantization/quanto_integration/test_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
require_optimum_quanto,
require_read_token,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -181,11 +180,11 @@ def test_generate_quality_cpu(self):
"""
self.check_inference_correctness(self.quantized_model, "cpu")

def test_generate_quality_cuda(self):
def test_generate_quality_accelerator(self):
"""
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model, "cuda")
self.check_inference_correctness(self.quantized_model, torch_device)

def test_quantized_model_layers(self):
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
Expand Down Expand Up @@ -215,7 +214,7 @@ def test_quantized_model_layers(self):
)
self.quantized_model.to(0)
self.assertEqual(
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, torch_device
)

def test_serialization_bin(self):
Expand Down Expand Up @@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa
weights = "int4"


@require_torch_gpu
@require_torch_accelerator
class QuantoQuantizationActivationTest(unittest.TestCase):
def test_quantize_activation(self):
quantization_config = QuantoConfig(
Expand All @@ -443,7 +442,7 @@ def test_quantize_activation(self):


@require_optimum_quanto
@require_torch_gpu
@require_torch_accelerator
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
Expand Down