Skip to content

Model: Granite MoE shared #13269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5661,6 +5661,39 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("GraniteMoeSharedForCausalLM")
class GraniteMoeSharedModel(GraniteMoeModel):
"""Conversion for IBM's GraniteMoeSharedForCausalLM"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_SHARED

def set_gguf_parameters(self):
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
- shared_intermediate_size
"""
super().set_gguf_parameters()
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""In modeling_granitemoeshared, the implementation of parallel experts
is used. This essentially merges w1 and w3 into a single tensor with 2x
the hidden size that is then split during forward. To keep compatibility
with existing shared expert support, we pull them apart here.
"""

if name.endswith("shared_mlp.input_linear.weight"):
ffn_dim = self.hparams["shared_intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
]

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
Expand Down
292 changes: 156 additions & 136 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,74 +255,75 @@ class GGUFType:


class MODEL_ARCH(IntEnum):
CLIP_VISION = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
STARCODER = auto()
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
QWEN2 = auto()
QWEN2MOE = auto()
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
COHERE2 = auto()
DBRX = auto()
OLMO = auto()
OLMO2 = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
GLM4 = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()
CLIP_VISION = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
STARCODER = auto()
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
QWEN2 = auto()
QWEN2MOE = auto()
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
COHERE2 = auto()
DBRX = auto()
OLMO = auto()
OLMO2 = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
GLM4 = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_MOE_SHARED = auto()
Copy link
Collaborator

@ngxson ngxson May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a new arch is needed here. Some other archs also have exp/shared exp and they are being controlled via the n_ff_shexp

CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -512,74 +513,75 @@ class MODEL_TENSOR(IntEnum):


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.QWEN2MOE: "qwen2moe",
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO2: "olmo2",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.QWEN2MOE: "qwen2moe",
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO2: "olmo2",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_MOE_SHARED: "granitemoeshared",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -1894,6 +1896,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.GRANITE_MOE_SHARED: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
],
MODEL_ARCH.CHAMELEON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoeshared
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down
Loading
Loading