Skip to content

Commit 13c4c81

Browse files
mengniwang95chensuyue
authored andcommitted
Fix code for llmc llama4 quantization (#1161)
Signed-off-by: Mengni Wang <[email protected]> (cherry picked from commit 3c88b3b)
1 parent e99da5c commit 13c4c81

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

auto_round/compressors/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,10 +3004,16 @@ def _quantize_block(
30043004
if not self.not_use_best_mse:
30053005
last_loss = best_loss
30063006
best_iter = last_best_iter
3007-
dump_info = (
3008-
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
3009-
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
3010-
)
3007+
if self.iters > 0:
3008+
dump_info = (
3009+
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
3010+
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
3011+
)
3012+
else:
3013+
dump_info = (
3014+
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
3015+
"layers in the block"
3016+
)
30113017

30123018
if self.low_gpu_mem_usage:
30133019
clear_memory(device_list=self.device_list) # clear cached memory during training

auto_round/compressors/mllm/compressor.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from auto_round.compressors.base import BaseCompressor
2626
from auto_round.compressors.mllm.dataset import get_mllm_dataloader
27-
from auto_round.compressors.mllm.template import Template, get_template
27+
from auto_round.compressors.mllm.template import TEMPLATES, Template, get_template
2828
from auto_round.logger import logger
2929
from auto_round.schemes import QuantizationScheme
3030
from auto_round.special_model_handler import (
@@ -200,23 +200,27 @@ def __init__(
200200
self.image_processor = image_processor
201201
from transformers import PreTrainedModel
202202

203-
if model.config.model_type == "llava" and isinstance(model, PreTrainedModel):
203+
# if model is not the object of transformers PreTrainedModel, there maybe no config attribute
204+
if isinstance(model, PreTrainedModel) and model.config.model_type == "llava":
204205
template = "default"
205206
if hasattr(model, "name_or_path") and any([name in model.name_or_path for name in MISTRAL_3_2_MODELS]):
206207
template = "mistral3_2"
207208
if iters > 0:
208-
self.template = template if template is not None else model.config.model_type
209-
if not isinstance(dataset, torch.utils.data.DataLoader):
210-
self.template = get_template(
211-
self.template,
212-
model=model,
213-
tokenizer=tokenizer,
214-
processor=processor,
215-
image_processor=image_processor,
216-
use_rtn=iters == 0,
217-
quiet=not self.quant_nontext_module,
218-
)
219-
dataset = self.template.default_dataset if dataset is None else dataset
209+
if template is None and model.config.model_type not in TEMPLATES:
210+
self.template = None
211+
else:
212+
self.template = template if template is not None else model.config.model_type
213+
if not isinstance(dataset, torch.utils.data.DataLoader):
214+
self.template = get_template(
215+
self.template,
216+
model=model,
217+
tokenizer=tokenizer,
218+
processor=processor,
219+
image_processor=image_processor,
220+
use_rtn=iters == 0,
221+
quiet=not self.quant_nontext_module,
222+
)
223+
dataset = self.template.default_dataset if dataset is None else dataset
220224
else:
221225
self.template = None
222226

@@ -233,7 +237,9 @@ def __init__(
233237
" switching to liuhaotian/llava_conv_58k"
234238
)
235239
dataset = "liuhaotian/llava_conv_58k"
236-
elif not _only_text_test(model, tokenizer, self.device, self.template.model_type):
240+
elif self.template is not None and not _only_text_test(
241+
model, tokenizer, self.device, self.template.model_type
242+
):
237243
logger.warning(
238244
f"{model.config.model_type} does not support for {dataset},"
239245
" will use liuhaotian/llava_conv_58k with default config as an alternative."

auto_round/special_model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _get_moe_converter(config):
6060

6161

6262
def _handle_special_model(model):
63-
if model.config.model_type == "deepseek_vl_v2":
63+
if hasattr(model, "config") and model.config.model_type == "deepseek_vl_v2":
6464
from functools import partial
6565

6666
model.forward = partial(_deepseek_vl2_forward, model)

0 commit comments

Comments
 (0)