Skip to content

convert : improve model arch handling #13122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 57 additions & 41 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain
from transformers import AutoConfig

import math
import numpy as np
Expand Down Expand Up @@ -66,8 +67,6 @@ class ModelBase:
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
gguf_writer: gguf.GGUFWriter
model_name: str | None
Expand All @@ -78,6 +77,10 @@ class ModelBase:
# subclasses should define this!
model_arch: gguf.MODEL_ARCH

# subclasses should initialize this!
block_count: int
tensor_map: gguf.TensorNameMap

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
Expand Down Expand Up @@ -113,8 +116,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
self.metadata_override = metadata_override
self.model_name = model_name
Expand Down Expand Up @@ -417,15 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]

@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
architectures = hparams.get("architectures")
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
if architectures is not None:
# preserve "architectures" from root level config
hparams["architectures"] = architectures
return hparams
try:
return AutoConfig.from_pretrained(dir_model).to_dict()
except Exception as e:
logger.warning(f"Failed to load model config from {dir_model}: {e}")
logger.warning("Trying to load config.json instead")
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -454,6 +453,23 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type


class TextModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if "text_config" in self.hparams:
# move the text_config to the root level
self.hparams = {**self.hparams, **self.hparams["text_config"]}

self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
# would require using decorated functions instead of simply defining the property
if "model_arch" not in cls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")

def set_vocab(self):
self._set_vocab_gpt2()

Expand Down Expand Up @@ -1070,9 +1086,9 @@ def __init__(self, *args, **kwargs):
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")

# small hack to correct the number of layers
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
# get n_embd of the text model
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams"

if "vision_config" not in self.hparams:
Expand All @@ -1081,6 +1097,9 @@ def __init__(self, *args, **kwargs):
self.global_config = self.hparams
self.hparams = self.hparams["vision_config"]

self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)

# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
Expand All @@ -1098,7 +1117,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))

# preprocessor config
Expand Down Expand Up @@ -1719,23 +1738,12 @@ def prepare_tensors(self):
"LlamaForCausalLM",
"MistralForCausalLM",
"MixtralForCausalLM",
"Idefics3ForConditionalGeneration",
"SmolVLMForConditionalGeneration",
"VLlama3ForCausalLM",
"LlavaForConditionalGeneration")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# fix for SmolVLM2, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
# fix for Pixtral, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
and self.hparams.get("model_type") == "mistral":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral":
# fix missing config.json values
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json
else:
Expand All @@ -1913,7 +1917,6 @@ def set_gguf_parameters(self):
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)

Expand Down Expand Up @@ -1944,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# fix for SmolVLM2, missing some keys in config.json
# default values are taken from transformers code
if self.hparams["model_type"] == "smolvlm_vision":
# fix for SmolVLM2, missing some keys in config.json
# default values are taken from transformers code
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)

def set_gguf_parameters(self):
super().set_gguf_parameters()
Expand Down Expand Up @@ -3505,6 +3507,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

@ModelBase.register("NomicBertModel")
class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
Expand Down Expand Up @@ -5849,6 +5853,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
return n


def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
return arch
Comment on lines +5856 to +5866
Copy link
Collaborator Author

@ngxson ngxson May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@compilade In this case, maybe we can patch get_model_architecture to introduce an exception for Mamba? (so that you no longer need to manually add the "architectures" to config.json)

Maybe something like

if "ssm_cfg" in hparams and hparams.get("ssm_cfg").get("layer") == "Mamba":
    return "MambaForCausalLM"

And since "architectures" is not present in config.json, the AutoConfig will throw an error, which will trigger old method

Copy link
Collaborator

@compilade compilade May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I've tried this, and for some reason it doesn't work; it seems like AutoConfig doesn't fail when there is no architectures field. It still uses wrong default values for what Mamba2 needs.

The error I'm getting suggests hparams.n_groups defaults to 8, and hparams.intermediate_size defaults to 8192, which are the values for Mamba-Codestral-7B-v0.1, not the Mamba2-370m model I'm actually converting.

It's as if AutoConfig can still detect it's Mamba2, but doesn't use the correct values from the config.

Here's what I've tried
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 532cc879d..4a3713e4d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -5968,12 +5968,21 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
     hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
     text_config = hparams.get("text_config", {})
     vision_config = hparams.get("vision_config", {})
-    arch = hparams["architectures"][0]
+    arch = None
+    if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
+        arch = arches[0]
+    elif "ssm_cfg" in hparams:
+        # TODO: more general extra mappings
+        ssm_mapping = {"Mamba": "MambaForCausalLM", "Mamba2": "Mamba2ForCausalLM"}
+        arch = ssm_mapping.get(hparams["ssm_cfg"].get("layer", "Mamba"), None)
+
     # if "architectures" is found in the sub-config, use that instead
     if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
         arch = text_config["architectures"][0]
     elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
         arch = vision_config["architectures"][0]
+    if arch is None:
+        raise ValueError("Failed to detect model architecture")
     return arch

And then converting https://huggingface.co/state-spaces/mamba2-370m.

(when using #9126)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's as if AutoConfig can still detect it's Mamba2, but doesn't use the correct values from the config.

Hmm this is quite magic tbh, I have no idea how AutoConfig works under the hood.

Another solution though, we can add an argument in load_hparams, let's say use_raw_config: bool. Then in the __init__, you can rewrite the self.hparams = load_hparams(..., use_raw_config=True)



def main() -> None:
args = parse_args()

Expand Down Expand Up @@ -5901,16 +5918,15 @@ def main() -> None:

logger.info(f"Loading model: {dir_model.name}")

hparams = ModelBase.load_hparams(dir_model)

if args.mmproj:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")

with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_architecture = get_model_architecture(dir_model, model_type)
logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
except NotImplementedError:
Expand Down