2424
2525from auto_round .compressors .base import BaseCompressor
2626from auto_round .compressors .mllm .dataset import get_mllm_dataloader
27- from auto_round .compressors .mllm .template import Template , get_template
27+ from auto_round .compressors .mllm .template import TEMPLATES , Template , get_template
2828from auto_round .logger import logger
2929from auto_round .schemes import QuantizationScheme
3030from auto_round .special_model_handler import (
@@ -200,23 +200,27 @@ def __init__(
200200 self .image_processor = image_processor
201201 from transformers import PreTrainedModel
202202
203- if model .config .model_type == "llava" and isinstance (model , PreTrainedModel ):
203+ # if model is not the object of transformers PreTrainedModel, there maybe no config attribute
204+ if isinstance (model , PreTrainedModel ) and model .config .model_type == "llava" :
204205 template = "default"
205206 if hasattr (model , "name_or_path" ) and any ([name in model .name_or_path for name in MISTRAL_3_2_MODELS ]):
206207 template = "mistral3_2"
207208 if iters > 0 :
208- self .template = template if template is not None else model .config .model_type
209- if not isinstance (dataset , torch .utils .data .DataLoader ):
210- self .template = get_template (
211- self .template ,
212- model = model ,
213- tokenizer = tokenizer ,
214- processor = processor ,
215- image_processor = image_processor ,
216- use_rtn = iters == 0 ,
217- quiet = not self .quant_nontext_module ,
218- )
219- dataset = self .template .default_dataset if dataset is None else dataset
209+ if template is None and model .config .model_type not in TEMPLATES :
210+ self .template = None
211+ else :
212+ self .template = template if template is not None else model .config .model_type
213+ if not isinstance (dataset , torch .utils .data .DataLoader ):
214+ self .template = get_template (
215+ self .template ,
216+ model = model ,
217+ tokenizer = tokenizer ,
218+ processor = processor ,
219+ image_processor = image_processor ,
220+ use_rtn = iters == 0 ,
221+ quiet = not self .quant_nontext_module ,
222+ )
223+ dataset = self .template .default_dataset if dataset is None else dataset
220224 else :
221225 self .template = None
222226
@@ -233,7 +237,9 @@ def __init__(
233237 " switching to liuhaotian/llava_conv_58k"
234238 )
235239 dataset = "liuhaotian/llava_conv_58k"
236- elif not _only_text_test (model , tokenizer , self .device , self .template .model_type ):
240+ elif self .template is not None and not _only_text_test (
241+ model , tokenizer , self .device , self .template .model_type
242+ ):
237243 logger .warning (
238244 f"{ model .config .model_type } does not support for { dataset } ,"
239245 " will use liuhaotian/llava_conv_58k with default config as an alternative."
0 commit comments