Skip to content

Commit 8fefdd3

Browse files
liwenju0wenju.li
andauthored
[Feature] add support kimi vl model (#5383)
Co-authored-by: wenju.li <[email protected]>
1 parent 403b855 commit 8fefdd3

File tree

13 files changed

+1189
-11
lines changed

13 files changed

+1189
-11
lines changed

docs/supported_models/vision_language_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ python3 -m sglang.launch_server \
2828
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
2929
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
3030
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
31-
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
31+
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
32+
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. |

python/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ runtime_common = [
4242
"uvicorn",
4343
"uvloop",
4444
"xgrammar==0.1.17",
45+
"blobfile==3.0.0"
4546
]
4647

4748
srt = [

python/sglang/srt/configs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
44
from sglang.srt.configs.exaone import ExaoneConfig
55
from sglang.srt.configs.janus_pro import MultiModalityConfig
6+
from sglang.srt.configs.kimi_vl import KimiVLConfig
7+
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
68

79
__all__ = [
810
"ExaoneConfig",
911
"ChatGLMConfig",
1012
"DbrxConfig",
1113
"DeepseekVL2Config",
1214
"MultiModalityConfig",
15+
"KimiVLConfig",
16+
"MoonViTConfig",
1317
]

python/sglang/srt/configs/kimi_vl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3+
from typing import Optional, Union
4+
5+
from transformers.configuration_utils import PretrainedConfig
6+
7+
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
8+
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
9+
10+
11+
class KimiVLConfig(PretrainedConfig):
12+
model_type = "kimi_vl"
13+
14+
def __init__(
15+
self,
16+
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
17+
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
18+
ignore_index: int = -100,
19+
media_placeholder_token_id: int = 163605,
20+
pad_token_id: int = 0,
21+
**kwargs
22+
):
23+
if vision_config is None:
24+
vision_config = MoonViTConfig()
25+
elif isinstance(vision_config, dict):
26+
vision_config = MoonViTConfig(**vision_config)
27+
self.vision_config = vision_config
28+
29+
if text_config is None:
30+
text_config = DeepseekV2Config()
31+
elif isinstance(text_config, dict):
32+
text_config = DeepseekV2Config(**text_config)
33+
self.text_config = text_config
34+
35+
self.ignore_index = ignore_index
36+
self.media_placeholder_token_id = media_placeholder_token_id
37+
38+
super().__init__(pad_token_id=pad_token_id, **kwargs)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
3+
from transformers.configuration_utils import PretrainedConfig
4+
5+
6+
class MoonViTConfig(PretrainedConfig):
7+
model_type = "moonvit"
8+
9+
def __init__(
10+
self,
11+
patch_size: int = 14,
12+
init_pos_emb_height: int = 64,
13+
init_pos_emb_width: int = 64,
14+
num_attention_heads: int = 16,
15+
num_hidden_layers: int = 27,
16+
hidden_size: int = 1152,
17+
intermediate_size: int = 4304,
18+
merge_kernel_size: tuple[int, int] = (2, 2),
19+
**kwargs,
20+
):
21+
super().__init__(**kwargs)
22+
self.patch_size = patch_size
23+
# Positional embedding config
24+
self.init_pos_emb_height = init_pos_emb_height
25+
self.init_pos_emb_width = init_pos_emb_width
26+
# Transformer config
27+
self.num_hidden_layers = num_hidden_layers
28+
self.num_attention_heads = num_attention_heads
29+
self.hidden_size = hidden_size
30+
self.intermediate_size = intermediate_size
31+
# Patch merger config
32+
self.merge_kernel_size = merge_kernel_size

python/sglang/srt/configs/model_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ def __init__(
176176
self.attention_arch = AttentionArch.MLA
177177
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
178178
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
179+
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
180+
self.head_dim = 256
181+
self.attention_arch = AttentionArch.MLA
182+
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
183+
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
184+
self.v_head_dim = self.hf_text_config.v_head_dim
185+
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
179186
else:
180187
self.attention_arch = AttentionArch.MHA
181188

@@ -530,6 +537,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
530537
"Qwen2VLForConditionalGeneration",
531538
"Qwen2_5_VLForConditionalGeneration",
532539
"CLIPModel",
540+
"KimiVLForConditionalGeneration",
533541
]
534542

535543

python/sglang/srt/conversation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,24 @@ def generate_chat_conv(
806806
)
807807
)
808808

809+
# Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
810+
register_conv_template(
811+
Conversation(
812+
name="kimi-vl",
813+
system_message="You are a helpful assistant",
814+
system_template="<|im_system|>system<|im_middle|>{system_message}",
815+
roles=(
816+
"<|im_user|>user<|im_middle|>",
817+
"<|im_assistant|>assistant<|im_middle|>",
818+
),
819+
messages=[],
820+
sep="<|im_end|>",
821+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
822+
stop_str="<|im_end|>",
823+
image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
824+
)
825+
)
826+
809827

810828
@register_conv_template_matching_function
811829
def match_deepseek_janus_pro(model_path: str):
@@ -888,3 +906,10 @@ def match_openbmb_minicpm(model_path: str):
888906
return "minicpmv"
889907
elif "minicpm-o" in model_path:
890908
return "minicpmo"
909+
910+
911+
@register_conv_template_matching_function
912+
def match_moonshot_kimivl(model_path: str):
913+
model_path = model_path.lower()
914+
if "kimi" in model_path and "vl" in model_path:
915+
return "kimi-vl"

python/sglang/srt/hf_transformers_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
DbrxConfig,
3636
DeepseekVL2Config,
3737
ExaoneConfig,
38+
KimiVLConfig,
3839
MultiModalityConfig,
3940
)
4041
from sglang.srt.connector import create_remote_connector
@@ -46,6 +47,7 @@
4647
ExaoneConfig.model_type: ExaoneConfig,
4748
DeepseekVL2Config.model_type: DeepseekVL2Config,
4849
MultiModalityConfig.model_type: MultiModalityConfig,
50+
KimiVLConfig.model_type: KimiVLConfig,
4951
}
5052

5153
for name, cls in _CONFIG_REGISTRY.items():
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import asyncio
2+
import math
3+
from typing import List, Union
4+
5+
import torch
6+
from PIL import Image
7+
8+
from sglang.srt.managers.multimodal_processors.base_processor import (
9+
BaseMultimodalProcessor as SGLangBaseProcessor,
10+
)
11+
from sglang.srt.managers.multimodal_processors.base_processor import (
12+
MultimodalSpecialTokens,
13+
)
14+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
15+
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
16+
17+
18+
# Compatible with KimiVLForConditionalGeneration
19+
class KimiVLImageProcessor(SGLangBaseProcessor):
20+
models = [KimiVLForConditionalGeneration]
21+
22+
def __init__(self, hf_config, server_args, _processor):
23+
super().__init__(hf_config, server_args, _processor)
24+
self.IMAGE_TOKEN = "<|media_pad|>"
25+
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
26+
27+
self.im_start = "<|media_start|>"
28+
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
29+
30+
self.im_end = "<|media_end|>"
31+
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
32+
33+
self.im_content = "<|media_content|>"
34+
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
35+
36+
async def process_mm_data_async(
37+
self,
38+
image_data: List[Union[str, bytes]],
39+
input_text,
40+
request_obj,
41+
max_req_input_len,
42+
*args,
43+
**kwargs,
44+
):
45+
if not image_data:
46+
return None
47+
if isinstance(image_data, str):
48+
image_data = [image_data]
49+
50+
base_output = self.load_mm_data(
51+
prompt=input_text,
52+
image_data=image_data,
53+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
54+
max_req_input_len=max_req_input_len,
55+
)
56+
ret = self.process_mm_data(
57+
input_text=base_output.input_text,
58+
images=base_output.images,
59+
)
60+
return {
61+
"input_ids": ret["input_ids"].flatten().tolist(),
62+
"mm_items": [
63+
MultimodalDataItem(
64+
pixel_values=ret["pixel_values"],
65+
image_grid_thws=ret["image_grid_hws"],
66+
modality=Modality.IMAGE,
67+
)
68+
],
69+
"im_token_id": self.im_token_id,
70+
"im_start_id": self.im_start_id,
71+
"im_end_id": self.im_end_id,
72+
"im_content_id": self.im_content_id,
73+
}

python/sglang/srt/models/deepseek_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def forward_absorb(
752752
q_nope_out = q_nope_out.transpose(0, 1)
753753

754754
k_nope = latent_cache[..., : self.kv_lora_rank]
755-
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
755+
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
756756
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
757757

758758
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
@@ -1391,6 +1391,9 @@ def __init__(
13911391

13921392
self.dp_size = get_attention_dp_size()
13931393

1394+
def get_input_embeddings(self) -> torch.Tensor:
1395+
return self.embed_tokens
1396+
13941397
def forward(
13951398
self,
13961399
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)