Skip to content

Commit 5c5d38c

Browse files
committed
add check_neural_compressor_min_version for 4 bit behavior
Signed-off-by: Xin <[email protected]>
1 parent 5debcdf commit 5c5d38c

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

examples/text-generation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
check_habana_frameworks_version,
4141
check_optimum_habana_min_version,
4242
get_habana_frameworks_version,
43+
check_neural_compressor_min_version,
4344
set_seed,
4445
)
4546

@@ -269,9 +270,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
269270
original_model=org_model,
270271
**model_kwargs,
271272
)
272-
# TODO: This will be removed in v1.19 Synapse release
273-
# the loaded model should have the same dtype as original_model
274-
model = model.to(model_kwargs["torch_dtype"])
273+
if not check_neural_compressor_min_version("3.2"):
274+
model = model.to(model_kwargs["torch_dtype"])
275275
else:
276276
if args.assistant_model is not None:
277277
assistant_model = AutoModelForCausalLM.from_pretrained(

optimum/habana/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,14 @@ def check_habana_frameworks_version(req_version):
384384
)
385385

386386

387+
def check_neural_compressor_min_version(req_version):
388+
"""
389+
Checks if the installed version of `neural_compressor` is larger than or equal to `req_version`.
390+
"""
391+
import neural_compressor
392+
return version.Version(neural_compressor.__version__) >= version.Version(req_version)
393+
394+
387395
def get_device_name():
388396
"""
389397
Returns the name of the current device: Gaudi or Gaudi2.

0 commit comments

Comments
 (0)