Skip to content

[Doc - Guidance] Proper MoE quantization #732

@mratsim

Description

@mratsim

Hello team,

I've been looking into quantizing some of the recent models (GLM-4.x, Minimax-M2.1, MiMo-V2-Flash, ...) but it seems like most frameworks do not forward calibration data to all experts unless the model is added to the quantization framework.

This can lead to significant quality issues like https://avtc.github.io/aquarium-side-by-side/ (discussion ModelCloud/GPTQModel#2235 (comment)):

Image

Looking at the repo, I see all expert calibration mentioned in 2 places:

  1. def forward(self, x: torch.Tensor) -> torch.Tensor:
    # Forward all tokens to all experts for calibration
    self.gate.topk = self.n_routed_experts
    self.gate.topk_groups = self.gate.n_groups
    super().forward(x)
    # Restore the original topk and topk_groups
    self.gate.topk = self._original_topk
    self.gate.topk_groups = self._original_topk_groups
    return super().forward(x)
  2. class _QuantSparseMoe(QuantModule):
    """Module to support special handling of token dispatching during calibration.
    During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
    However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
    returns.
    If calibration is not enabled, this module behaves as a normal MoELayer.
    """
    def _setup(self):
    pass
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
    # If any of the experts are in calibration mode, we will forward all tokens to all experts
    # This is used only for calibration, we need to re-calculate the actual outputs again using
    # the original top_k
    original_top_k = self.top_k
    self.top_k = self.num_experts
    super().forward(hidden_states)
    self.top_k = original_top_k
    return super().forward(hidden_states)

However in the latter case, only a handful of models are supported:

try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
if Llama4TextMoe not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
_QuantLlama4TextExperts
)
except ImportError:
pass
try:
from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU, DbrxExperts, DbrxFFN
if DbrxExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({DbrxExperts: "hf.DbrxExperts"})(_QuantDbrxExperts)
if DbrxExpertGLU not in QuantModuleRegistry:
QuantModuleRegistry.register({DbrxExpertGLU: "hf.DbrxExpertGLU"})(_QuantDbrxExpertGLU)
if DbrxFFN not in QuantModuleRegistry:
QuantModuleRegistry.register({DbrxFFN: "hf.DbrxFFN"})(_QuantDbrxFFN)
except ImportError:
pass
try:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
if MixtralSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass
try:
from transformers.models.falcon.modeling_falcon import FalconLinear
if FalconLinear not in QuantModuleRegistry:
QuantModuleRegistry.register({FalconLinear: "hf.FalconLinear"})(_QuantLinear)
except ImportError:
pass
try:
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass
try:
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass
try:
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

And some like gpt-oss or DeepSeek need mid to significant customization.

Can you provide guidance, ideally a reference documentation on how to add support to optimally quantize new MoE models to your framework?

In particular this should cover:

  • How the plugin system works
  • the "easy" case where registering _QuantSparseMoe is enough, for example I think Minimax-M2.1 is in this case
  • the "mid" case like GPT-OSS
    class _QuantGptOssExperts(_QuantFunctionalMixin):
    """Quantized wrapper for `transformers.GptOssExperts`.
    Quantizes `gate_up_proj` and `down_proj` weights via dynamic attributes inside `quantize_weight()`.
    Activations into `gate_up_proj` are quantized by `gate_up_proj_input_quantizer`. For `down_proj`
    activation quantization, we intercept `torch.Tensor.__matmul__`/`torch.bmm` and quantize inputs
    on every second call (since the first call computes `gate_up_proj` outputs and second call
    computes `down_proj` outputs).
    """
    @staticmethod
    def _get_quantized_weight(quantizer, module, weight):
    # MoE weight is accessed for each expert in one forward pass. so lets cache it
    if module._enable_weight_quantization:
    if hasattr(quantizer, "_cached_quant_val"):
    return getattr(quantizer, "_cached_quant_val")
    quantizer._cached_quant_val = _transposed_quantize(weight, quantizer)
    return quantizer._cached_quant_val
    return weight
    def _setup_for_weight_quantization(self):
    self._register_dynamic_attribute(
    "gate_up_proj", partial(self._get_quantized_weight, self.gate_up_proj_weight_quantizer)
    )
    self._register_dynamic_attribute(
    "down_proj", partial(self._get_quantized_weight, self.down_proj_weight_quantizer)
    )
    def _setup(self):
    assert not hasattr(self, "kernel_layer_name"), (
    "ModelOpt quantization does not support patched forward for kernel_hub"
    )
    self.gate_up_proj_input_quantizer = TensorQuantizer()
    self.gate_up_proj_weight_quantizer = TensorQuantizer()
    self.down_proj_input_quantizer = TensorQuantizer()
    self.down_proj_weight_quantizer = TensorQuantizer()
    self._register_temp_attribute("_enable_weight_quantization", False)
    self._register_temp_attribute("_down_proj_mul", False)
    self._setup_for_weight_quantization()
    @property
    def functionals_to_replace(self):
    def _quantized_bmm(batch1, batch2):
    batch1 = self.down_proj_input_quantizer(batch1) if self._down_proj_mul else batch1
    self._down_proj_mul = not self._down_proj_mul # toggle the flag
    return torch._bmm(batch1, batch2)
    def _tensor_matmul(self_t, other):
    self_t = self.down_proj_input_quantizer(self_t) if self._down_proj_mul else self_t
    self._down_proj_mul = not self._down_proj_mul
    return torch.matmul(self_t, other)
    return [
    (torch, "bmm", _quantized_bmm),
    (torch.Tensor, "__matmul__", _tensor_matmul),
    ]
    @contextmanager
    def quantize_weight(self):
    """Context in which MoE weight is quantized."""
    self._enable_weight_quantization = True
    try:
    yield
    finally:
    for module in self.modules():
    if isinstance(module, TensorQuantizer) and hasattr(module, "_cached_quant_val"):
    delattr(module, "_cached_quant_val")
    self._enable_weight_quantization = False
    def forward(
    self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
    ) -> torch.Tensor:
    """Forward method to add quantization."""
    hidden_states = self.gate_up_proj_input_quantizer(hidden_states)
    with self.quantize_weight():
    return super().forward(hidden_states, router_indices, routing_weights)
    try:
    from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts
    if GptOssExperts not in QuantModuleRegistry:
    QuantModuleRegistry.register({GptOssExperts: "hf.GptOssExperts"})(_QuantGptOssExperts)
    except ImportError:
    pass
  • the "hard" case like DeepSeek-V3.2 that seems to need a whole custom pipeline https://github.com/NVIDIA/Model-Optimizer/tree/0.40.0/examples/deepseek
  • And why each model land in which case
  • Lastly, some models are not in transformers (and transformers V5 is a huge version change that rewrote all LLM model files Refactor weight loading huggingface/transformers#41580), can you explain how to support models that need trust_remote_code like Minimax-M2.x Add support for MiniMax-M2 huggingface/transformers#42028
  • Finally, I see that there are KL-divergence tools, but no guide on how to use them on a new quant as of 0.40.0: https://github.com/NVIDIA/Model-Optimizer/tree/0.40.0/docs/source/guides

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions