Skip to content

Commit 2338c7f

Browse files
committed
Added 8-bit weight compression for OVModel
1 parent 681b946 commit 2338c7f

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

optimum/intel/openvino/quantization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,6 @@ def quantize(
163163
raise ValueError("`save_directory` needs to be specified")
164164

165165
if weights_only:
166-
if isinstance(self.model, OVBaseModel):
167-
raise ValueError(
168-
"`weights_only` currently not supported for `OVModels`, only available for torch.nn.Module."
169-
)
170166
if calibration_dataset is not None:
171167
logger.warning(
172168
"`calibration_dataset` was provided but will not be used as `weights_only` is set to `True`."
@@ -185,6 +181,7 @@ def quantize(
185181
batch_size,
186182
data_collator,
187183
remove_unused_columns,
184+
weights_only,
188185
**kwargs,
189186
)
190187
elif isinstance(self.model, OVBaseModel):
@@ -194,6 +191,7 @@ def quantize(
194191
batch_size,
195192
data_collator,
196193
remove_unused_columns,
194+
weights_only,
197195
**kwargs,
198196
)
199197
elif isinstance(self.model, torch.nn.Module):
@@ -217,11 +215,17 @@ def _quantize_ovbasemodel(
217215
batch_size: int = 1,
218216
data_collator: Optional[DataCollator] = None,
219217
remove_unused_columns: bool = True,
218+
weights_only: bool = False,
220219
**kwargs,
221220
):
222221
save_directory = Path(save_directory)
223222
save_directory.mkdir(parents=True, exist_ok=True)
224223

224+
if weights_only:
225+
self.model.model = nncf.compress_weights(self.model.model)
226+
self.model.save_pretrained(save_directory)
227+
return
228+
225229
calibration_dataloader = self._get_calibration_dataloader(
226230
calibration_dataset=calibration_dataset,
227231
batch_size=batch_size,
@@ -247,11 +251,17 @@ def _quantize_ovcausallm(
247251
batch_size: int = 1,
248252
data_collator: Optional[DataCollator] = None,
249253
remove_unused_columns: bool = True,
254+
weights_only: bool = False,
250255
**kwargs,
251256
):
252257
save_directory = Path(save_directory)
253258
save_directory.mkdir(parents=True, exist_ok=True)
254259

260+
if weights_only:
261+
self.model.model = nncf.compress_weights(self.model.model)
262+
self.model.save_pretrained(save_directory)
263+
return
264+
255265
calibration_dataloader = self._get_calibration_dataloader(
256266
calibration_dataset=calibration_dataset,
257267
batch_size=batch_size,

tests/openvino/test_quantization.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,12 @@ def preprocess_function(examples, tokenizer):
146146
class OVWeightCompressionTest(unittest.TestCase):
147147
# TODO : add models
148148
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = (
149-
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 39),
150-
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 5),
149+
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 39, 35),
150+
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 5, 23),
151151
)
152152

153153
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
154-
def test_automodel_weight_compression(self, model_cls, model_name, expected_int8):
154+
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
155155
task = model_cls.export_feature
156156

157157
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -166,7 +166,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8
166166

167167
# TODO: uncomment once move to a newer version of NNCF which has some fixes
168168
_, num_int8 = get_num_quantized_nodes(model)
169-
self.assertEqual(expected_int8, num_int8)
169+
self.assertEqual(expected_pt_int8, num_int8)
170170

171171
tokens = tokenizer("This is a sample input", return_tensors="pt")
172172
outputs = model(**tokens)
@@ -177,6 +177,27 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_int8
177177
loaded_config = OVConfig.from_pretrained(tmp_dir)
178178
self.assertEqual(expected_config.to_dict()["compression"], loaded_config.to_dict()["compression"])
179179

180+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS)
181+
def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
182+
task = model_cls.export_feature
183+
184+
with tempfile.TemporaryDirectory() as tmp_dir:
185+
transformers_model = model_cls.from_pretrained(model_name, export=True)
186+
tokenizer = AutoTokenizer.from_pretrained(model_name)
187+
if tokenizer.pad_token is None:
188+
tokenizer.pad_token = tokenizer.eos_token
189+
190+
quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
191+
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
192+
model = model_cls.from_pretrained(tmp_dir)
193+
194+
_, num_int8 = get_num_quantized_nodes(model)
195+
self.assertEqual(expected_ov_int8, num_int8)
196+
197+
tokens = tokenizer("This is a sample input", return_tensors="pt")
198+
outputs = model(**tokens)
199+
self.assertTrue("logits" in outputs)
200+
180201

181202
class OVQuantizerQATest(unittest.TestCase):
182203
SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),)

0 commit comments

Comments
 (0)