diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py
index b8b7bd368..a74ce02fa 100644
--- a/gptqmodel/models/_const.py
+++ b/gptqmodel/models/_const.py
@@ -99,6 +99,7 @@ def get_best_device(backend: BACKEND=BACKEND.AUTO) -> torch.device:
"baichuan",
"internlm",
"internlm2",
+ "internvl_chat",
"qwen",
"xverse",
"deci",
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index 5b2d5da33..d20348e2c 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -3,6 +3,7 @@
import os
import sys
+from .definitions.internvl_chat import InternVLChatGPTQ
# TODO: waiting for pytorch implementgation of aten ops for MPS
if sys.platform == "darwin":
@@ -94,6 +95,7 @@
"baichuan": BaiChuanGPTQ,
"internlm": InternLMGPTQ,
"internlm2": InternLM2GPTQ,
+ "internvl_chat": InternVLChatGPTQ,
"qwen": QwenGPTQ,
"mistral": MistralGPTQ,
"Yi": YiGPTQ,
diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py
index 90eecc33a..811e1de68 100644
--- a/gptqmodel/models/definitions/__init__.py
+++ b/gptqmodel/models/definitions/__init__.py
@@ -21,6 +21,7 @@
from .hymba import HymbaGPTQ
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
+from .internvl_chat import InternVLChatGPTQ
from .llama import LlamaGPTQ
from .longllama import LongLlamaGPTQ
from .minicpm3 import MiniCPM3GPTQ
diff --git a/gptqmodel/models/definitions/internvl_chat.py b/gptqmodel/models/definitions/internvl_chat.py
new file mode 100644
index 000000000..1819fec7e
--- /dev/null
+++ b/gptqmodel/models/definitions/internvl_chat.py
@@ -0,0 +1,153 @@
+from typing import Dict
+
+import torch
+
+from transformers import AutoTokenizer
+from ..base import BaseGPTQModel
+from ...utils.calibration import batched
+import torchvision.transforms as T
+from torchvision.transforms.functional import InterpolationMode
+
+from ...utils.image import fetch_image
+from ...utils.model import MODALITY
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+def build_transform(input_size):
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ return transform
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+
+def load_image(image, input_size=448, max_num=12):
+ image = image.convert('RGB')
+ transform = build_transform(input_size=input_size)
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
+ pixel_values = [transform(image) for image in images]
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
+
+
+class InternVLChatGPTQ(BaseGPTQModel):
+ IMG_START_TOKEN = '
'
+ IMG_END_TOKEN = ''
+ IMG_CONTEXT_TOKEN = ''
+
+ require_pkgs_version = ["transformers<=4.44.2", "timm>=1.0.12", "torchvision>=0.20.1"]
+
+ base_modules = ["language_model.model.tok_embeddings", "language_model.model.norm"]
+
+ layers_node = "language_model.model.layers"
+ layer_type = "InternLM2DecoderLayer"
+ layer_modules = [
+ ["attention.wqkv", "attention.wo"],
+
+ ["feed_forward.w1", "feed_forward.w3"],
+ ["feed_forward.w2"],
+ ]
+
+ modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT]
+
+ def preprocess_dataset(self, sample: Dict) -> Dict:
+ template = self.model.conv_template
+ template.append_message(template.roles[0], sample["question"])
+ template.append_message(template.roles[1], sample["answer"])
+ query = template.get_prompt()
+
+ pixel_values = load_image(fetch_image(sample), max_num=12).to(torch.bfloat16)
+ num_patches = pixel_values.size(0)
+ image_tokens = self.IMG_START_TOKEN + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + self.IMG_END_TOKEN
+ query = query.replace('', image_tokens, 1)
+ image_flags = torch.tensor([1] * num_patches, dtype=torch.long)
+ return {
+ "query": query,
+ "pixel_values": pixel_values,
+ "image_flags": image_flags,
+ }
+
+ def prepare_dataset(
+ self,
+ calibration_dataset,
+ batch_size: int = 1,
+ tokenizer=None, ):
+ if tokenizer is None:
+ tokenizer = AutoTokenizer.from_pretrained(self.model_local_path, trust_remote_code=True)
+
+ tokenizer.padding_side = 'left'
+
+ calib_data = []
+ for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset):
+ queries, pixel_values, image_flags = tuple(
+ [instance[key] for instance in batch] for key in ("query", "pixel_values", "image_flags"))
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
+ input_ids = model_inputs['input_ids']
+ attention_mask = model_inputs['attention_mask']
+
+ pixel_values = torch.cat(pixel_values, dim=0)
+ image_flags = torch.cat(image_flags, dim=0)
+
+ calib_data.append({
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "image_flags": image_flags,
+ })
+ return calib_data
diff --git a/tests/models/model_test.py b/tests/models/model_test.py
index 10e2d44b1..c8a442ff6 100644
--- a/tests/models/model_test.py
+++ b/tests/models/model_test.py
@@ -138,9 +138,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut
is_quantized = model.quantized
- # ovis cannot load processor
- is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
- need_create_processor = is_image_to_text_model and not is_ovis_model
+ is_qwen2vl_model = model.__class__.__name__ == "Qwen2VLGPTQ"
+ need_create_processor = is_image_to_text_model and is_qwen2vl_model
if not is_quantized:
model.quantize(calibration_dataset, batch_size=batch_size)
diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py
index bc0eccecb..08ec97e11 100644
--- a/tests/models/ovis/image_to_test_dataset.py
+++ b/tests/models/ovis/image_to_test_dataset.py
@@ -1,4 +1,4 @@
-from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ
+from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ, InternVLChatGPTQ
def format_ovis_dataset(image, assistant):
@@ -29,6 +29,12 @@ def format_qwen2_vl_dataset(image, assistant):
{"role": "assistant", "content": assistant},
]
+def format_internlm2_vl_dataset(image, assistant):
+ return {
+ "image": image,
+ "question": f"\nDescribe the image in detail.",
+ "answer": assistant,
+ }
def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]:
from datasets import load_dataset
@@ -49,4 +55,7 @@ def get_calib_dataset(model):
if isinstance(model, Qwen2VLGPTQ):
return prepare_dataset(format_qwen2_vl_dataset, n_sample=1)
+ if isinstance(model, InternVLChatGPTQ):
+ return prepare_dataset(format_internlm2_vl_dataset, n_sample=1)
+
raise NotImplementedError(f"Unsupported MODEL: {model.__class__}")
diff --git a/tests/models/test_internvl_chat.py b/tests/models/test_internvl_chat.py
new file mode 100644
index 000000000..abcd3cdcf
--- /dev/null
+++ b/tests/models/test_internvl_chat.py
@@ -0,0 +1,20 @@
+from model_test import ModelTest
+
+
+class TestInternlm2_VL(ModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/InternVL2-8B-MPO"
+ NATIVE_ARC_CHALLENGE_ACC = 0.3217
+ NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3575
+ APPLY_CHAT_TEMPLATE = True
+ TRUST_REMOTE_CODE = True
+ BATCH_SIZE = 6
+ USE_VLLM = False
+
+
+ def test_internlm2_5(self):
+ # transformers<=4.44.2 run normal
+ model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE,
+ torch_dtype=self.TORCH_DTYPE, use_flash_attn=False)
+
+
+