Skip to content

Commit 8c4dc64

Browse files
committed
update
1 parent bb8d4ba commit 8c4dc64

File tree

2 files changed

+4
-23
lines changed

2 files changed

+4
-23
lines changed

python/sglang/srt/model_loader/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,8 @@ def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
11251125
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
11261126

11271127
offsets = np.concatenate(([0], np.cumsum(num_elements)))
1128+
# Make torch infer_schema happy(Compatible with vLLM)
1129+
offsets = torch.tensor(offsets).cpu()
11281130
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
11291131

11301132
if load_8bit:

python/sglang/srt/models/qwen2_5_vl.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
embed_dim=dim,
142142
num_heads=num_heads,
143143
projection_size=dim,
144-
use_qkv_parallel=False,
144+
use_qkv_parallel=True,
145145
use_context_forward=use_context_forward,
146146
softmax_in_single_precision=softmax_in_single_precision,
147147
flatten_batch=flatten_batch,
@@ -325,7 +325,7 @@ def get_window_index(self, grid_thw):
325325

326326
@property
327327
def dtype(self) -> torch.dtype:
328-
return self.blocks[0].mlp.gate_proj.weight.dtype
328+
return self.patch_embed.proj.weight.dtype
329329

330330
@property
331331
def device(self) -> torch.device:
@@ -572,10 +572,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
572572
("gate_up_proj", "up_proj", 1),
573573
("gate_up_proj", "gate_proj", 0),
574574
]
575-
# Just for bnb 4bit
576-
is_bnb_weights = hasattr(
577-
weights, "gi_code"
578-
) and weights.gi_code.co_name.startswith("_quantized_4bit_generator")
579575
params_dict = dict(self.named_parameters(remove_duplicate=False))
580576
for name, loaded_weight in weights:
581577
if "rotary_emb.inv_freq" in name:
@@ -596,23 +592,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596592
weight_loader(param, loaded_weight, shard_id)
597593
break
598594
else:
599-
if "visual" in name and "qkv.weight" in name and not is_bnb_weights:
600-
visual_num_heads = self.config.vision_config.num_heads
601-
visual_embed_dim = self.config.vision_config.hidden_size
602-
head_size = visual_embed_dim // visual_num_heads
603-
loaded_weight = loaded_weight.view(
604-
3, visual_num_heads, head_size, visual_embed_dim
605-
)
606-
loaded_weight = loaded_weight.transpose(0, 1)
607-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
608-
elif "visual" in name and "qkv.bias" in name and not is_bnb_weights:
609-
visual_num_heads = self.config.vision_config.num_heads
610-
visual_embed_dim = self.config.vision_config.hidden_size
611-
head_size = visual_embed_dim // visual_num_heads
612-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
613-
loaded_weight = loaded_weight.transpose(0, 1)
614-
loaded_weight = loaded_weight.reshape(-1)
615-
616595
if "visual" in name:
617596
# adapt to VisionAttention
618597
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

0 commit comments

Comments
 (0)