|
| 1 | +# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py |
| 2 | + |
| 3 | +from functools import partial |
| 4 | +from typing import Any, Dict, Iterable, Optional, Tuple |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import nn |
| 8 | +from transformers import PretrainedConfig |
| 9 | + |
| 10 | +from sglang.srt.distributed import ( |
| 11 | + get_tensor_model_parallel_rank, |
| 12 | + get_tensor_model_parallel_world_size, |
| 13 | + split_tensor_along_last_dim, |
| 14 | + tensor_model_parallel_all_gather, |
| 15 | +) |
| 16 | +from sglang.srt.layers.layernorm import RMSNorm |
| 17 | +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear |
| 18 | +from sglang.srt.layers.logits_processor import LogitsProcessor |
| 19 | +from sglang.srt.layers.pooler import Pooler, PoolingType |
| 20 | +from sglang.srt.layers.quantization.base_config import QuantizationConfig |
| 21 | +from sglang.srt.layers.radix_attention import RadixAttention |
| 22 | +from sglang.srt.layers.rotary_embedding import get_rope |
| 23 | +from sglang.srt.layers.vocab_parallel_embedding import ( |
| 24 | + ParallelLMHead, |
| 25 | + VocabParallelEmbedding, |
| 26 | +) |
| 27 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch |
| 28 | +from sglang.srt.model_loader.weight_utils import default_weight_loader |
| 29 | +from sglang.srt.models.mimo import MiMoForCausalLM |
| 30 | +from sglang.srt.models.qwen2 import ( |
| 31 | + Qwen2Attention, |
| 32 | + Qwen2DecoderLayer, |
| 33 | + Qwen2MLP, |
| 34 | + Qwen2Model, |
| 35 | +) |
| 36 | +from sglang.srt.utils import add_prefix |
| 37 | + |
| 38 | + |
| 39 | +class MiMoMultiTokenPredictorLayer(nn.Module): |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + config: PretrainedConfig, |
| 44 | + prefix: str, |
| 45 | + quant_config: Optional[QuantizationConfig] = None, |
| 46 | + ) -> None: |
| 47 | + super().__init__() |
| 48 | + |
| 49 | + self.embed_tokens = VocabParallelEmbedding( |
| 50 | + config.vocab_size, |
| 51 | + config.hidden_size, |
| 52 | + ) |
| 53 | + self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 54 | + self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 55 | + self.input_proj = nn.Linear( |
| 56 | + config.hidden_size * 2, config.hidden_size, bias=False |
| 57 | + ) |
| 58 | + self.mtp_block = Qwen2DecoderLayer( |
| 59 | + config=config, quant_config=quant_config, prefix=prefix |
| 60 | + ) |
| 61 | + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 62 | + |
| 63 | + def forward( |
| 64 | + self, |
| 65 | + input_ids: torch.Tensor, |
| 66 | + positions: torch.Tensor, |
| 67 | + forward_batch: ForwardBatch, |
| 68 | + input_embeds: torch.Tensor = None, |
| 69 | + ) -> torch.Tensor: |
| 70 | + |
| 71 | + if input_embeds is None: |
| 72 | + hidden_states = self.embed_tokens(input_ids) |
| 73 | + else: |
| 74 | + hidden_states = input_embeds |
| 75 | + # masking inputs at position 0, as not needed by MTP |
| 76 | + hidden_states[positions == 0] = 0 |
| 77 | + |
| 78 | + hidden_states = self.input_proj( |
| 79 | + torch.cat( |
| 80 | + ( |
| 81 | + self.hidden_layernorm(forward_batch.spec_info.hidden_states), |
| 82 | + self.token_layernorm(hidden_states), |
| 83 | + ), |
| 84 | + dim=-1, |
| 85 | + ) |
| 86 | + ) |
| 87 | + |
| 88 | + hidden_states, residual = self.mtp_block( |
| 89 | + positions=positions, |
| 90 | + hidden_states=hidden_states, |
| 91 | + forward_batch=forward_batch, |
| 92 | + residual=None, |
| 93 | + ) |
| 94 | + hidden_states = residual + hidden_states |
| 95 | + hidden_states = self.final_layernorm(hidden_states) |
| 96 | + return hidden_states |
| 97 | + |
| 98 | + |
| 99 | +class MiMoMTP(nn.Module): |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + config: PretrainedConfig, |
| 103 | + quant_config: Optional[QuantizationConfig] = None, |
| 104 | + prefix: str = "", |
| 105 | + ) -> None: |
| 106 | + nn.Module.__init__(self) |
| 107 | + self.config = config |
| 108 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 109 | + self.quant_config = quant_config |
| 110 | + |
| 111 | + self.model = MiMoMultiTokenPredictorLayer( |
| 112 | + config, |
| 113 | + prefix, |
| 114 | + quant_config, |
| 115 | + ) |
| 116 | + self.lm_head = ParallelLMHead( |
| 117 | + config.vocab_size, |
| 118 | + config.hidden_size, |
| 119 | + quant_config=quant_config, |
| 120 | + ) |
| 121 | + self.logits_processor = LogitsProcessor(config) |
| 122 | + |
| 123 | + @torch.no_grad() |
| 124 | + def forward( |
| 125 | + self, |
| 126 | + input_ids: torch.Tensor, |
| 127 | + positions: torch.Tensor, |
| 128 | + forward_batch: ForwardBatch, |
| 129 | + ) -> torch.Tensor: |
| 130 | + hidden_states = self.model(input_ids, positions, forward_batch) |
| 131 | + return self.logits_processor( |
| 132 | + input_ids, hidden_states, self.lm_head, forward_batch |
| 133 | + ) |
| 134 | + |
| 135 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 136 | + stacked_params_mapping = [ |
| 137 | + # (param_name, shard_name, shard_id) |
| 138 | + ("qkv_proj", "q_proj", "q"), |
| 139 | + ("qkv_proj", "k_proj", "k"), |
| 140 | + ("qkv_proj", "v_proj", "v"), |
| 141 | + ("gate_up_proj", "gate_proj", 0), |
| 142 | + ("gate_up_proj", "up_proj", 1), |
| 143 | + ] |
| 144 | + |
| 145 | + params_dict = dict(self.named_parameters()) |
| 146 | + for name, loaded_weight in weights: |
| 147 | + if "rotary_emb.inv_freq" in name or "projector" in name: |
| 148 | + continue |
| 149 | + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: |
| 150 | + # Models trained using ColossalAI may include these tensors in |
| 151 | + # the checkpoint. Skip them. |
| 152 | + continue |
| 153 | + if self.config.tie_word_embeddings and "lm_head.weight" in name: |
| 154 | + continue |
| 155 | + if name.startswith("model.vision_tower") and name not in params_dict: |
| 156 | + continue |
| 157 | + name = self.map_model_name_to_mtp_param_name(name) |
| 158 | + |
| 159 | + for param_name, weight_name, shard_id in stacked_params_mapping: |
| 160 | + if weight_name not in name: |
| 161 | + continue |
| 162 | + if "mtp_block" not in name: |
| 163 | + break |
| 164 | + name = name.replace(weight_name, param_name) |
| 165 | + # Skip loading extra bias for GPTQ models. |
| 166 | + if name.endswith(".bias") and name not in params_dict: |
| 167 | + continue |
| 168 | + param = params_dict[name] |
| 169 | + weight_loader = param.weight_loader |
| 170 | + weight_loader(param, loaded_weight, shard_id) |
| 171 | + break |
| 172 | + else: |
| 173 | + # Skip loading extra bias for GPTQ models. |
| 174 | + if name.endswith(".bias") and name not in params_dict: |
| 175 | + continue |
| 176 | + if "mtp_block" not in name and ( |
| 177 | + "embed_tokens" not in name |
| 178 | + and "lm_head" not in name |
| 179 | + and "token_layernorm" not in name |
| 180 | + and "hidden_layernorm" not in name |
| 181 | + and "input_proj" not in name |
| 182 | + and "final_layernorm" not in name |
| 183 | + ): |
| 184 | + continue |
| 185 | + param = params_dict[name] |
| 186 | + weight_loader = getattr(param, "weight_loader", default_weight_loader) |
| 187 | + weight_loader(param, loaded_weight) |
| 188 | + |
| 189 | + def map_model_name_to_mtp_param_name(self, name: str) -> str: |
| 190 | + import re |
| 191 | + |
| 192 | + name_without_prefix = [ |
| 193 | + "token_layernorm", |
| 194 | + "hidden_layernorm", |
| 195 | + "input_proj", |
| 196 | + "final_layernorm", |
| 197 | + ] |
| 198 | + pattern = r"model.mtp_layers.(\d+)." |
| 199 | + group = re.match(pattern, name) |
| 200 | + if group is not None: |
| 201 | + for sub_name in name_without_prefix: |
| 202 | + if sub_name in name: |
| 203 | + name = name.replace(group.group(), "model.") |
| 204 | + return name |
| 205 | + name = name.replace(group.group(), "model.mtp_block.") |
| 206 | + return name |
| 207 | + |
| 208 | + def get_embed_and_head(self): |
| 209 | + return self.model.embed_tokens.weight, self.lm_head.weight |
| 210 | + |
| 211 | + def set_embed_and_head(self, embed, head): |
| 212 | + del self.model.embed_tokens.weight |
| 213 | + del self.lm_head.weight |
| 214 | + self.model.embed_tokens.weight = embed |
| 215 | + self.lm_head.weight = head |
| 216 | + torch.cuda.empty_cache() |
| 217 | + torch.cuda.synchronize() |
| 218 | + |
| 219 | + |
| 220 | +EntryClass = MiMoMTP |
0 commit comments