Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
8e0c8d4
support pool
lizexu123 Sep 2, 2025
57795ea
update pooling
lizexu123 Sep 8, 2025
302adb0
merge develop
lizexu123 Sep 8, 2025
87237b0
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Sep 9, 2025
cef99ec
add pooler_config and check
lizexu123 Sep 9, 2025
926c796
update
lizexu123 Sep 10, 2025
a76a43e
support AutoWeightsLoader load weight
lizexu123 Sep 12, 2025
344a8df
fix
lizexu123 Sep 12, 2025
5ec6a93
update
lizexu123 Sep 15, 2025
2235785
merge develop
lizexu123 Sep 15, 2025
98b32fc
delete print
lizexu123 Sep 15, 2025
4f90dfc
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 Sep 15, 2025
0ef4c9a
update pre-commit
lizexu123 Sep 15, 2025
1daacb7
fix
lizexu123 Sep 15, 2025
9d1a011
fix xpu
lizexu123 Sep 15, 2025
2eff16e
fix ModelRegistry->model_registry
lizexu123 Sep 15, 2025
9fcd05e
fix Copilot review
lizexu123 Sep 16, 2025
945cba5
fix pooler.py
lizexu123 Sep 16, 2025
6f545aa
delete StepPooler
lizexu123 Sep 16, 2025
222d1b2
fix abstract
lizexu123 Sep 16, 2025
97b4649
fix default_loader_v1
lizexu123 Sep 16, 2025
4cf6164
fix Pre Commit
lizexu123 Sep 16, 2025
db0a4bf
support torch qwen3 dense
lizexu123 Sep 16, 2025
2681765
add test and fix torch-qwen
lizexu123 Sep 16, 2025
10969cd
fix
lizexu123 Sep 16, 2025
669d712
fix
lizexu123 Sep 16, 2025
dd45025
adapter ci:
lizexu123 Sep 16, 2025
31b7311
fix review
lizexu123 Sep 17, 2025
e023c6b
fix pooling_params.py
lizexu123 Sep 17, 2025
cb80ce8
fix
lizexu123 Sep 17, 2025
893cdbb
fix tasks.py 2025
lizexu123 Sep 17, 2025
d4dcc3c
fix print and logger
lizexu123 Sep 17, 2025
41aa2c5
Modefy ModelRegistry and delete AutoWeightsLoader
lizexu123 Sep 18, 2025
8e92eb4
fix logger
lizexu123 Sep 18, 2025
91f777e
fix test_embedding
lizexu123 Sep 18, 2025
a90a091
delete T
lizexu123 Sep 18, 2025
6f6c549
fix ci bug
lizexu123 Sep 18, 2025
1fdf477
ernie4_5 model_registry
lizexu123 Sep 18, 2025
9e4d1fa
fix test
lizexu123 Sep 19, 2025
ebf4e0c
fix test
lizexu123 Sep 19, 2025
27ec018
support Qwen3-Embedding-0.6B tp=1 load
lizexu123 Sep 19, 2025
adc5b8f
fix extra code
lizexu123 Sep 19, 2025
5e264f3
fix
lizexu123 Sep 19, 2025
798f788
delete fix vocab_size
lizexu123 Sep 19, 2025
b72ac60
delete prepare_params_dict
lizexu123 Sep 19, 2025
a1de646
fix:
lizexu123 Sep 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/features/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla

```python
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
from fastdeploy.model_registry import ModelRegistry
from fastdeploy.model_executor.models.model_base import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
from fastdeploy.config import ErnieArchitectures

Expand Down
2 changes: 1 addition & 1 deletion docs/zh/features/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开

```python
# 文件:fd_add_dummy_model/__init__.py
from fastdeploy.model_registry import ModelRegistry
from fastdeploy.model_executor.models.model_base import ModelRegistry
from my_custom_model import MyModelForCasualLM, MyPretrainedModel

def register():
Expand Down
296 changes: 294 additions & 2 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,83 @@

import json
import os
from dataclasses import field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union

import paddle
import paddle.distributed as dist
from paddleformers.transformers.configuration_utils import PretrainedConfig
from typing_extensions import assert_never

import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.transformer_utils.config import get_pooling_config
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger

logger = get_logger("config", "config.log")

TaskOption = Literal["generate"]
TaskOption = Literal["auto", "generate", "embedding", "embed"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里embedding 和embed的区别是什么


RunnerType = Literal["generate", "pooling"]

RunnerOption = Literal["auto", "generate", "pooling"]

ConvertOption = Literal["auto", "none", "embed"]

ConvertType = Literal["none", "embed"]

_ResolvedTask = Literal["generate", "encode", "embed"]

_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"generate": [],
"pooling": ["embed"],
}

# Some model suffixes are based on auto classes from Transformers:
# https://huggingface.co/docs/transformers/en/model_doc/auto
# NOTE: Items higher on this list priority over lower ones
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForCausalLM", ("generate", "none")),
("ForConditionalGeneration", ("generate", "none")),
("ChatModel", ("generate", "none")),
("LMHeadModel", ("generate", "none")),
("ForTextEncoding", ("pooling", "embed")),
("EmbeddingModel", ("pooling", "embed")),
("ForSequenceClassification", ("pooling", "classify")),
("ForAudioClassification", ("pooling", "classify")),
("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")),
("RewardModel", ("pooling", "reward")),
# Let other `*Model`s take priority
("Model", ("pooling", "embed")),
]


def iter_architecture_defaults():
yield from _SUFFIX_TO_DEFAULTS


def try_match_architecture_defaults(
architecture: str,
*,
runner_type: Optional[RunnerType] = None,
convert_type: Optional[ConvertType] = None,
):
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
if (
(runner_type is None or runner_type == default_runner_type)
and (convert_type is None or convert_type == default_convert_type)
and architecture.endswith(suffix)
):
return suffix, (default_runner_type, default_convert_type)
return None


class MoEPhase:
Expand Down Expand Up @@ -133,6 +192,12 @@ def __init__(
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
self.runner = "auto"
self.convert = "auto"
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
Comment on lines +195 to +199
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不单独加一个PoolerConfig?把runner/convert等都加进去

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pooler.py中创建DisPatchPooler,里面的ResolvedPoolingConfig它的from_config中需要pooler_config


self.partial_rotary_factor: float = 1.0
for key, value in args.items():
if hasattr(self, key) and value != "None":
Expand Down Expand Up @@ -160,6 +225,7 @@ def __init__(
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)

architectures = self.architectures[0]

if MultimodalRegistry.contains_model(architectures):
self.enable_mm = True
else:
Expand All @@ -170,6 +236,39 @@ def __init__(
self.override_name_from_config()
self.read_from_env()
self.read_model_config()
self.runner_type = self._get_runner_type(self.architectures, self.runner)
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)

registry = self.registry
is_generative_model = registry.is_text_generation_model(self.architectures, self)
is_pooling_model = registry.is_pooling_model(self.architectures, self)

if self.runner_type == "generate" and not is_generative_model:
generate_converts = _RUNNER_CONVERTS["generate"]
if self.convert_type not in generate_converts:
raise ValueError("This model does not support '--runner generate.")
if self.runner_type == "pooling" and not is_pooling_model:
pooling_converts = _RUNNER_CONVERTS["pooling"]
if self.convert_type not in pooling_converts:
convert_option = "<" + "|".join(pooling_converts) + ">"
raise ValueError(
"This model does not support `--runner pooling`. "
f"You can pass `--convert {convert_option} to adapt "
"it into a pooling model."
)

self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
model_info, arch = registry.inspect_model_cls(self.architectures, self)
self._model_info = model_info
self._architecture = arch

self.pooler_config = self._init_pooler_config()

@property
def registry(self):
from fastdeploy.model_executor.models.model_base import ModelRegistry

return ModelRegistry()
Comment on lines +273 to +276
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_config怎么还返回了一个ModelRegistry呢

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我写的方法都不是类方法,旧的还是类方法


def override_name_from_config(self):
"""
Expand All @@ -193,7 +292,6 @@ def override_name_from_config(self):
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.

If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
Expand Down Expand Up @@ -234,6 +332,165 @@ def read_model_config(self):
f"Config file path: {config_path}"
)

def _get_default_runner_type(
self,
architectures: list[str],
) -> RunnerType:
registry = self.registry
if get_pooling_config(self.model, self.revision):
return "pooling"
for arch in architectures:
if arch in registry.get_supported_archs():
if registry.is_pooling_model(architectures, self):
return "pooling"
if registry.is_text_generation_model(architectures, self):
return "generate"
match = try_match_architecture_defaults(arch)
if match:
_, (runner_type, _) = match
return runner_type
return "generate"

def _get_default_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
) -> ConvertType:
registry = self.registry

for arch in architectures:
if arch in registry.get_supported_archs():
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
return "none"
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
return "none"
match = try_match_architecture_defaults(arch, runner_type=runner_type)
if match:
_, (_, convert_type) = match
return convert_type

# This is to handle Sentence Transformers models that use *ForCausalLM
# and also multi-modal pooling models which are not defined as
# Sentence Transformers models
if runner_type == "pooling":
return "embed"

return "none"

def _get_runner_type(
self,
architectures: list[str],
runner: RunnerOption,
) -> RunnerType:
if runner != "auto":
return runner

runner_type = self._get_default_runner_type(architectures)
if runner_type != "generate":
logger.info(
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
runner_type,
)

return runner_type

def _get_convert_type(
self,
architectures: list[str],
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert != "auto":
return convert

convert_type = self._get_default_convert_type(architectures, runner_type)

if convert_type != "none":
logger.info(
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
convert_type,
)

return convert_type

def _get_supported_generation_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry

supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
supported_tasks.append("generate")

# TODO:Temporarily does not support transcription.
return supported_tasks

def _get_default_pooling_task(
self,
architectures: list[str],
) -> Literal["embed"]:
# Temporarily does not support classification and reward.
for arch in architectures:
match = try_match_architecture_defaults(arch, runner_type="pooling")
if match:
_, (_, convert_type) = match
assert convert_type != "none"
return convert_type

return "embed"

def _get_supported_pooling_tasks(
self,
architectures: list[str],
convert_type: ConvertType,
) -> list[_ResolvedTask]:
registry = self.registry

supported_tasks = list[_ResolvedTask]()
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
supported_tasks.append("encode")

extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
supported_tasks.append(extra_task)

return supported_tasks

def _get_supported_tasks(
self,
architectures: list[str],
runner_type: RunnerType,
convert_type: ConvertType,
) -> list[_ResolvedTask]:
if runner_type == "generate":
return self._get_supported_generation_tasks(architectures, convert_type)
if runner_type == "pooling":
return self._get_supported_pooling_tasks(architectures, convert_type)

assert_never(runner_type)

def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)

pooler_config = self.override_pooler_config or PoolerConfig()

base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
for k, v in base_config.items():
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)

default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type

return pooler_config

return None

def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
Expand Down Expand Up @@ -850,6 +1107,41 @@ def __init__(
setattr(self, key, value)


class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""

pooling_type: Optional[str] = None
"""
The pooling method of the pooling model.
"""
# for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs. Defaults to True.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
When an input exceeds max_embed_len, it will be handled according to
the original max_model_len validation logic.
Defaults to None (i.e. set to max_model_len).
"""


class LoRAConfig:
"""LoRA Config"""

Expand Down
Loading