Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,6 @@ def quantize(
raise ValueError("`save_directory` needs to be specified")

if weights_only:
if isinstance(self.model, OVBaseModel):
raise ValueError(
"`weights_only` currently not supported for `OVModels`, only available for torch.nn.Module."
)
if calibration_dataset is not None:
logger.warning(
"`calibration_dataset` was provided but will not be used as `weights_only` is set to `True`."
Expand All @@ -189,6 +185,7 @@ def quantize(
batch_size,
data_collator,
remove_unused_columns,
weights_only,
**kwargs,
)
elif isinstance(self.model, OVBaseModel):
Expand All @@ -198,6 +195,7 @@ def quantize(
batch_size,
data_collator,
remove_unused_columns,
weights_only,
**kwargs,
)
elif isinstance(self.model, torch.nn.Module):
Expand All @@ -221,11 +219,17 @@ def _quantize_ovbasemodel(
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
self.model.save_pretrained(save_directory)
return

calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
Expand All @@ -251,11 +255,17 @@ def _quantize_ovcausallm(
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

if weights_only:
self.model.model = nncf.compress_weights(self.model.model)
self.model.save_pretrained(save_directory)
return

calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
Expand Down
29 changes: 25 additions & 4 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def preprocess_function(examples, tokenizer):
class OVWeightCompressionTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = (
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45),
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35),
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22),
Comment on lines +149 to +150
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference in the quantization applied on the two models (depending on whether this is a pytorch or an openvino model) ?

)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
def test_automodel_weight_compression(self, model_cls, model_name, expected_int8):
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -166,7 +166,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8

# TODO: uncomment once move to a newer version of NNCF which has some fixes
_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_pt_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
Expand All @@ -176,6 +176,27 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertIsNotNone(loaded_config)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature

with tempfile.TemporaryDirectory() as tmp_dir:
transformers_model = model_cls.from_pretrained(model_name, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

_, num_int8 = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)


class OVQuantizerQATest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),)
Expand Down