Skip to content
Merged
67 changes: 40 additions & 27 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from torchao.quantization import (
Float8Tensor,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
Expand Down Expand Up @@ -118,17 +120,18 @@ def test_repr(self):
"""
Check that there is no error in the repr
"""
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
config = Int4WeightOnlyConfig(group_size=8, layout=TensorCoreTiledLayout())
quantization_config = TorchAoConfig(config, modules_to_not_convert=["conv"])
repr(quantization_config)

def test_json_serializable(self):
"""
Check that the config dict can be JSON serialized.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
config = Int4WeightOnlyConfig(group_size=32, layout=TensorCoreTiledLayout())
quantization_config = TorchAoConfig(config)
d = quantization_config.to_dict()
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
self.assertTrue("inner_k_tiles" in d["quant_type"]["default"]["_data"]["layout"]["_data"])
quantization_config.to_json_string(use_diff=False)


Expand Down Expand Up @@ -159,7 +162,8 @@ def test_int4wo_quant(self):
"""
Simple LLM model testing int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -181,7 +185,8 @@ def test_int4wo_quant_bfloat16_conversion(self):
"""
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -203,7 +208,8 @@ def test_int8_dynamic_activation_int8_weight_quant(self):
"""
Simple LLM model testing int8_dynamic_activation_int8_weight
"""
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
config = Int8DynamicActivationInt8WeightConfig()
quant_config = TorchAoConfig(config)

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -426,7 +432,8 @@ def test_int4wo_offload(self):
"lm_head": 0,
}

quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -461,7 +468,8 @@ def test_int4wo_quant_multi_accelerator(self):
set ZE_AFFINITY_MASK=0,1 if you have more than 2 Intel XPUs
"""

quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
Expand Down Expand Up @@ -505,29 +513,30 @@ def test_autoquant(self):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


@require_torchao
@require_torchao_version_greater_or_equal("0.8.0")
@require_torchao_version_greater_or_equal("0.11.0")
class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme = "int4_weight_only"
quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)

device = "cpu"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
if is_torchao_available()
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
cls.quant_scheme = Int4WeightOnlyConfig(**cls.quant_scheme_kwargs)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"

def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
self.quant_config = TorchAoConfig(self.quant_scheme)
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=dtype,
Expand All @@ -550,7 +559,7 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri
"""
Test if we can serialize and load/infer the model again on the same device
"""
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device)
Expand Down Expand Up @@ -606,12 +615,12 @@ def test_serialization_expected_output(self, config, expected_output):


class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}

# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.quant_scheme = Int8DynamicActivationInt8WeightConfig()
cls.quant_scheme_kwargs = {}
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"

@require_torch_accelerator
Expand All @@ -623,12 +632,12 @@ def test_serialization_expected_output_on_accelerator(self):


class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}

# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.quant_scheme = Int8WeightOnlyConfig()
cls.quant_scheme_kwargs = {}
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"

@require_torch_accelerator
Expand All @@ -640,15 +649,17 @@ def test_serialization_expected_output_on_accelerator(self):


@require_torch_accelerator
@require_torchao
class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1}
device = f"{torch_device}:0"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
# fmt: off
cls.quant_scheme = Int4WeightOnlyConfig(**{"group_size": 32, "version": 1})
cls.quant_scheme_kwargs = {}
EXPECTED_OUTPUTS = Expectations(
{
("xpu", 3): "What are we having for dinner?\n\nJessica: (smiling)",
Expand All @@ -661,25 +672,27 @@ def setUpClass(cls):

@require_torch_accelerator
class TorchAoSerializationW8A8AcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
device = f"{torch_device}:0"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.quant_scheme = Int8DynamicActivationInt8WeightConfig()
cls.quant_scheme_kwargs = {}
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"


@require_torch_accelerator
class TorchAoSerializationW8AcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
device = f"{torch_device}:0"

# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.quant_scheme = Int8WeightOnlyConfig()
cls.quant_scheme_kwargs = {}
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"


Expand Down