diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index df3f8a55d5320..ff82a85a9d7cd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1089,6 +1089,8 @@ def __init__(self, *args, **kwargs): raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION") # get n_embd of the text model + if "text_config" not in self.hparams: + self.hparams["text_config"] = {} text_config = {**self.hparams, **self.hparams["text_config"]} self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) assert self.n_embd_text > 0, "n_embd not found in hparams" @@ -2583,6 +2585,82 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +class Qwen2VLVisionModel(VisionModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["image_size"] = self.hparams.get("image_size", 560) + # rename config.json values + self.hparams["num_attention_heads"] = self.hparams.get("num_heads") + self.hparams["num_hidden_layers"] = self.hparams.get("depth") + if "embed_dim" in self.hparams: # qwen2vl + self.hparams["intermediate_size"] = self.hparams.get("hidden_size") + self.hparams["hidden_size"] = self.hparams.get("embed_dim") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if self.global_config['model_type'] == 'qwen2_vl': + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL) + elif self.global_config['model_type'] == 'qwen2_5_vl': + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL) + self.gguf_writer.add_vision_use_silu(True) + # find n_wa_pattern (window attention pattern) + fullatt_block_indexes = hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl" + n_wa_pattern = fullatt_block_indexes[0] + 1 + # validate n_wa_pattern + for i in range(1, len(fullatt_block_indexes)): + if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: + raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + else: + raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("visual."): + # process visual tensors + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + elif 'patch_embed.proj.weight' in name: + # split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + return [ + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]), + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC diff --git a/examples/llava/README.md b/examples/llava/README.md index 3b62627ce829f..b97b9e8c54367 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -35,6 +35,16 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF # Pixtral 12B llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF +# Qwen 2 VL +llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF +llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF + +# Qwen 2.5 VL +llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF +llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF +llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF +llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF + # Mistral Small 3.1 24B (IQ2_M quantization) llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7 ``` @@ -60,7 +70,17 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta ## How to obtain `mmproj` -Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them: +Multimodal projector (`mmproj`) files are specific to each model architecture. + +For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file: +- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support +- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) +- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) +- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint +- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen)) +- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) + +For older models, please refer to the relevant guide for instructions on how to obtain or create them: - [LLaVA](../../docs/multimodal/llava.md) - [MobileVLM](../../docs/multimodal/MobileVLM.md) @@ -70,10 +90,3 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P - [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md) - [IBM Granite Vision](../../docs/multimodal/granitevision.md) - [Google Gemma 3](../../docs/multimodal/gemma3.md) - -For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file: -- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support -- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) -- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) -- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint -- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py deleted file mode 100644 index 7951a6fa8951e..0000000000000 --- a/examples/llava/qwen2_vl_surgery.py +++ /dev/null @@ -1,217 +0,0 @@ -import argparse -from typing import Dict, List, Optional - -import torch -import numpy as np -from gguf import * -from transformers import ( - AutoProcessor, - Qwen2VLConfig, - Qwen2VLProcessor, - Qwen2VLForConditionalGeneration, - Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue] - Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue] -) - - -VISION = "clip.vision" - - -def k(raw_key: str, arch: str) -> str: - return raw_key.format(arch=arch) - - -def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]): - if fullatt_block_indexes is None: - return 0 - n_wa = fullatt_block_indexes[0] - for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]): - if b - a - 1 != n_wa: - raise ValueError( - f"window/full attention layer should have fix pattern of " - f"for each full-attention layer followed by {n_wa} window-attention layers" - ) - return n_wa + 1 - - -class VL2: - - @staticmethod - def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") - # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[to_gguf_name] {og} --> {name}") - return name - - @classmethod - def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]: - vision_model = qwen2vl.visual - tensor_map = {} - for name, ten in vision_model.state_dict().items(): - ten = ten.numpy() - if 'qkv' in name: - if ten.ndim == 2: # weight - c3, _ = ten.shape - else: # bias - c3 = ten.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = ten[:c] - wk = ten[c: c * 2] - wv = ten[c * 2:] - tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq - tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk - tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv - elif 'merger' in name: - if name.endswith("ln_q.weight"): - tensor_map['v.post_ln.weight'] = ten - elif name.endswith("ln_q.bias"): - tensor_map['v.post_ln.bias'] = ten - else: - # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" - tensor_map[cls.to_gguf_name(name)] = ten - elif 'patch_embed.proj.weight' in name: - # NOTE: split Conv3D into Conv2Ds - c1, c2, kt, kh, kw = ten.shape - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] - tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] - else: - tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten - - for new_name, ten in tensor_map.items(): - if ten.ndim <= 1 or new_name.endswith("_norm.weight"): - tensor_map[new_name] = ten.astype(np.float32) - else: - tensor_map[new_name] = ten.astype(dtype) - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder - return tensor_map - - -class VL25(VL2): - - @staticmethod - def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up") - name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[vl25][to_gguf_name] {og} --> {name}") - return name - - -def main(args): - if args.data_type == 'fp32': - dtype = torch.float32 - np_dtype = np.float32 - ftype = 0 - elif args.data_type == 'fp16': - dtype = torch.float16 - np_dtype = np.float16 - ftype = 1 - else: - raise ValueError() - - local_model = False - model_path = "" - model_name = args.model_name - print("model_name: ", model_name) - if args.model_type == "qwen2vl": - qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config - else: - qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config - - if os.path.isdir(model_name): - local_model = True - if model_name.endswith(os.sep): - model_name = model_name[:-1] - model_path = model_name - model_name = os.path.basename(model_name) - fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf" - - fout = GGUFWriter(path=fname_out, arch="clip") - fout.add_description("image encoder for Qwen2VL") - - fout.add_file_type(ftype) - fout.add_bool("clip.has_text_encoder", False) - fout.add_bool("clip.has_vision_encoder", True) - fout.add_bool("clip.has_qwen2vl_merger", True) - - print(cfg.vision_config) - if 'silu' in cfg.vision_config.hidden_act.lower(): - fout.add_bool("clip.use_silu", True) - fout.add_bool("clip.use_gelu", False) - elif 'gelu' in cfg.vision_config.hidden_act.lower(): - fout.add_bool("clip.use_silu", False) - fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower()) - else: - raise ValueError() - - if args.model_type == "qwen2.5vl": - fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes)) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) - fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) - fout.add_string("clip.projector_type", "qwen2.5vl_merger") - else: - fout.add_string("clip.projector_type", "qwen2vl_merger") - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) - fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) - - if args.model_type == "qwen2.5vl": - tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype) - else: - tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype) - for name, data in tensor_map.items(): - fout.add_tensor(name, data) - - fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) - fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder - fout.add_name(model_name) - """ - HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, - it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`. - """ - - if local_model: - processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path) - else: - processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) - fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue] - fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue] - - fout.write_header_to_file() - fout.write_kv_data_to_file() - fout.write_tensors_to_file() - fout.close() - print("save model as: ", fname_out) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") - parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl") - parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") - args = parser.parse_args() - main(args) diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh index 4af370064086f..22c2374971331 100755 --- a/examples/llava/tests.sh +++ b/examples/llava/tests.sh @@ -36,12 +36,6 @@ add_test() { arr_tmpl+=("$tmpl") } -add_test_big() { - if [ "$RUN_BIG_TESTS" = true ]; then - add_test "$@" - fi -} - add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" @@ -58,8 +52,16 @@ add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big -add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M" -add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" +if [ "$RUN_BIG_TESTS" = true ]; then + add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" + add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" + # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra + # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big +fi # these models always give the wrong answer, not sure why # add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M" diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a2540bd93fd91..74e46c3ee0f95 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -234,6 +234,7 @@ class ClipVision: SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -2162,6 +2163,8 @@ class VisionProjectorType: GEMMA3 = "gemma3" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" + QWEN2VL = "qwen2vl_merger" + QWEN25VL = "qwen2.5vl_merger" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a30c49e32b351..ff50d3de31287 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -984,6 +984,9 @@ def add_vision_use_silu(self, value: bool) -> None: def add_vision_projector_scale_factor(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) + def add_vision_n_wa_pattern(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2f6326104ffa7..2b089f84a841a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -896,6 +896,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "visual.merger.mlp.{bid}", # qwen2vl ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -919,6 +920,7 @@ class TensorNameMap: "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM "vision_tower.patch_conv", # pixtral + "visual.patch_embed.proj", # qwen2vl ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -932,6 +934,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_ATTN_K: ( @@ -939,6 +942,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_ATTN_V: ( @@ -946,6 +950,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "visual.blocks.{bid}.attn.v", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -953,6 +958,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "visual.blocks.{bid}.norm1", # qwen2vl ), MODEL_TENSOR.V_ENC_OUTPUT: ( @@ -960,6 +966,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "visual.blocks.{bid}.attn.proj", # qwen2vl ), MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( @@ -967,17 +974,24 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "visual.blocks.{bid}.norm2", # qwen2vl ), + # some namings are messed up because the original llava code swapped fc1 and fc2 + # we have no better way to fix it, just be careful + # new models like pixtral use the correct naming MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), MODEL_TENSOR.V_ENC_FFN_GATE: ( "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( @@ -985,6 +999,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), MODEL_TENSOR.V_PRE_NORM: ( @@ -995,6 +1011,7 @@ class TensorNameMap: MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", "model.vision_model.post_layernorm", # SmolVLM + "visual.merger.ln_q", # qwen2vl ), MODEL_TENSOR.V_MM_INP_PROJ: (