Skip to content

Commit a6ae3af

Browse files
authored
Support XiaomiMiMo inference with mtp (#6059)
1 parent 0b07c4a commit a6ae3af

File tree

6 files changed

+344
-6
lines changed

6 files changed

+344
-6
lines changed

docs/backend/speculative_decoding.ipynb

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,60 @@
283283
"terminate_process(server_process)"
284284
]
285285
},
286+
{
287+
"cell_type": "markdown",
288+
"metadata": {},
289+
"source": [
290+
"## Multi Token Prediction\n",
291+
"\n",
292+
"We support [MTP(Multi-Token Prediction)](https://arxiv.org/pdf/2404.19737) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to [deepseek doc](../references/deepseek.md#multi-token-prediction))"
293+
]
294+
},
295+
{
296+
"cell_type": "code",
297+
"execution_count": null,
298+
"metadata": {},
299+
"outputs": [],
300+
"source": [
301+
"server_process, port = launch_server_cmd(\n",
302+
" \"\"\"\n",
303+
" python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \\\n",
304+
" --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \\\n",
305+
" --mem-fraction 0.5\n",
306+
"\"\"\"\n",
307+
")\n",
308+
"\n",
309+
"wait_for_server(f\"http://localhost:{port}\")"
310+
]
311+
},
312+
{
313+
"cell_type": "code",
314+
"execution_count": null,
315+
"metadata": {},
316+
"outputs": [],
317+
"source": [
318+
"import requests\n",
319+
"\n",
320+
"url = f\"http://localhost:{port}/v1/chat/completions\"\n",
321+
"\n",
322+
"data = {\n",
323+
" \"model\": \"XiaomiMiMo/MiMo-7B-RL\",\n",
324+
" \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n",
325+
"}\n",
326+
"\n",
327+
"response = requests.post(url, json=data)\n",
328+
"print_highlight(response.json())"
329+
]
330+
},
331+
{
332+
"cell_type": "code",
333+
"execution_count": null,
334+
"metadata": {},
335+
"outputs": [],
336+
"source": [
337+
"terminate_process(server_process)"
338+
]
339+
},
286340
{
287341
"cell_type": "markdown",
288342
"metadata": {},

python/sglang/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
model_override_args=self.model_override_args,
7474
**kwargs,
7575
)
76+
7677
self.hf_text_config = get_hf_text_config(self.hf_config)
7778
self.attention_chunk_size = getattr(
7879
self.hf_text_config, "attention_chunk_size", None
@@ -97,6 +98,8 @@ def __init__(
9798
):
9899
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
99100

101+
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
102+
self.hf_config.architectures[0] = "MiMoMTP"
100103
# Check model type
101104
self.is_generation = is_generation_model(
102105
self.hf_config.architectures, is_embedding

python/sglang/srt/model_executor/model_runner.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,15 @@ def profile_max_num_token(self, total_gpu_memory: int):
782782
distributed=get_world_group().world_size > 1,
783783
cpu_group=get_world_group().cpu_group,
784784
)
785-
if self.use_mla_backend:
786-
num_layers = (
787-
self.model_config.num_hidden_layers
788-
if not self.is_draft_worker
789-
else self.model_config.hf_config.num_nextn_predict_layers
785+
if self.is_draft_worker:
786+
num_layers = getattr(
787+
self.model_config.hf_config,
788+
"num_nextn_predict_layers",
789+
self.num_effective_layers,
790790
)
791+
else:
792+
num_layers = self.num_effective_layers
793+
if self.use_mla_backend:
791794
# FIXME: pipeline parallelism is not compatible with mla backend
792795
assert self.pp_size == 1
793796
cell_size = (
@@ -799,7 +802,7 @@ def profile_max_num_token(self, total_gpu_memory: int):
799802
cell_size = (
800803
self.model_config.get_num_kv_heads(get_attention_tp_size())
801804
* self.model_config.head_dim
802-
* self.num_effective_layers
805+
* num_layers
803806
* 2
804807
* torch._utils._element_size(self.kv_cache_dtype)
805808
)

python/sglang/srt/models/mimo_mtp.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
2+
3+
from functools import partial
4+
from typing import Any, Dict, Iterable, Optional, Tuple
5+
6+
import torch
7+
from torch import nn
8+
from transformers import PretrainedConfig
9+
10+
from sglang.srt.distributed import (
11+
get_tensor_model_parallel_rank,
12+
get_tensor_model_parallel_world_size,
13+
split_tensor_along_last_dim,
14+
tensor_model_parallel_all_gather,
15+
)
16+
from sglang.srt.layers.layernorm import RMSNorm
17+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18+
from sglang.srt.layers.logits_processor import LogitsProcessor
19+
from sglang.srt.layers.pooler import Pooler, PoolingType
20+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
21+
from sglang.srt.layers.radix_attention import RadixAttention
22+
from sglang.srt.layers.rotary_embedding import get_rope
23+
from sglang.srt.layers.vocab_parallel_embedding import (
24+
ParallelLMHead,
25+
VocabParallelEmbedding,
26+
)
27+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28+
from sglang.srt.model_loader.weight_utils import default_weight_loader
29+
from sglang.srt.models.mimo import MiMoForCausalLM
30+
from sglang.srt.models.qwen2 import (
31+
Qwen2Attention,
32+
Qwen2DecoderLayer,
33+
Qwen2MLP,
34+
Qwen2Model,
35+
)
36+
from sglang.srt.utils import add_prefix
37+
38+
39+
class MiMoMultiTokenPredictorLayer(nn.Module):
40+
41+
def __init__(
42+
self,
43+
config: PretrainedConfig,
44+
prefix: str,
45+
quant_config: Optional[QuantizationConfig] = None,
46+
) -> None:
47+
super().__init__()
48+
49+
self.embed_tokens = VocabParallelEmbedding(
50+
config.vocab_size,
51+
config.hidden_size,
52+
)
53+
self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
54+
self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
55+
self.input_proj = nn.Linear(
56+
config.hidden_size * 2, config.hidden_size, bias=False
57+
)
58+
self.mtp_block = Qwen2DecoderLayer(
59+
config=config, quant_config=quant_config, prefix=prefix
60+
)
61+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62+
63+
def forward(
64+
self,
65+
input_ids: torch.Tensor,
66+
positions: torch.Tensor,
67+
forward_batch: ForwardBatch,
68+
input_embeds: torch.Tensor = None,
69+
) -> torch.Tensor:
70+
71+
if input_embeds is None:
72+
hidden_states = self.embed_tokens(input_ids)
73+
else:
74+
hidden_states = input_embeds
75+
# masking inputs at position 0, as not needed by MTP
76+
hidden_states[positions == 0] = 0
77+
78+
hidden_states = self.input_proj(
79+
torch.cat(
80+
(
81+
self.hidden_layernorm(forward_batch.spec_info.hidden_states),
82+
self.token_layernorm(hidden_states),
83+
),
84+
dim=-1,
85+
)
86+
)
87+
88+
hidden_states, residual = self.mtp_block(
89+
positions=positions,
90+
hidden_states=hidden_states,
91+
forward_batch=forward_batch,
92+
residual=None,
93+
)
94+
hidden_states = residual + hidden_states
95+
hidden_states = self.final_layernorm(hidden_states)
96+
return hidden_states
97+
98+
99+
class MiMoMTP(nn.Module):
100+
def __init__(
101+
self,
102+
config: PretrainedConfig,
103+
quant_config: Optional[QuantizationConfig] = None,
104+
prefix: str = "",
105+
) -> None:
106+
nn.Module.__init__(self)
107+
self.config = config
108+
self.tp_size = get_tensor_model_parallel_world_size()
109+
self.quant_config = quant_config
110+
111+
self.model = MiMoMultiTokenPredictorLayer(
112+
config,
113+
prefix,
114+
quant_config,
115+
)
116+
self.lm_head = ParallelLMHead(
117+
config.vocab_size,
118+
config.hidden_size,
119+
quant_config=quant_config,
120+
)
121+
self.logits_processor = LogitsProcessor(config)
122+
123+
@torch.no_grad()
124+
def forward(
125+
self,
126+
input_ids: torch.Tensor,
127+
positions: torch.Tensor,
128+
forward_batch: ForwardBatch,
129+
) -> torch.Tensor:
130+
hidden_states = self.model(input_ids, positions, forward_batch)
131+
return self.logits_processor(
132+
input_ids, hidden_states, self.lm_head, forward_batch
133+
)
134+
135+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
136+
stacked_params_mapping = [
137+
# (param_name, shard_name, shard_id)
138+
("qkv_proj", "q_proj", "q"),
139+
("qkv_proj", "k_proj", "k"),
140+
("qkv_proj", "v_proj", "v"),
141+
("gate_up_proj", "gate_proj", 0),
142+
("gate_up_proj", "up_proj", 1),
143+
]
144+
145+
params_dict = dict(self.named_parameters())
146+
for name, loaded_weight in weights:
147+
if "rotary_emb.inv_freq" in name or "projector" in name:
148+
continue
149+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
150+
# Models trained using ColossalAI may include these tensors in
151+
# the checkpoint. Skip them.
152+
continue
153+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
154+
continue
155+
if name.startswith("model.vision_tower") and name not in params_dict:
156+
continue
157+
name = self.map_model_name_to_mtp_param_name(name)
158+
159+
for param_name, weight_name, shard_id in stacked_params_mapping:
160+
if weight_name not in name:
161+
continue
162+
if "mtp_block" not in name:
163+
break
164+
name = name.replace(weight_name, param_name)
165+
# Skip loading extra bias for GPTQ models.
166+
if name.endswith(".bias") and name not in params_dict:
167+
continue
168+
param = params_dict[name]
169+
weight_loader = param.weight_loader
170+
weight_loader(param, loaded_weight, shard_id)
171+
break
172+
else:
173+
# Skip loading extra bias for GPTQ models.
174+
if name.endswith(".bias") and name not in params_dict:
175+
continue
176+
if "mtp_block" not in name and (
177+
"embed_tokens" not in name
178+
and "lm_head" not in name
179+
and "token_layernorm" not in name
180+
and "hidden_layernorm" not in name
181+
and "input_proj" not in name
182+
and "final_layernorm" not in name
183+
):
184+
continue
185+
param = params_dict[name]
186+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
187+
weight_loader(param, loaded_weight)
188+
189+
def map_model_name_to_mtp_param_name(self, name: str) -> str:
190+
import re
191+
192+
name_without_prefix = [
193+
"token_layernorm",
194+
"hidden_layernorm",
195+
"input_proj",
196+
"final_layernorm",
197+
]
198+
pattern = r"model.mtp_layers.(\d+)."
199+
group = re.match(pattern, name)
200+
if group is not None:
201+
for sub_name in name_without_prefix:
202+
if sub_name in name:
203+
name = name.replace(group.group(), "model.")
204+
return name
205+
name = name.replace(group.group(), "model.mtp_block.")
206+
return name
207+
208+
def get_embed_and_head(self):
209+
return self.model.embed_tokens.weight, self.lm_head.weight
210+
211+
def set_embed_and_head(self, embed, head):
212+
del self.model.embed_tokens.weight
213+
del self.lm_head.weight
214+
self.model.embed_tokens.weight = embed
215+
self.lm_head.weight = head
216+
torch.cuda.empty_cache()
217+
torch.cuda.synchronize()
218+
219+
220+
EntryClass = MiMoMTP

test/srt/models/test_mtp_models.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import unittest
2+
from types import SimpleNamespace
3+
4+
from sglang.srt.utils import kill_process_tree
5+
from sglang.test.few_shot_gsm8k import run_eval
6+
from sglang.test.test_utils import (
7+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
8+
DEFAULT_URL_FOR_TEST,
9+
CustomTestCase,
10+
popen_launch_server,
11+
)
12+
13+
14+
class TestMiMoMTP(CustomTestCase):
15+
@classmethod
16+
def setUpClass(cls):
17+
cls.model = "XiaomiMiMo/MiMo-7B-RL"
18+
cls.base_url = DEFAULT_URL_FOR_TEST
19+
cls.process = popen_launch_server(
20+
cls.model,
21+
cls.base_url,
22+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
23+
other_args=[
24+
"--trust-remote-code",
25+
"--speculative-algorithm",
26+
"EAGLE",
27+
"--speculative-num-steps",
28+
"1",
29+
"--speculative-eagle-topk",
30+
"1",
31+
"--speculative-num-draft-tokens",
32+
"2",
33+
"--mem-fraction-static",
34+
"0.5",
35+
],
36+
)
37+
38+
@classmethod
39+
def tearDownClass(cls):
40+
kill_process_tree(cls.process.pid)
41+
42+
def test_gsm8k(self):
43+
args = SimpleNamespace(
44+
num_shots=5,
45+
data_path=None,
46+
num_questions=200,
47+
max_new_tokens=512,
48+
parallel=128,
49+
host="http://127.0.0.1",
50+
port=int(self.base_url.split(":")[-1]),
51+
)
52+
metrics = run_eval(args)
53+
print(f"{metrics=}")
54+
self.assertGreater(metrics["accuracy"], 0.7)
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

0 commit comments

Comments
 (0)