Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,10 +3003,16 @@ def _quantize_block(
if not self.not_use_best_mse:
last_loss = best_loss
best_iter = last_best_iter
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
)
if self.iters > 0:
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
)
else:
dump_info = (
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
"layers in the block"
)

if self.low_gpu_mem_usage:
clear_memory(device_list=self.device_list) # clear cached memory during training
Expand Down
36 changes: 21 additions & 15 deletions auto_round/compressors/mllm/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from auto_round.compressors.base import BaseCompressor
from auto_round.compressors.mllm.dataset import get_mllm_dataloader
from auto_round.compressors.mllm.template import Template, get_template
from auto_round.compressors.mllm.template import TEMPLATES, Template, get_template
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.special_model_handler import (
Expand Down Expand Up @@ -200,23 +200,27 @@ def __init__(
self.image_processor = image_processor
from transformers import PreTrainedModel

if model.config.model_type == "llava" and isinstance(model, PreTrainedModel):
# if model is not the object of transformers PreTrainedModel, there maybe no config attribute
if isinstance(model, PreTrainedModel) and model.config.model_type == "llava":
template = "default"
if hasattr(model, "name_or_path") and any([name in model.name_or_path for name in MISTRAL_3_2_MODELS]):
template = "mistral3_2"
if iters > 0:
self.template = template if template is not None else model.config.model_type
if not isinstance(dataset, torch.utils.data.DataLoader):
self.template = get_template(
self.template,
model=model,
tokenizer=tokenizer,
processor=processor,
image_processor=image_processor,
use_rtn=iters == 0,
quiet=not self.quant_nontext_module,
)
dataset = self.template.default_dataset if dataset is None else dataset
if template is None and model.config.model_type not in TEMPLATES:
self.template = None
else:
self.template = template if template is not None else model.config.model_type
if not isinstance(dataset, torch.utils.data.DataLoader):
self.template = get_template(
self.template,
model=model,
tokenizer=tokenizer,
processor=processor,
image_processor=image_processor,
use_rtn=iters == 0,
quiet=not self.quant_nontext_module,
)
dataset = self.template.default_dataset if dataset is None else dataset
else:
self.template = None

Expand All @@ -233,7 +237,9 @@ def __init__(
" switching to liuhaotian/llava_conv_58k"
)
dataset = "liuhaotian/llava_conv_58k"
elif not _only_text_test(model, tokenizer, self.device, self.template.model_type):
elif self.template is not None and not _only_text_test(
model, tokenizer, self.device, self.template.model_type
):
logger.warning(
f"{model.config.model_type} does not support for {dataset},"
" will use liuhaotian/llava_conv_58k with default config as an alternative."
Expand Down
2 changes: 1 addition & 1 deletion auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_moe_converter(config):


def _handle_special_model(model):
if model.config.model_type == "deepseek_vl_v2":
if hasattr(model, "config") and model.config.model_type == "deepseek_vl_v2":
from functools import partial

model.forward = partial(_deepseek_vl2_forward, model)
Expand Down