Skip to content

mtmd : support InternVL 2.5 and 3 #13422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,11 @@ def load_hparams(dir_model: Path):
logger.warning(f"Failed to load model config from {dir_model}: {e}")
logger.warning("Trying to load config.json instead")
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
config = json.load(f)
if "llm_config" in config:
# rename for InternVL
config["text_config"] = config["llm_config"]
return config

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -2606,6 +2610,11 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.hf_arch == "Qwen2Model":
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
if "language_model." in name:
name = name.replace("language_model.", "") # for InternVL
if name.startswith("mlp") or name.startswith("vision_model"):
# skip visual tensors
return []
yield from super().modify_tensors(data_torch, name, bid)


Expand Down Expand Up @@ -2709,6 +2718,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [] # skip other tensors


@ModelBase.register("InternVisionModel")
class InternVisionModel(VisionModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
if hparams["hidden_act"] == "silu":
self.gguf_writer.add_vision_use_silu(True)
elif hparams["hidden_act"] == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
else:
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
# downsample_ratio
downsample_ratio = self.global_config.get("downsample_ratio")
assert downsample_ratio is not None
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))

def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, name, n_dims # unused
if ".patch_embd." in new_name:
return gguf.GGMLQuantizationType.F16
if ".position_embd." in new_name:
return gguf.GGMLQuantizationType.F32
return False

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("vision_model") or name.startswith("mlp"):
# process visual tensors
# correct name
if name.startswith("vision_model"):
name = "vision_tower." + name
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
name += ".weight"
# split QKV tensors if needed
if ".qkv." in name:
if data_torch.ndim == 2: # weight
c3, _ = data_torch.shape
else: # bias
c3 = data_torch.shape[0]
assert c3 % 3 == 0
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c: c * 2]
wv = data_torch[c * 2:]
return [
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq),
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk),
(self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv),
]
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors


@ModelBase.register("WavTokenizerDec")
class WavTokenizerDecModel(TextModel):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
Expand Down Expand Up @@ -3360,6 +3425,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
head_dim = n_embd // num_heads
num_groups = num_heads // q_per_kv

name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model"):
# skip visual tensors
return []

if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
qkv = data_torch

Expand Down Expand Up @@ -3433,6 +3503,10 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model"):
# skip visual tensors
return []
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
Expand Down
8 changes: 8 additions & 0 deletions docs/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,12 @@ NOTE: some models may require large context window, for example: `-c 8192`

# Mistral Small 3.1 24B (IQ2_M quantization)
(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF

# InternVL 2.5 and 3
(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF
(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF
(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF
(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF
(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
```
7 changes: 7 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ class MODEL_TENSOR(IntEnum):
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
V_LAYER_SCALE_1 = auto()
V_LAYER_SCALE_2 = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
V_MM_INP_NORM = auto()
Expand Down Expand Up @@ -748,6 +750,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
Expand Down Expand Up @@ -786,6 +790,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_LAYER_SCALE_1,
MODEL_TENSOR.V_LAYER_SCALE_2,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ,
Expand Down Expand Up @@ -2167,6 +2173,7 @@ class VisionProjectorType:
PIXTRAL = "pixtral"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
INTERNVL = "internvl"


# Items here are (block size, type size)
Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ class TensorNameMap:

MODEL_TENSOR.V_MMPROJ_MLP: (
"model.mm_projector.mlp.mlp.{bid}",
"mlp1.{bid}", # InternVL
),

MODEL_TENSOR.V_MMPROJ_PEG: (
Expand Down Expand Up @@ -955,6 +956,7 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
Expand All @@ -963,6 +965,7 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
Expand All @@ -971,6 +974,7 @@ class TensorNameMap:

MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
Expand Down Expand Up @@ -1000,6 +1004,14 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
),

MODEL_TENSOR.V_LAYER_SCALE_1: (
"vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
),

MODEL_TENSOR.V_LAYER_SCALE_2: (
"vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
),

MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)

For older models, please refer to the relevant guide for instructions on how to obtain or create them:

Expand Down
11 changes: 6 additions & 5 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"

#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl

#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
Expand All @@ -60,8 +57,10 @@
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_LLAVA_PROJ "mm.%d.%s"
Expand Down Expand Up @@ -105,6 +104,7 @@ enum projector_type {
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -119,6 +119,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
Loading
Loading