@@ -4243,6 +4243,14 @@ def set_gguf_parameters(self):
42434243class MambaModel (TextModel ):
42444244 model_arch = gguf .MODEL_ARCH .MAMBA
42454245
4246+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4247+ # Avoid using AutoConfig for hparams
4248+ hparams = kwargs .pop ("hparams" , None )
4249+ if hparams is None :
4250+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4251+ hparams = json .load (f )
4252+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4253+
42464254 def set_vocab (self ):
42474255 vocab_size = self .hparams ["vocab_size" ]
42484256 # Round vocab size to next multiple of 8
@@ -4321,8 +4329,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43214329class Mamba2Model (TextModel ):
43224330 model_arch = gguf .MODEL_ARCH .MAMBA2
43234331
4324- def __init__ (self , * args , ** kwargs ):
4325- super ().__init__ (* args , ** kwargs )
4332+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4333+ # Avoid using AutoConfig for hparams
4334+ # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4335+ hparams = kwargs .pop ("hparams" , None )
4336+ if hparams is None :
4337+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4338+ hparams = json .load (f )
4339+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
43264340 self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
43274341 self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
43284342 self .n_group = self .hparams .get ("n_groups" , 1 )
@@ -6225,12 +6239,20 @@ def split_str_to_n_bytes(split_str: str) -> int:
62256239def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
62266240 text_config = hparams .get ("text_config" , {})
62276241 vision_config = hparams .get ("vision_config" , {})
6228- arch = hparams ["architectures" ][0 ]
6242+ arch = None
6243+ if (arches := hparams .get ("architectures" )) is not None and len (arches ) > 0 :
6244+ arch = arches [0 ]
6245+ elif "ssm_cfg" in hparams :
6246+ # For non-hf Mamba and Mamba2 models
6247+ arch = hparams ["ssm_cfg" ].get ("layer" , "Mamba" ) + "ForCausalLM"
6248+
62296249 # if "architectures" is found in the sub-config, use that instead
62306250 if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
62316251 arch = text_config ["architectures" ][0 ]
62326252 elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
62336253 arch = vision_config ["architectures" ][0 ]
6254+ if arch is None :
6255+ raise ValueError ("Failed to detect model architecture" )
62346256 return arch
62356257
62366258
0 commit comments