Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ environment_variables: dict[str, Callable[[], Any]] = {

# Enable debug mode (0 or 1)
"FD_DEBUG":
lambda: os.getenv("FD_DEBUG", "0"),
lambda: int(os.getenv("FD_DEBUG", "0")),

# FastDeploy log retention days
"FD_LOG_BACKUP_COUNT":
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ environment_variables: dict[str, Callable[[], Any]] = {

# 是否启用调试模式,可设置为 0 或 1
"FD_DEBUG":
lambda: os.getenv("FD_DEBUG", "0"),
lambda: int(os.getenv("FD_DEBUG", "0")),

# FastDeploy 日志保留天数
"FD_LOG_BACKUP_COUNT":
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.utils import current_package_version, envs

if envs.FD_DEBUG != "1":
if envs.FD_DEBUG != 1:
import logging

pf_logger.logger.setLevel(logging.INFO)
Expand Down
14 changes: 14 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import argparse
import json
import os
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -54,6 +55,16 @@ def nullable_str(x: str) -> Optional[str]:
return x if x else None


def get_model_architecture(model: str, model_config_name: Optional[str] = "config.json") -> Optional[str]:
config_path = os.path.join(model, model_config_name)
if os.path.exists(config_path):
model_config = json.load(open(config_path, "r", encoding="utf-8"))
architecture = model_config["architectures"][0]
return architecture
else:
return model


@dataclass
class EngineArgs:
# Model configuration parameters
Expand Down Expand Up @@ -432,6 +443,9 @@ def __post_init__(self):
if self.guided_decoding_backend != "off":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
envs.FD_ENABLE_MAX_PREFILL = 1

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""
Expand Down
1 change: 0 additions & 1 deletion fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def _get_num_new_tokens(self, request, token_budget):
if not self.config.model_config.enable_mm:
return num_new_tokens

request.with_image = False
inputs = request.multimodal_inputs
if inputs.get("patch_idx", None) is not None and inputs.get("patch_map", None) is not None:
pre_end_idx = request.num_computed_tokens
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# Log directory.
"FD_LOG_DIR": lambda: os.getenv("FD_LOG_DIR", "log"),
# Whether to use debug mode, can set 0 or 1
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
"FD_DEBUG": lambda: int(os.getenv("FD_DEBUG", "0")),
# Number of days to keep fastdeploy logs.
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
# Model download source, can set "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE".
Expand Down
10 changes: 5 additions & 5 deletions fastdeploy/input/paddleocr_vl_processor/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ def __init__(
self.temporal_conv_size = self.image_processor.temporal_patch_size

# Special tokens and IDs
self.image_token = "<|image_pad|>"

self.image_token = "<|IMAGE_PLACEHOLDER|>"
self.video_token = "<|video_pad|>"

self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.image_patch_id = self.image_token_id

self.vision_start = "<|vision_start|>"
self.vision_start = "<|IMAGE_START|>"
self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)

self.tokens_per_second = tokens_per_second
Expand Down Expand Up @@ -167,9 +168,8 @@ def text2ids(self, text, images=None, videos=None):
"vit_position_ids": [],
}
# Define placeholders and their lengths
# IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
IMAGE_PLACEHOLDER = "<|image_pad|>"
VIDEO_PLACEHOLDER = "<|video@placeholder|>"
IMAGE_PLACEHOLDER = self.image_token
VIDEO_PLACEHOLDER = self.video_token
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _get_legacy_logger(self, name, file_name, without_formater=False, print_to_c
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)

is_debug = int(envs.FD_DEBUG)
is_debug = envs.FD_DEBUG
# logger = logging.getLogger(name)
# 为了兼容原有接口,使用命名空间进行隔离,避免logger覆盖、混乱等问题
legacy_name = f"legacy.{name}"
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/models/paddleocr_vl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddleformers.transformers.configuration_utils import PretrainedConfig


class PPOCRVisionConfig(PretrainedConfig):
class PaddleOCRVisionConfig(PretrainedConfig):
model_type = "paddleocr_vl"
base_config_key = "vision_config"

Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
class PaddleOCRConfig(PretrainedConfig):
model_type = "paddleocr_vl"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"vision_config": PPOCRVisionConfig}
sub_configs = {"vision_config": PaddleOCRVisionConfig}

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
Expand Down
17 changes: 17 additions & 0 deletions fastdeploy/model_executor/models/paddleocr_vl/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import math
from typing import Optional

import paddle
import paddle.nn as nn
Expand Down Expand Up @@ -57,8 +58,10 @@ def __init__(self, text_config, vision_config, prefix=""):

self.pre_norm = nn.LayerNorm(self.vision_config.hidden_size, epsilon=1e-05)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size)
self.linear_1.weight.weight_loader = self.weight_loader
self.act = GELUActivation()
self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size)
self.linear_2.weight.weight_loader = self.weight_loader

def forward(self, image_features, image_grid_thw):
m1, m2 = self.merge_kernel_size
Expand Down Expand Up @@ -94,6 +97,20 @@ def forward(self, image_features, image_grid_thw):
hidden_states = self.linear_2(hidden_states)
return hidden_states

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)

def load_state_dict(self, state_dict):
params_dict = dict(self.named_parameters())
for param_name, param in params_dict.items():
Expand Down
122 changes: 76 additions & 46 deletions fastdeploy/model_executor/models/paddleocr_vl/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,20 @@
# limitations under the License.
"""

import os
from typing import List, Optional, Tuple, Union

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.functional.flash_attention import flash_attn_unpadded
from paddleformers.transformers.activations import ACT2FN
from paddleformers.transformers.model_utils import PretrainedModel

from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn

try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None

from .config import PPOCRVisionConfig
from .config import PaddleOCRVisionConfig


def rotate_half(x):
Expand Down Expand Up @@ -61,22 +56,51 @@ def apply_rotary_pos_emb_vision(x, cos, sin):
return x_embed.astype(orig_dtype)


class QKVLinear(nn.Linear):
def __init__(self, config, in_features, out_features, weight_attr=None, bias_attr=None):
super().__init__(in_features, out_features, weight_attr, bias_attr)
class SiglipAttention(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.in_features = in_features
self.out_features = out_features
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim
self.weight.weight_loader = self.weight_loader
self.bias.weight_loader = self.weight_loader
self.scale = self.head_dim**-0.5

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# qkv_linear
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias_attr=True)
self.qkv_proj.weight.weight_loader = self.qkv_weight_loader
self.qkv_proj.bias.weight_loader = self.qkv_weight_loader

# out_linear
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj.weight.weight_loader = self.out_proj_weight_loader

enable_fa3 = False
flash_attn_version = int(os.environ.get("FLAGS_flash_attn_version", "2"))
if flash_attn_version == 3:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
enable_fa3 = is_current_sm_supported and is_paddle_supported

if enable_fa3:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen

self.flash_attn_func = flash_attention_v3_varlen
self.flash_attn_kwargs = {}
else:
from paddle.nn.functional.flash_attention import flash_attn_unpadded

self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.scale, "training": False}

def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# Tensor parallelism splits the weight along the output_dim
loaded_weight = get_tensor(loaded_weight)
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.transpose([1, 0])

if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
Expand All @@ -90,7 +114,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
param_shard_offset = self.num_heads * self.head_dim * 2
param_shard_size = self.num_heads * self.head_dim

param = slice_fn(param, self.out_features, start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, -1, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
Expand All @@ -102,30 +126,19 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)


class SiglipAttention(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim
self.scale = self.head_dim**-0.5

self.qkv_proj = QKVLinear(config, self.embed_dim, self.embed_dim * 3, bias_attr=True)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
if is_current_sm_supported and is_paddle_supported:
self.flash_attn_func = flash_attention_v3_varlen
self.flash_attn_kwargs = {}
else:
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.scale, "training": False}
def out_proj_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)

def forward(
self,
Expand Down Expand Up @@ -170,7 +183,7 @@ def forward(
)[0]
# --------

attn_output = attn_output.reshape(seq_length, -1)
attn_output = attn_output.reshape((seq_length, -1))
attn_output = self.out_proj(attn_output)

return attn_output
Expand Down Expand Up @@ -315,11 +328,28 @@ def __init__(self, config):
super().__init__()
self.config = config
if config.hidden_act == "gelu_pytorch_tanh":
config.hidden_act = "silu"
config.hidden_act = "gelu_new"

self.activation_fn = ACT2FN[config.hidden_act]

self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.fc2.weight.weight_loader = self.weight_loader

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
Expand Down Expand Up @@ -576,7 +606,7 @@ def forward(
class SiglipMultiheadAttentionPoolingHead(nn.Layer):
"""Multihead Attention Pooling."""

def __init__(self, config: PPOCRVisionConfig):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()

self.probe = self.create_parameter(
Expand All @@ -601,7 +631,7 @@ def forward(self, hidden_state, key_padding_mask=None):


class SiglipVisionTransformer(nn.Layer):
def __init__(self, config: PPOCRVisionConfig):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
Expand Down Expand Up @@ -666,10 +696,10 @@ def forward(


class SiglipVisionModel(PretrainedModel):
config_class = PPOCRVisionConfig
config_class = PaddleOCRVisionConfig
main_input_name = "pixel_values"

def __init__(self, config: PPOCRVisionConfig, prefix=""):
def __init__(self, config: PaddleOCRVisionConfig, prefix=""):
super().__init__(config)
self.prefix_name = prefix
self.vision_model = SiglipVisionTransformer(config)
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
update_fd_config_for_mm(fd_config)
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):
fd_config.load_config.load_choices = "default"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥要两处都配置呢? 有点冗余

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为前一个在服务层,这个在引擎层。测试的时候发现服务层的设置传不到引擎层,参考775-789行中ENABLE_V1_KVCACHE_SCHEDULER的设置,在引擎层又设置了一遍。

architecture = fd_config.model_config.architectures[0]
if "PaddleOCR" in architecture:
envs.FD_ENABLE_MAX_PREFILL = 1
return fd_config


Expand Down
Loading