Skip to content
Merged
45 changes: 27 additions & 18 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from torchao.quantization import (
Float8Tensor,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
Expand Down Expand Up @@ -118,17 +120,18 @@ def test_repr(self):
"""
Check that there is no error in the repr
"""
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
config = Int4WeightOnlyConfig(group_size=8, layout=TensorCoreTiledLayout())
quantization_config = TorchAoConfig(config, modules_to_not_convert=["conv"])
repr(quantization_config)

def test_json_serializable(self):
"""
Check that the config dict can be JSON serialized.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
config = Int4WeightOnlyConfig(group_size=32, layout=TensorCoreTiledLayout())
quantization_config = TorchAoConfig(config)
d = quantization_config.to_dict()
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
self.assertTrue("inner_k_tiles" in d["quant_type"]["default"]["_data"]["layout"]["_data"])
quantization_config.to_json_string(use_diff=False)


Expand Down Expand Up @@ -159,7 +162,8 @@ def test_int4wo_quant(self):
"""
Simple LLM model testing int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -181,7 +185,8 @@ def test_int4wo_quant_bfloat16_conversion(self):
"""
Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization
"""
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -203,7 +208,8 @@ def test_int8_dynamic_activation_int8_weight_quant(self):
"""
Simple LLM model testing int8_dynamic_activation_int8_weight
"""
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
config = Int8DynamicActivationInt8WeightConfig()
quant_config = TorchAoConfig(config)

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -426,7 +432,8 @@ def test_int4wo_offload(self):
"lm_head": 0,
}

quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -461,7 +468,8 @@ def test_int4wo_quant_multi_accelerator(self):
set ZE_AFFINITY_MASK=0,1 if you have more than 2 Intel XPUs
"""

quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
config = Int4WeightOnlyConfig(**self.quant_scheme_kwargs)
quant_config = TorchAoConfig(config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
Expand Down Expand Up @@ -511,12 +519,13 @@ class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme = "int4_weight_only"
quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
quant_scheme = Int4WeightOnlyConfig(**quant_scheme_kwargs)

device = "cpu"

# called only once for all test in this class
Expand All @@ -526,8 +535,8 @@ def setUpClass(cls):
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"

def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
self.quant_config = TorchAoConfig(self.quant_scheme)
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=dtype,
Expand All @@ -550,7 +559,7 @@ def check_serialization_expected_output(self, device, expected_output, safe_seri
"""
Test if we can serialize and load/infer the model again on the same device
"""
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
dtype = torch.bfloat16 if isinstance(self.quant_scheme, Int4WeightOnlyConfig) else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device)
Expand Down Expand Up @@ -606,7 +615,7 @@ def test_serialization_expected_output(self, config, expected_output):


class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_scheme, quant_scheme_kwargs = Int8DynamicActivationInt8WeightConfig(), {}

# called only once for all test in this class
@classmethod
Expand All @@ -623,7 +632,7 @@ def test_serialization_expected_output_on_accelerator(self):


class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
quant_scheme, quant_scheme_kwargs = Int8WeightOnlyConfig(), {}

# called only once for all test in this class
@classmethod
Expand All @@ -641,7 +650,7 @@ def test_serialization_expected_output_on_accelerator(self):

@require_torch_accelerator
class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1}
quant_scheme, quant_scheme_kwargs = Int4WeightOnlyConfig(**{"group_size": 32, "version": 1}), {}
device = f"{torch_device}:0"

# called only once for all test in this class
Expand All @@ -661,7 +670,7 @@ def setUpClass(cls):

@require_torch_accelerator
class TorchAoSerializationW8A8AcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_scheme, quant_scheme_kwargs = Int8DynamicActivationInt8WeightConfig(), {}
device = f"{torch_device}:0"

# called only once for all test in this class
Expand All @@ -673,7 +682,7 @@ def setUpClass(cls):

@require_torch_accelerator
class TorchAoSerializationW8AcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
quant_scheme, quant_scheme_kwargs = Int8WeightOnlyConfig(), {}
device = f"{torch_device}:0"

# called only once for all test in this class
Expand Down