Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
27531c8
[BugFix] qwen2.5vl enable_thinking=true and image_patch_id bug fix
CSWYF3634076 Sep 5, 2025
053df06
Merge branch 'PaddlePaddle:develop' into develop
CSWYF3634076 Sep 12, 2025
da7bfcd
[Docs]offine infer add apply_chat_template add_generation_prompt para…
CSWYF3634076 Sep 12, 2025
24da2ed
[Model]qwen2.5VL support --use-cudagraph
CSWYF3634076 Sep 15, 2025
485a41e
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test
CSWYF3634076 Sep 18, 2025
9753961
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test
CSWYF3634076 Sep 23, 2025
bd5ccac
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v2
CSWYF3634076 Sep 23, 2025
6a00349
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v3
CSWYF3634076 Sep 23, 2025
09283b7
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v4
CSWYF3634076 Sep 23, 2025
e4cb8c8
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v5
CSWYF3634076 Sep 23, 2025
dc50bba
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v6
CSWYF3634076 Sep 24, 2025
05e27f6
[Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v7
CSWYF3634076 Sep 24, 2025
3283aa4
Merge branch 'PaddlePaddle:develop' into develop
CSWYF3634076 Sep 24, 2025
103ea32
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Sep 24, 2025
047f875
qwen25vl v1 loader
CSWYF3634076 Sep 29, 2025
c025e2d
qwen25vl v1 loader v2
CSWYF3634076 Oct 13, 2025
db94b51
Merge branch 'PaddlePaddle:develop' into develop
CSWYF3634076 Oct 14, 2025
203d398
qwen25vl v1 loader v3
CSWYF3634076 Oct 14, 2025
415dd6c
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 14, 2025
f2b53a2
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 15, 2025
d7ed928
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 16, 2025
cbe70c1
qwen25vl v1 loader fix tp2 weight PySafeSlice
CSWYF3634076 Oct 16, 2025
eed8289
qwen25vl v1 loader no test
CSWYF3634076 Oct 20, 2025
e9fda8a
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 20, 2025
4e266c3
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 20, 2025
9a2ffbc
qwen25vl v1 loader add unit test
CSWYF3634076 Oct 22, 2025
7053ca8
Merge branch 'qwen-v1-loader' of https://github.com/CSWYF3634076/Fast…
CSWYF3634076 Oct 22, 2025
a25a59b
Merge branch 'PaddlePaddle:develop' into develop
CSWYF3634076 Oct 22, 2025
ab7c91d
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 22, 2025
5594a17
Merge branch 'PaddlePaddle:develop' into develop
CSWYF3634076 Oct 22, 2025
db2a3c0
qwen25vl v1 loader add unit test v2
CSWYF3634076 Oct 22, 2025
e92f2c2
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 22, 2025
16d83a4
qwen25vl v1 loader add torch unit test v3
CSWYF3634076 Oct 23, 2025
44a4ce7
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 23, 2025
65b7585
qwen25vl v1 loader add torch unit test v4
CSWYF3634076 Oct 24, 2025
076f1ef
Merge branch 'qwen-v1-loader' of https://github.com/CSWYF3634076/Fast…
CSWYF3634076 Oct 24, 2025
7316f49
qwen25vl v1 loader add torch unit test v5
CSWYF3634076 Oct 24, 2025
0f78bf0
qwen25vl v1 loader add torch unit test v6
CSWYF3634076 Oct 24, 2025
0b7f290
Merge branch 'develop' into qwen-v1-loader
CSWYF3634076 Oct 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def process_request_dict(self, request, max_model_len=None):
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.setdefault("enable_thinking", True)
request.setdefault("enable_thinking", False)
outputs = self.processor.request2ids(request)

else:
Expand All @@ -249,11 +249,8 @@ def process_request_dict(self, request, max_model_len=None):
if request.get("completion_token_ids"):
self.append_completion_tokens(outputs, request["completion_token_ids"])

enable_thinking = False
if request.get("chat_template_kwargs"):
chat_template_kwargs = request.get("chat_template_kwargs")
enable_thinking = chat_template_kwargs.get("enable_thinking", False)
request["enable_thinking"] = enable_thinking
# qwen25_vl not support thinking
request["enable_thinking"] = False

outputs = self.pack_outputs(outputs)

Expand Down
93 changes: 88 additions & 5 deletions fastdeploy/model_executor/models/qwen2_5_vl/dfnrope/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from functools import partial
from typing import Optional

import numpy as np
import paddle
Expand All @@ -30,7 +31,8 @@
)
from paddleformers.transformers.model_utils import PretrainedModel

from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.layers.utils import divide, get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs

from .activation import ACT2FN
from .configuration import DFNRopeVisionTransformerConfig
Expand Down Expand Up @@ -74,10 +76,18 @@ class VisionFlashAttention2(nn.Layer):
nn (_type_): _description_
"""

def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int = 1) -> None:
def __init__(
self,
dim: int,
num_heads: int = 16,
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
model_format: str = "",
) -> None:
super().__init__()
self.num_heads = num_heads
self.tensor_parallel_degree = tensor_parallel_degree
self.tensor_parallel_rank = tensor_parallel_rank

if tensor_parallel_degree > 1:
self.qkv = ColumnParallelLinear(
Expand All @@ -96,11 +106,52 @@ def __init__(self, dim: int, num_heads: int = 16, tensor_parallel_degree: int =
input_is_parallel=True,
has_bias=True,
)

# TODO(wangyafeng) Referring to the current situation of combining ernie vl
# with the framework, it should be possible to optimize it in the future
set_weight_attrs(self.qkv.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(
self.qkv.bias, {"weight_loader": self.weight_loader, "load_bias": True, "output_dim": True}
)
set_weight_attrs(self.proj.weight, {"output_dim": False})

else:
self.qkv = nn.Linear(dim, dim * 3, bias_attr=True)
self.proj = nn.Linear(dim, dim, bias_attr=True)

set_weight_attrs(self.qkv.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"})
self.head_dim = dim // num_heads # must added
self.num_heads = num_heads
self.hidden_size = dim
self.num_heads_per_rank = divide(self.num_heads, self.tensor_parallel_degree)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
load_bias = getattr(param, "load_bias", None)
if load_bias:
head_dim = self.hidden_size // self.num_heads
shard_weight = loaded_weight[...].reshape([3, self.num_heads, head_dim])
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([-1])
else:
shard_weight = loaded_weight[...].reshape(
[
self.hidden_size,
3,
self.num_heads,
self.head_dim,
]
)
shard_weight = np.split(shard_weight, self.tensor_parallel_degree, axis=-2)[self.tensor_parallel_rank]
shard_weight = shard_weight.reshape([self.hidden_size, -1])
shard_weight = get_tensor(shard_weight)
assert param.shape == shard_weight.shape, (
f" Attempted to load weight ({shard_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(shard_weight, False)

def forward(
self,
Expand Down Expand Up @@ -217,6 +268,7 @@ def __init__(
bias: bool = False,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
model_format: str = "",
) -> None:
super().__init__()
self.tensor_parallel_degree = tensor_parallel_degree
Expand Down Expand Up @@ -245,11 +297,23 @@ def __init__(
input_is_parallel=True,
has_bias=bias,
)
set_weight_attrs(self.gate_proj.weight, {"output_dim": True})
set_weight_attrs(self.up_proj.weight, {"output_dim": True})
set_weight_attrs(self.down_proj.weight, {"output_dim": False})
if bias:
set_weight_attrs(self.gate_proj.bias, {"output_dim": True})
set_weight_attrs(self.up_proj.bias, {"output_dim": True})
# set_weight_attrs(self.down_proj.bias, {"output_dim": False})

else:
self.gate_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias_attr=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias_attr=bias)

set_weight_attrs(self.gate_proj.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.up_proj.weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.down_proj.weight, {"weight_need_transpose": model_format == "torch"})

self.act = ACT2FN[hidden_act]

def forward(self, x) -> paddle.Tensor:
Expand Down Expand Up @@ -353,7 +417,9 @@ def __init__(
mlp_hidden_dim: int,
hidden_act: str = "gelu",
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
attn_implementation: str = "sdpa",
model_format: str = "",
) -> None:
"""_summary_

Expand All @@ -362,14 +428,15 @@ def __init__(
attn_implementation (str, optional): _description_. Defaults to "sdpa".
"""
super().__init__()

self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)

self.attn = VisionFlashAttention2(
dim=dim,
num_heads=num_heads,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
model_format=model_format,
)

self.mlp = VisionMlp(
Expand All @@ -378,6 +445,7 @@ def __init__(
bias=True,
hidden_act=hidden_act,
tensor_parallel_degree=tensor_parallel_degree,
model_format=model_format,
)

def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
Expand Down Expand Up @@ -408,7 +476,13 @@ class PatchMerger(nn.Layer):
nn (_type_): _description_
"""

def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
model_format: str = "",
) -> None:
"""_summary_

Args:
Expand All @@ -425,6 +499,9 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N
nn.Linear(self.hidden_size, dim, bias_attr=True),
)

set_weight_attrs(self.mlp[0].weight, {"weight_need_transpose": model_format == "torch"})
set_weight_attrs(self.mlp[2].weight, {"weight_need_transpose": model_format == "torch"})

def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""_summary_

Expand Down Expand Up @@ -470,6 +547,8 @@ def __init__(self, config, prefix_name: str = "") -> None:
hidden_size=config.vision_config.hidden_size,
)

model_format = getattr(config, "model_format", "")

head_dim = config.vision_config.hidden_size // config.vision_config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)

Expand All @@ -481,13 +560,17 @@ def __init__(self, config, prefix_name: str = "") -> None:
mlp_hidden_dim=config.vision_config.intermediate_size,
hidden_act=config.vision_config.hidden_act,
tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree,
tensor_parallel_rank=config.pretrained_config.tensor_parallel_rank,
model_format=model_format,
)
for _ in range(config.vision_config.depth)
]
)

self.merger = PatchMerger(
dim=config.vision_config.out_hidden_size, context_dim=config.vision_config.hidden_size
dim=config.vision_config.out_hidden_size,
context_dim=config.vision_config.hidden_size,
model_format=model_format,
)

@property
Expand Down
65 changes: 63 additions & 2 deletions fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import re
from functools import partial
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -182,6 +183,65 @@ def _init_vision_model(self, model_config) -> nn.Layer:
def name(self):
return "Qwen2_5_VLForConditionalGeneration"

@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.

Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""

from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
# 参数变量名与权重key不同的要做映射
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]

params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
# model_format = self.fd_config.model_config.model_format
# Because the prefix for Paddle is qwen2, and for Hugging Face it is model.
# if model_format == "torch":
# loaded_weight_name = loaded_weight_name.replace("model", "qwen2")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块是不是可以删掉

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)

if self.tie_word_embeddings:
# because we use lazy guard and is not initialized by default
if not self.lm_head.linear.weight._is_initialized():
self.lm_head.linear.weight.initialize()
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})

@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
Expand Down Expand Up @@ -235,8 +295,9 @@ def get_input_embeddings(
video_mask = ids_remove_padding == self.model.video_token_id
video_token_num = video_mask.sum()

# 由于框架只有 image_features,所以目前不支持图片和视频混合
# TODO(wangyafeng) 后续考虑支持传入 video_features
# Due to the fact that the framework only has image_features,
# it currently does not support mixing images and videos
# TODO(wangyafeng) Consider supporting the input of video_features in the future
if image_token_num > 0:
input_embeddings[image_mask] = image_features.cast(self.model._dtype)
if video_token_num > 0:
Expand Down
Loading