diff --git a/docs/features/plugins.md b/docs/features/plugins.md index 0fe97ef7b64..ed63ed5940f 100644 --- a/docs/features/plugins.md +++ b/docs/features/plugins.md @@ -18,7 +18,7 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla ```python # File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py -from fastdeploy.model_registry import ModelRegistry +from fastdeploy.model_executor.models.model_base import ModelRegistry from my_custom_model import MyModelForCasualLM, MyPretrainedModel from fastdeploy.config import ErnieArchitectures diff --git a/docs/zh/features/plugins.md b/docs/zh/features/plugins.md index 040233ef85e..e1601081667 100644 --- a/docs/zh/features/plugins.md +++ b/docs/zh/features/plugins.md @@ -18,7 +18,7 @@ FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开 ```python # 文件:fd_add_dummy_model/__init__.py -from fastdeploy.model_registry import ModelRegistry +from fastdeploy.model_executor.models.model_base import ModelRegistry from my_custom_model import MyModelForCasualLM, MyPretrainedModel def register(): diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f17791f147e..897e4e1ba1d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -18,12 +18,14 @@ import json import os +from dataclasses import field from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union import paddle import paddle.distributed as dist from paddleformers.transformers.configuration_utils import PretrainedConfig +from typing_extensions import assert_never import fastdeploy from fastdeploy import envs @@ -31,11 +33,68 @@ from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig +from fastdeploy.transformer_utils.config import get_pooling_config from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger logger = get_logger("config", "config.log") -TaskOption = Literal["generate"] +TaskOption = Literal["auto", "generate", "embedding", "embed"] + +RunnerType = Literal["generate", "pooling"] + +RunnerOption = Literal["auto", "generate", "pooling"] + +ConvertOption = Literal["auto", "none", "embed"] + +ConvertType = Literal["none", "embed"] + +_ResolvedTask = Literal["generate", "encode", "embed"] + +_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { + "generate": [], + "pooling": ["embed"], +} + +# Some model suffixes are based on auto classes from Transformers: +# https://huggingface.co/docs/transformers/en/model_doc/auto +# NOTE: Items higher on this list priority over lower ones +_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ + ("ForCausalLM", ("generate", "none")), + ("ForConditionalGeneration", ("generate", "none")), + ("ChatModel", ("generate", "none")), + ("LMHeadModel", ("generate", "none")), + ("ForTextEncoding", ("pooling", "embed")), + ("EmbeddingModel", ("pooling", "embed")), + ("ForSequenceClassification", ("pooling", "classify")), + ("ForAudioClassification", ("pooling", "classify")), + ("ForImageClassification", ("pooling", "classify")), + ("ForVideoClassification", ("pooling", "classify")), + ("ClassificationModel", ("pooling", "classify")), + ("ForRewardModeling", ("pooling", "reward")), + ("RewardModel", ("pooling", "reward")), + # Let other `*Model`s take priority + ("Model", ("pooling", "embed")), +] + + +def iter_architecture_defaults(): + yield from _SUFFIX_TO_DEFAULTS + + +def try_match_architecture_defaults( + architecture: str, + *, + runner_type: Optional[RunnerType] = None, + convert_type: Optional[ConvertType] = None, +): + for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults(): + if ( + (runner_type is None or runner_type == default_runner_type) + and (convert_type is None or convert_type == default_convert_type) + and architecture.endswith(suffix) + ): + return suffix, (default_runner_type, default_convert_type) + return None class MoEPhase: @@ -133,6 +192,12 @@ def __init__( self.eos_tokens_lens: int = 2 self.lm_head_fp32: bool = False self.model_format = "auto" + self.runner = "auto" + self.convert = "auto" + self.pooler_config: Optional["PoolerConfig"] = field(init=False) + self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None + self.revision = None + self.partial_rotary_factor: float = 1.0 self.num_nextn_predict_layers = 0 for key, value in args.items(): @@ -161,6 +226,7 @@ def __init__( self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size) architectures = self.architectures[0] + if MultimodalRegistry.contains_model(architectures): self.enable_mm = True else: @@ -171,6 +237,43 @@ def __init__( self.override_name_from_config() self.read_from_env() self.read_model_config() + self.runner_type = self._get_runner_type(self.architectures, self.runner) + self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert) + + registry = self.registry + is_generative_model = registry.is_text_generation_model(self.architectures, self) + is_pooling_model = registry.is_pooling_model(self.architectures, self) + is_multimodal_model = registry.is_multimodal_model(self.architectures, self) + + if self.runner_type == "generate" and not is_generative_model: + if is_multimodal_model: + pass + else: + generate_converts = _RUNNER_CONVERTS["generate"] + if self.convert_type not in generate_converts: + raise ValueError("This model does not support '--runner generate.") + if self.runner_type == "pooling" and not is_pooling_model: + pooling_converts = _RUNNER_CONVERTS["pooling"] + if self.convert_type not in pooling_converts: + convert_option = "<" + "|".join(pooling_converts) + ">" + raise ValueError( + "This model does not support `--runner pooling`. " + f"You can pass `--convert {convert_option} to adapt " + "it into a pooling model." + ) + + self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type) + model_info, arch = registry.inspect_model_cls(self.architectures, self) + self._model_info = model_info + self._architecture = arch + + self.pooler_config = self._init_pooler_config() + + @property + def registry(self): + from fastdeploy.model_executor.models.model_base import ModelRegistry + + return ModelRegistry() def override_name_from_config(self): """ @@ -194,7 +297,6 @@ def override_name_from_config(self): def read_from_env(self): """ Read configuration information from environment variables and update the object's attributes. - If an attribute is not present or is an empty string in the environment variables, use the default value. """ self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM) @@ -235,6 +337,165 @@ def read_model_config(self): f"Config file path: {config_path}" ) + def _get_default_runner_type( + self, + architectures: list[str], + ) -> RunnerType: + registry = self.registry + if get_pooling_config(self.model, self.revision): + return "pooling" + for arch in architectures: + if arch in registry.get_supported_archs(): + if registry.is_pooling_model(architectures, self): + return "pooling" + if registry.is_text_generation_model(architectures, self): + return "generate" + match = try_match_architecture_defaults(arch) + if match: + _, (runner_type, _) = match + return runner_type + return "generate" + + def _get_default_convert_type( + self, + architectures: list[str], + runner_type: RunnerType, + ) -> ConvertType: + registry = self.registry + + for arch in architectures: + if arch in registry.get_supported_archs(): + if runner_type == "generate" and registry.is_text_generation_model(architectures, self): + return "none" + if runner_type == "pooling" and registry.is_pooling_model(architectures, self): + return "none" + match = try_match_architecture_defaults(arch, runner_type=runner_type) + if match: + _, (_, convert_type) = match + return convert_type + + # This is to handle Sentence Transformers models that use *ForCausalLM + # and also multi-modal pooling models which are not defined as + # Sentence Transformers models + if runner_type == "pooling": + return "embed" + + return "none" + + def _get_runner_type( + self, + architectures: list[str], + runner: RunnerOption, + ) -> RunnerType: + if runner != "auto": + return runner + + runner_type = self._get_default_runner_type(architectures) + if runner_type != "generate": + logger.info( + "Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.", + runner_type, + ) + + return runner_type + + def _get_convert_type( + self, + architectures: list[str], + runner_type: RunnerType, + convert: ConvertOption, + ) -> ConvertType: + if convert != "auto": + return convert + + convert_type = self._get_default_convert_type(architectures, runner_type) + + if convert_type != "none": + logger.info( + "Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.", + convert_type, + ) + + return convert_type + + def _get_supported_generation_tasks( + self, + architectures: list[str], + convert_type: ConvertType, + ) -> list[_ResolvedTask]: + registry = self.registry + + supported_tasks = list[_ResolvedTask]() + if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]: + supported_tasks.append("generate") + + # TODO:Temporarily does not support transcription. + return supported_tasks + + def _get_default_pooling_task( + self, + architectures: list[str], + ) -> Literal["embed"]: + # Temporarily does not support classification and reward. + for arch in architectures: + match = try_match_architecture_defaults(arch, runner_type="pooling") + if match: + _, (_, convert_type) = match + assert convert_type != "none" + return convert_type + + return "embed" + + def _get_supported_pooling_tasks( + self, + architectures: list[str], + convert_type: ConvertType, + ) -> list[_ResolvedTask]: + registry = self.registry + + supported_tasks = list[_ResolvedTask]() + if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]: + supported_tasks.append("encode") + + extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type + supported_tasks.append(extra_task) + + return supported_tasks + + def _get_supported_tasks( + self, + architectures: list[str], + runner_type: RunnerType, + convert_type: ConvertType, + ) -> list[_ResolvedTask]: + if runner_type == "generate": + return self._get_supported_generation_tasks(architectures, convert_type) + if runner_type == "pooling": + return self._get_supported_pooling_tasks(architectures, convert_type) + + assert_never(runner_type) + + def _init_pooler_config(self) -> Optional["PoolerConfig"]: + if self.runner_type == "pooling": + if isinstance(self.override_pooler_config, dict): + self.override_pooler_config = PoolerConfig(**self.override_pooler_config) + + pooler_config = self.override_pooler_config or PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + for k, v in base_config.items(): + if getattr(pooler_config, k) is None: + setattr(pooler_config, k, v) + + default_pooling_type = self._model_info.default_pooling_type + if pooler_config.pooling_type is None: + pooler_config.pooling_type = default_pooling_type + + return pooler_config + + return None + def _get_download_model(self, model_name, model_type="default"): # TODO: Provide dynamic graph for self-downloading and save to the specified download directory. pass @@ -856,6 +1117,41 @@ def __init__( setattr(self, key, value) +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. + """ + # for embeddings models + normalize: Optional[bool] = None + """ + Whether to normalize the embeddings outputs. Defaults to True. + """ + dimensions: Optional[int] = None + """ + Reduce the dimensions of embeddings if model + support matryoshka representation. Defaults to None. + """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + When an input exceeds max_embed_len, it will be handled according to + the original max_model_len validation logic. + Defaults to None (i.e. set to max_model_len). + """ + + class LoRAConfig: """LoRA Config""" diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2dcbde6470c..65969a3cb09 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -18,13 +18,14 @@ import json from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import paddle from fastdeploy import envs from fastdeploy.config import ( CacheConfig, + ConvertOption, EarlyStopConfig, FDConfig, GraphOptimizationConfig, @@ -32,6 +33,8 @@ MobaAttentionConfig, ModelConfig, ParallelConfig, + PoolerConfig, + RunnerOption, SpeculativeConfig, TaskOption, ) @@ -95,6 +98,20 @@ class EngineArgs: """ The task to be executed by the model. """ + runner: RunnerOption = "auto" + """ + The type of model runner to use.Each FD instance only supports one model runner. + even if the same model can be used for multiple types. + """ + convert: ConvertOption = "auto" + """ + Convert the model using adapters. The most common use case is to + adapt a text generation model to be used for pooling tasks. + """ + override_pooler_config: Optional[Union[dict, PoolerConfig]] = None + """ + Override configuration for the pooler. + """ max_num_seqs: int = 8 """ Maximum number of sequences per iteration. @@ -475,6 +492,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.task, help="Task to be executed by the model.", ) + model_group.add_argument( + "--runner", + type=str, + default=EngineArgs.runner, + help="The type of model runner to use", + ) + model_group.add_argument( + "--convert", type=str, default=EngineArgs.convert, help="Convert the model using adapters" + ) + model_group.add_argument( + "--override-pooler-config", + type=json.loads, + default=EngineArgs.override_pooler_config, + help="Override the pooler configuration with a JSON string.", + ) model_group.add_argument( "--use-warmup", type=int, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 482ccbf7433..4eb71d2608d 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -498,6 +498,9 @@ def _start_worker_service(self): f" --load_choices {self.cfg.load_config.load_choices}" f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --ips {ips}" + f" --runner {self.cfg.model_config.runner}" + f" --convert {self.cfg.model_config.convert}" + f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" ) worker_append_flag = { diff --git a/fastdeploy/engine/pooling_params.py b/fastdeploy/engine/pooling_params.py new file mode 100644 index 00000000000..13d3f01e488 --- /dev/null +++ b/fastdeploy/engine/pooling_params.py @@ -0,0 +1,170 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from copy import deepcopy +from typing import TYPE_CHECKING, Annotated, Any, Optional + +import msgspec + +from fastdeploy.engine.sampling_params import RequestOutputKind +from fastdeploy.engine.tasks import PoolingTask + +if TYPE_CHECKING: + from fastdeploy.config import ModelConfig + + +class PoolingParams: + """API parameters for pooling models. + + Attributes: + normalize: Whether to normalize the embeddings outputs. + dimensions: Reduce the dimensions of embeddings + if model support matryoshka representation. + activation: Whether to apply activation function to + the classification outputs. + softmax: Whether to apply softmax to the reward outputs. + step_tag_id: Step tag ID for process reward models to identify + specific steps in multi-step reasoning tasks. + returned_token_ids: List of token IDs to return rewards for, + used for fine-grained reward calculation. + task: Internal use only. Specifies the pooling task type + ("embed" for embeddings, "encode" for reward models). + requires_token_ids: Internal use only. Whether token ID information + is required for processing. + extra_kwargs: Internal use only. Dictionary for storing additional + custom parameters for extended functionality. + output_kind: Output type specification, fixed to FINAL_ONLY + (only final outputs are returned). + """ + + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None + """If set to -1, will use the truncation size supported by the model. If + set to an integer k, will use only the last k tokens from the prompt + (i.e., left truncation). If set to `None`, truncation is disabled.""" + + # for embeddings models + dimensions: Optional[int] = None + normalize: Optional[bool] = None + + # for reward models + softmax: Optional[bool] = None + step_tag_id: Optional[int] = None + returned_token_ids: Optional[list[int]] = None + + task: Optional[PoolingTask] = None + """Internal use only.""" + + requires_token_ids: bool = False + """Internal use only.""" + + extra_kwargs: Optional[dict[str, Any]] = None + """Internal use only.""" + + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + + @property + def _all_parameters(self) -> list[str]: + return ["dimensions", "normalize", "softmax", "step_tag_id", "returned_token_ids"] + + @property + def valid_parameters(self): + return { + "embed": ["dimensions", "normalize"], + "encode": ["softmax", "step_tag_id", "returned_token_ids"], + } + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return deepcopy(self) + + def verify(self, task: PoolingTask, model_config: Optional["ModelConfig"] = None) -> None: + + if self.task is None: + self.task = task + elif self.task != task: + msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" + raise ValueError(msg) + + # NOTE: Task validation needs to done against the model instance, + # which is not available in model config. So, it's not included + # in this method + + self._merge_default_parameters(model_config) + self._set_default_parameters(model_config) + self._verify_valid_parameters() + + def _merge_default_parameters(self, model_config: Optional["ModelConfig"] = None) -> None: + + if model_config is None: + return + + pooler_config = model_config.pooler_config + if pooler_config is None: + return + + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + + for k in valid_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): + if self.task == "embed": + if self.normalize is None: + self.normalize = True + elif self.task == "encode": + if self.softmax is None: + self.softmax = True + else: + raise ValueError(f"Unknown pooling task: {self.task}") + + def _verify_valid_parameters(self): + assert self.task is not None, "task must be set" + valid_parameters = self.valid_parameters[self.task] + invalid_parameters = [] + for k in self._all_parameters: + if k in valid_parameters: + continue + + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"Task {self.task} only supports {valid_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters" + ) + + def __repr__(self) -> str: + return ( + f"PoolingParams(" + f"task={self.task}, " + f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " + f"requires_token_ids={self.requires_token_ids}, " + f"extra_kwargs={self.extra_kwargs})" + ) + + def __post_init__(self) -> None: + assert self.output_kind == RequestOutputKind.FINAL_ONLY, "For pooling output_kind has to be FINAL_ONLY" diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 73231b3be7c..a2fa81e08fe 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -18,6 +18,7 @@ import random from dataclasses import dataclass, fields +from enum import Enum from typing import Any, List, Optional, Union @@ -268,3 +269,12 @@ def __post_init__(self): "You can only use one kind of guided decoding " "('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')." ) + + +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOutput + FINAL_ONLY = 2 diff --git a/fastdeploy/engine/tasks.py b/fastdeploy/engine/tasks.py new file mode 100644 index 00000000000..70ff0eec398 --- /dev/null +++ b/fastdeploy/engine/tasks.py @@ -0,0 +1,25 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Literal, get_args + +GenerationTask = Literal["generate"] +GENERATION_TASKS = get_args(GenerationTask) + +PoolingTask = Literal["encode", "embed"] +POOLING_TASKS = get_args(PoolingTask) + +SupportedTask = Literal[GenerationTask, PoolingTask] diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index 29e86c642ce..2eb800de66a 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -146,3 +146,26 @@ def forward_gcu(self, x): if self.bias is not None: out = out + self.bias return out + + +def get_act_fn(act_fn_name: str) -> nn.Layer: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("paddle.nn.Layer"): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + + activation_map = { + "gelu": nn.GELU(), + "relu": nn.ReLU(), + "silu": nn.Silu(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + } + if act_fn_name in activation_map: + return activation_map[act_fn_name] + else: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") diff --git a/fastdeploy/model_executor/layers/pool/metadata.py b/fastdeploy/model_executor/layers/pool/metadata.py new file mode 100644 index 00000000000..2dd4d13fe40 --- /dev/null +++ b/fastdeploy/model_executor/layers/pool/metadata.py @@ -0,0 +1,85 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from dataclasses import dataclass +from typing import Optional + +import paddle + +from fastdeploy.engine.pooling_params import PoolingParams + + +@dataclass +class PoolingCursor: + index: list[int] + first_token_indices_gpu: paddle.Tensor + last_token_indices_gpu: paddle.Tensor + prompt_lens_cpu: paddle.Tensor + num_scheduled_tokens_cpu: paddle.Tensor + + def __getitem__(self, indices: slice): + return PoolingCursor( + index=self.index[indices], + first_token_indices_gpu=self.first_token_indices_gpu[indices], + last_token_indices_gpu=self.last_token_indices_gpu[indices], + prompt_lens_cpu=self.prompt_lens_cpu[indices], + num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], + ) + + def is_partial_prefill(self): + return not paddle.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu).item() + + +@dataclass +class PoolingMetadata: + """Tensors for pooling.""" + + prompt_lens: paddle.Tensor # CPU Tensor + prompt_token_ids: Optional[paddle.Tensor] + pooling_params: list[PoolingParams] + pooling_cursor: Optional[PoolingCursor] = None + + def __getitem__(self, indices: slice): + return PoolingMetadata( + prompt_lens=self.prompt_lens[indices], + prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], + pooling_params=self.pooling_params[indices], + pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices], + ) + + def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: str): + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, self.prompt_lens, device) + + +def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: str): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + index = list(range(n_seq)) + num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, device="cpu") + cumsum = paddle.zeros([n_seq + 1], dtype="int64", place=paddle.CPUPlace()) + paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:]) + if device == "gpu": + cumsum_device = cumsum.cuda() + else: + cumsum_device = cumsum + return PoolingCursor( + index=index, + first_token_indices_gpu=cumsum_device[:n_seq], + last_token_indices_gpu=cumsum_device[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens, + ) diff --git a/fastdeploy/model_executor/layers/pooler.py b/fastdeploy/model_executor/layers/pooler.py new file mode 100644 index 00000000000..06b18b4672a --- /dev/null +++ b/fastdeploy/model_executor/layers/pooler.py @@ -0,0 +1,550 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Set +from dataclasses import dataclass +from enum import IntEnum +from itertools import groupby +from typing import Callable, Optional, TypeVar, Union, cast + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from fastdeploy.config import FDConfig, ModelConfig, PoolerConfig +from fastdeploy.engine.tasks import PoolingTask +from fastdeploy.model_executor.layers.pool.metadata import ( + PoolingCursor, + PoolingMetadata, + PoolingParams, +) +from fastdeploy.model_executor.models.adapters import _load_st_projector +from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput +from fastdeploy.utils import get_logger + +logger = get_logger("pooler", "pooler.log") + +PoolingFn = Callable[ + [Union[paddle.Tensor, list[paddle.Tensor]], PoolingMetadata], Union[paddle.Tensor, list[paddle.Tensor]] +] +ClassifierFn = Callable[[paddle.Tensor], paddle.Tensor] + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + + LAST = 0 + ALL = 1 + CLS = 2 + STEP = 3 + MEAN = 4 + + +_T = TypeVar("_T", paddle.Tensor, list[paddle.Tensor]) + + +@dataclass(frozen=True) +class ResolvedPoolingConfig: + pooling_type: PoolingType + task: PoolingTask + + @classmethod + def from_config( + cls, + task: PoolingTask, + pooler_config: PoolerConfig, + ) -> "ResolvedPoolingConfig": + assert pooler_config.pooling_type is not None + return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type]) + + +def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: + pooling_params = pooling_metadata.pooling_params + return pooling_params + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + pooling_params = get_pooling_params(pooling_metadata) + + tasks: list[PoolingTask] = [task for pooling_param in pooling_params if (task := pooling_param.task) is not None] + assert len(pooling_params) == len(tasks) + + return tasks + + +def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[paddle.Tensor]: + assert ( + pooling_metadata.prompt_token_ids is not None + ), "Please set `requires_token_ids=True` in `get_pooling_updates`" + + return [pooling_metadata.prompt_token_ids[i, :num] for i, num in enumerate(pooling_metadata.prompt_lens)] + + +@dataclass(frozen=True) +class PoolingParamsUpdate: + requires_token_ids: bool = False + """Set this flag to enable `get_prompt_token_ids` for your pooler.""" + + def apply(self, params: PoolingParams) -> None: + params.requires_token_ids = self.requires_token_ids + + +class Pooler(nn.Layer, ABC): + """The interface required for all poolers used in pooling models in FastDeploy.""" + + @staticmethod + def for_encode(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None): + if pooler_config.pooling_type == "STEP": + return StepPooler() + + resolved_config = ResolvedPoolingConfig(task="encode", pooling_type=PoolingType.ALL) + return SimplePooler.from_config(resolved_config, model_config) + + @staticmethod + def for_embed(pooler_config: PoolerConfig, model_config: Optional["ModelConfig"] = None): + resolved_config = ResolvedPoolingConfig.from_config( + task="embed", + pooler_config=pooler_config, + ) + return SimplePooler.from_config(resolved_config, model_config) + + @staticmethod + def for_classify( + pooler_config: PoolerConfig, + classify: Optional[ClassifierFn], + ): + pass + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + """ + Construct the updated pooling parameters to use for a supported task. + """ + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: Union[list[paddle.Tensor], paddle.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + +class BasePoolerActication(nn.Layer, ABC): + + @abstractmethod + def forward(self, pooled_data: _T) -> _T: + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL + raise NotImplementedError + + +class PoolerActivation(BasePoolerActication): + + @staticmethod + def wraps(module: nn.Layer): + if isinstance(module, nn.Identity): + return PoolerIdentity() + if isinstance(module, (nn.Sigmoid, nn.Softmax)): + return PoolerClassify() + + return LambdaPoolerActivation(module) + + @abstractmethod + def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor: + raise NotImplementedError + + def forward(self, pooled_data: _T) -> _T: + if isinstance(pooled_data, list): + return [self.forward_chunk(data) for data in pooled_data] + + return self.forward_chunk(pooled_data) + + +class PoolerIdentity(PoolerActivation): + + def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor: + return pooled_data + + +class PoolerClassify(PoolerActivation): + + def __init__(self, *, static_num_labels: bool = True) -> None: + super().__init__() + + if static_num_labels: + fd_config = FDConfig() + self.num_labels = getattr(fd_config.model_config, "num_labels", 0) + if self.num_labels == 0: + logger.warning( + "num_labels should be > 0 for classification" + "models, falling back to softmax. " + "Please check if the configuration is correct." + ) + else: + self.num_labels = None + + def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor: + num_labels = self.num_labels if self.num_labels is not None else pooled_data.shape[-1] + if num_labels < 2: + return F.sigmoid(pooled_data.astype("float32")).astype(pooled_data.dtype) + + return F.softmax(pooled_data.astype("float32"), axis=-1).astype(pooled_data.dtype) + + +class LambdaPoolerActivation(PoolerActivation): + + def __init__(self, fn: Callable[[paddle.Tensor], paddle.Tensor]): + super().__init__() + + self.fn = fn + + def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor: + return self.fn(pooled_data) + + +class PoolerHead(nn.Layer): + + def __init__(self, activation: PoolerActivation) -> None: + super().__init__() + self.activation = activation + + def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata): + + return self.activation(pooled_data) + + +class EmbeddingPoolerHead(PoolerHead): + + def __init__(self, model_config: Optional["ModelConfig"] = None) -> None: + super().__init__(activation=PoolerNormalize()) + + self.projector = _load_st_projector(model_config) + + def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata): + + if isinstance(pooled_data, list): + pooled_data = paddle.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + projector = cast(nn.Layer, self.projector) + + def _proj(x: paddle.Tensor) -> paddle.Tensor: + orig_dtype = x.dtype + y = projector(x.astype("float32")) + return y.astype(orig_dtype) + + pooled_data = _proj(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + + pooling_params = get_pooling_params(pooling_metadata) + + # for matryoshka representation + dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] + if any(d is not None for d in dimensions_list): + # change the output dimension + assert len(pooled_data) == len(dimensions_list) + if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list): + # if all dimensions are the same + d = dimensions_list[0] + pooled_data = pooled_data[..., :d] + else: + pooled_data = [vecs if d is None else vecs[..., :d] for vecs, d in zip(pooled_data, dimensions_list)] + # for normalize + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)] + + # pooled_data shape: [batchsize, embedding_dimension] + return pooled_data + + +class RewardPoolerHead(PoolerHead): + + def __init__(self, model_config: Optional["ModelConfig"] = None) -> None: + super().__init__(activation=PoolerClassify(static_num_labels=False)) + self.model_config = model_config + + def forward(self, pooled_data: Union[list[paddle.Tensor], paddle.Tensor], pooling_metadata: PoolingMetadata): + pooling_params = get_pooling_params(pooling_metadata) + + # for softmax + flags = [p.softmax for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [self.activation(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)] + + return pooled_data + + +def build_output( + all_data: Union[paddle.Tensor, list[paddle.Tensor]], +) -> PoolerOutput: + # Pooling models D2H & synchronize occurs here + if isinstance(all_data, list): + all_data = [d.cpu() for d in all_data] + else: + all_data = all_data.cpu() + + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] + return PoolerOutput(outputs=all_outputs) + + +class PoolingMethod(nn.Layer, ABC): + + @staticmethod + def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": + if pooling_type == PoolingType.LAST: + return LastPool() + if pooling_type == PoolingType.ALL: + return AllPool() + if pooling_type == PoolingType.CLS: + return CLSPool() + if pooling_type == PoolingType.MEAN: + return MeanPool() + raise NotImplementedError(f"Unsupported method: {pooling_type}") + + +class LastPool(PoolingMethod): + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} + + def forward_all( + self, + hidden_states: paddle.Tensor, + pooling_cursor: PoolingCursor, + ) -> Union[list[paddle.Tensor], paddle.Tensor]: + return hidden_states[pooling_cursor.last_token_indices_gpu] + + +class AllPool(PoolingMethod): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} + + def forward_all( + self, + hidden_states: paddle.Tensor, + pooling_cursor: PoolingCursor, + ) -> Union[list[paddle.Tensor], paddle.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with ALL pooling" + + hidden_states_lst = list(hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())) + return [hidden_states_lst[i] for i in pooling_cursor.index] + + +class MeanPool(PoolingMethod): + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} + + def forward_all( + self, + hidden_states: paddle.Tensor, + pooling_cursor: PoolingCursor, + ) -> Union[list[paddle.Tensor], paddle.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with MEAN pooling" + + if hidden_states.place.is_gpu_place(): + prompt_lens = pooling_cursor.prompt_lens_cpu.cuda() + else: + prompt_lens = pooling_cursor.prompt_lens_cpu + + # Use float32 for paddle.cumsum in MeanPool, + # otherwise precision will be lost significantly. + cumsum = paddle.cumsum(hidden_states.astype("float32"), axis=0) + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + + +class CLSPool(PoolingMethod): + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} + + def forward_all( + self, + hidden_states: paddle.Tensor, + pooling_cursor: PoolingCursor, + ) -> Union[list[paddle.Tensor], paddle.Tensor]: + assert not pooling_cursor.is_partial_prefill(), "partial prefill not supported with CLS pooling" + + return hidden_states[pooling_cursor.first_token_indices_gpu] + + +class StepPooler(Pooler): + def __init__( + self, + ) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = RewardPoolerHead() + + def extract_states( + self, + hidden_states: Union[paddle.Tensor, list[paddle.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[paddle.Tensor], paddle.Tensor]: + pooled_data_lst = self.pooling(hidden_states, pooling_metadata) + prompt_token_ids = get_prompt_token_ids(pooling_metadata) + + pooled_data = list[paddle.Tensor]() + + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip(pooled_data_lst, prompt_token_ids, pooling_params): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: Union[paddle.Tensor, list[paddle.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) + + +class SimplePooler(Pooler): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + """ + + @classmethod + def from_config( + cls, + pooler_config: ResolvedPoolingConfig, + model_config: Optional["ModelConfig"] = None, + ) -> "SimplePooler": + pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) + if pooler_config.task == "embed": + head = EmbeddingPoolerHead(model_config) + elif pooler_config.task == "encode": + head = RewardPoolerHead(model_config) + else: + raise NotImplementedError(f"Unknown task: {pooler_config.task}") + return cls(pooling, head) + + def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: + super().__init__() + + self.pooling = pooling + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.pooling.get_pooling_updates(task) + + def forward( + self, + hidden_states: Union[paddle.Tensor, list[paddle.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) + + +class PoolerNormalize(PoolerActivation): + def forward_chunk(self, pooled_data: paddle.Tensor) -> paddle.Tensor: + x = F.normalize(pooled_data.astype("float32"), p=2, axis=-1) + return x.astype(pooled_data.dtype) + + +class DispatchPooler(Pooler): + """Dispatches calls to a sub-pooler based on the pooling task.""" + + def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: + super().__init__() + + for task, pooler in poolers_by_task.items(): + if task not in pooler.get_supported_tasks(): + raise ValueError( + f"{pooler=} does not support {task=}. " f"Supported tasks: {pooler.get_supported_tasks()}" + ) + + self.poolers_by_task = poolers_by_task + + def get_supported_tasks(self) -> Set[PoolingTask]: + return set(self.poolers_by_task) + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.poolers_by_task[task].get_pooling_updates(task) + + def forward( + self, + hidden_states: Union[paddle.Tensor, list[paddle.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + poolers_by_task = self.poolers_by_task + + outputs = list[PoolingSequenceGroupOutput]() + offset = 0 + for task, group in groupby(get_tasks(pooling_metadata)): + if not (pooler := poolers_by_task.get(task)): + raise ValueError(f"Unsupported task: {task} " f"Supported tasks: {self.get_supported_tasks()}") + + num_items = len(list(group)) + group_output: PoolerOutput = pooler( + hidden_states, + pooling_metadata[offset : offset + num_items], + ) + outputs.extend(group_output.outputs) + offset += num_items + + return PoolerOutput(outputs) diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index 6ebb142532d..4ec3f139126 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -61,6 +61,7 @@ def load_weights(self, model, fd_config: FDConfig, architectures: str) -> None: fd_config, return_numpy=True, ) + model.set_state_dict(state_dict) self.clean_memory_fragments(state_dict) diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index f6ecb43f77e..3e700ca2747 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -16,7 +16,7 @@ import paddle from paddle import nn -from paddleformers.utils.log import logger +from typing_extensions import assert_never from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.model_executor.load_weight_utils import ( @@ -27,6 +27,7 @@ save_model, ) from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader +from fastdeploy.model_executor.models.adapters import as_embedding_model from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.platforms import current_platform @@ -54,11 +55,11 @@ def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) - load_weights_form_cache(model, weights_iterator) else: model.load_weights(weights_iterator) + self.clean_memory_fragments() def load_model(self, fd_config: FDConfig) -> nn.Layer: architectures = fd_config.model_config.architectures[0] - logger.info(f"Starting to load model {architectures}") context = paddle.LazyGuard() if fd_config.load_config.dynamic_load_weight: # register rl model @@ -70,6 +71,14 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: with weight_cache_context: with context: model_cls = ModelRegistry.get_class(architectures) + convert_type = fd_config.model_config.convert_type + if convert_type == "none": + pass + elif convert_type == "embed": + model_cls = as_embedding_model(model_cls) + else: + assert_never(convert_type) + model = model_cls(fd_config) model.eval() diff --git a/fastdeploy/model_executor/models/__init__.py b/fastdeploy/model_executor/models/__init__.py index e96d65b1851..9ac761d2f68 100644 --- a/fastdeploy/model_executor/models/__init__.py +++ b/fastdeploy/model_executor/models/__init__.py @@ -47,8 +47,10 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode module = importlib.import_module(f"{register_path}.{module_file}") for attr_name in dir(module): attr = getattr(module, attr_name) + if inspect.isclass(attr) and issubclass(attr, ModelForCasualLM) and attr is not ModelForCasualLM: ModelRegistry.register_model_class(attr) + if ( inspect.isclass(attr) and issubclass(attr, PretrainedModel) @@ -56,6 +58,7 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode and hasattr(attr, "arch_name") ): ModelRegistry.register_pretrained_model(attr) + except ImportError: raise ImportError(f"{module_file=} import error") diff --git a/fastdeploy/model_executor/models/adapters.py b/fastdeploy/model_executor/models/adapters.py new file mode 100644 index 00000000000..d56c1dcb1f4 --- /dev/null +++ b/fastdeploy/model_executor/models/adapters.py @@ -0,0 +1,214 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from collections.abc import Iterable +from typing import Optional, TypeVar + +import paddle +import paddle.nn as nn + +from fastdeploy.config import ModelConfig +from fastdeploy.model_executor.layers.activation import get_act_fn +from fastdeploy.model_executor.models.interfaces_base import is_pooling_model +from fastdeploy.transformer_utils.config import get_hf_file_to_dict + +_T = TypeVar("_T", bound=type[nn.Layer]) + +_GENERATE_SUFFIXES = [ + "ForCausalLM", + "ForConditionalGeneration", + "ChatModel", + "LMHeadModel", +] + + +def _load_dense_weights(linear: nn.Linear, folder: str, model_config: "ModelConfig") -> bool: + """Load weights using vLLM's weight_loader pattern.""" + + from fastdeploy.model_executor.utils import default_weight_loader + + filename = "model.safetensors" + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_to_dict(file_path, model_config.model, model_config.revision) + if not file_bytes: + return False + + state_dict = {} + if filename.endswith(".safetensors"): + import io + + from safetensors.numpy import load as load_safetensors + + numpy_tensors = load_safetensors(io.BytesIO(file_bytes)) + for key, numpy_array in numpy_tensors.items(): + state_dict[key] = paddle.to_tensor(numpy_array) + else: + import io + + state_dict = paddle.load(io.BytesIO(file_bytes)) + + weight_keys = ["weight", "linear.weight", "dense.weight"] + + for weight_key in weight_keys: + if weight_key in state_dict: + weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader) + weight_loader(linear.weight, state_dict[weight_key].astype(paddle.float32)) + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader) + bias_loader(linear.bias, state_dict[bias_key].astype(paddle.float32)) + return True + except Exception as e: + print(f"Failed to load :{e}") + return False + return False + + +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Layer]: + try: + modules = get_hf_file_to_dict("modules.json", model_config.model, model_config.revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [m for m in modules if m.get("type") == "sentence_transformers.models.Dense"] + if not dense_modules: + return None + + layers = [] + for module in dense_modules: + folder = module.get("path", "") + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, model_config.revision) + if not layer_config: + continue + linear = nn.Linear( + layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + ) + linear = linear.astype(paddle.float32) + + if not _load_dense_weights(linear, folder, model_config): + continue + + layers.append(linear) + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).astype(paddle.float32) + except Exception as e: + print(f"ST projector loading failed:{e}") + + return None + + +def _create_pooling_model_cls(orig_cls: _T) -> _T: + + class ModelForPooling(orig_cls): + + def __init__(self, fd_config, *args, **kwargs): + super().__init__(fd_config, *args, **kwargs) + self.fd_config = fd_config + self.is_pooling_model = True + + # These are not used in pooling models + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "pooler", None): + self._init_pooler(fd_config) + + def _init_pooler(self, fd_config): + raise NotImplementedError + + def load_weights(self, weights: Iterable[tuple[str, paddle.Tensor]]): + # TODO: Support uninitialized params tracking + + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights if not name.startswith("lm_head.")) + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or not any(child.parameters()) for name, child in self.named_children() + ) + if model_is_only_param: + weights = ((name[6:], data) for name, data in weights if name.startswith("model.")) + loaded_params = self.model.load_weights(weights) + loaded_params = {f"model.{name}" for name in loaded_params} + return loaded_params + + # For most other models + if hasattr(orig_cls, "load_weights"): + return orig_cls.load_weights(self, weights) # type: ignore + # Fallback + else: + raise ValueError("No load_weights method found in the model.") + + return ModelForPooling + + +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + + for generate_suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(generate_suffix) + return model_name + pooling_suffix + + +def as_embedding_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support embeddings. + + By default, the embeddings of the whole prompt are extracted from the + normalized hidden state corresponding to the last token. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + from fastdeploy.model_executor.layers.pooler import DispatchPooler, Pooler + + class ModelForEmbedding(_create_pooling_model_cls(cls)): + + def _init_pooler(self, fd_config, prefix: str = ""): + pooler_config = fd_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config, fd_config.model_config), + "embed": Pooler.for_embed(pooler_config, fd_config.model_config), + }, + ) + + ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") + + return ModelForEmbedding diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 07ca23e5d25..1c79b381f77 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -48,7 +48,11 @@ from fastdeploy.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, ) -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.platforms import current_platform if current_platform.is_cuda(): @@ -588,6 +592,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="DeepseekV3ForCausalLM", + module_path="deepseek_v3", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class DeepseekV3ForCausalLM(ModelForCasualLM): """ DeepseekV3ForCausalLM diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 24a4ce5d046..b4ccfe6398d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -45,7 +45,11 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import WeightMeta @@ -471,6 +475,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Ernie4_5_MoeForCausalLM", + module_path="ernie4_5_moe", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Ernie4_5_MoeForCausalLM(ModelForCasualLM): """ Ernie4_5_MoeForCausalLM @@ -646,6 +656,12 @@ def clear_grpah_opt_backend(self): self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) +@ModelRegistry.register_model_class( + architecture="Ernie4_5_ForCausalLM", + module_path="ernie4_5_moe", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM): """ Ernie4_5_ForCausalLM @@ -659,6 +675,12 @@ def name(self): return "Ernie4_5_ForCausalLM" +@ModelRegistry.register_model_class( + architecture="Ernie4_5ForCausalLM", + module_path="ernie4_5_moe", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM): """ Ernie4_5ForCausalLM 0.3B-PT diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index a722b2e5635..f5b23727401 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -31,7 +31,11 @@ from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) class Ernie4_5_MTPPretrainedModel(PretrainedModel): @@ -325,6 +329,12 @@ def forward( return hidden_states +@ModelRegistry.register_model_class( + architecture="Ernie4_5_MTPForCausalLM", + module_path="ernie4_5_mtp", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Ernie4_5_MTPForCausalLM(ModelForCasualLM): """ Ernie4_5_MTPForCausalLM diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index a36f9d4db2f..fc71b9daf89 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -44,7 +44,11 @@ Ernie4_5_Attention, Ernie4_5_MLP, ) -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.platforms import current_platform if current_platform.is_cuda(): @@ -792,6 +796,12 @@ def clear_grpah_opt_backend(self): self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) +@ModelRegistry.register_model_class( + architecture="Ernie4_5_VLMoeForConditionalGeneration", + module_path="ernie4_5_vl.ernie4_5_vl_moe", + category=ModelCategory.MULTIMODAL, + primary_use=ModelCategory.MULTIMODAL, +) class Ernie4_5_VLPretrainedModel(PretrainedModel): """ Ernie4_5_MoePretrainedModel diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index fdbf277afb8..2029885af52 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) class Glm4MoeMLP(nn.Layer): @@ -363,6 +367,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Glm4MoeForCausalLM", + module_path="glm4_moe", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Glm4MoeForCausalLM(ModelForCasualLM): """ Glm4MoeForCausalLM diff --git a/fastdeploy/model_executor/models/interfaces_base.py b/fastdeploy/model_executor/models/interfaces_base.py new file mode 100644 index 00000000000..b7ece5fe69a --- /dev/null +++ b/fastdeploy/model_executor/models/interfaces_base.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from paddle import nn + + +def is_text_generation_model(model_cls: Type[nn.Layer]) -> bool: + from .model_base import ModelForCasualLM + + return issubclass(model_cls, ModelForCasualLM) + + +def is_pooling_model(model_cls: Type[nn.Layer]) -> bool: + class_name = model_cls.__name__ + pooling_indicators = ["Embedding", "ForSequenceClassification"] + return ( + any(indicator in class_name for indicator in pooling_indicators) + or hasattr(model_cls, "is_embedding_model") + and model_cls.is_embedding_model + ) + + +def is_multimodal_model(class_name: str) -> bool: + multimodal_indicators = ["VL", "Vision", "ConditionalGeneration"] + return any(indicator in class_name for indicator in multimodal_indicators) + + +def determine_model_category(class_name: str): + from fastdeploy.model_executor.models.model_base import ModelCategory + + if any(pattern in class_name for pattern in ["VL", "Vision", "ConditionalGeneration"]): + return ModelCategory.MULTIMODAL + elif any(pattern in class_name for pattern in ["Embedding", "ForSequenceClassification"]): + return ModelCategory.EMBEDDING + return ModelCategory.TEXT_GENERATION + + +def get_default_pooling_type(model_cls: Type[nn.Layer] = None) -> str: + if model_cls is not None: + return getattr(model_cls, "default_pooling_type", "LAST") + return "LAST" diff --git a/fastdeploy/model_executor/models/model_base.py b/fastdeploy/model_executor/models/model_base.py index 06f0d0705d7..627d9050b11 100644 --- a/fastdeploy/model_executor/models/model_base.py +++ b/fastdeploy/model_executor/models/model_base.py @@ -3,40 +3,269 @@ # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, + @@ -12,31 +11,265 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ +import importlib from abc import ABC, abstractmethod -from typing import Dict, Union +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from typing import Dict, List, Optional, Tuple, Type, Union import numpy as np import paddle from paddle import nn from paddleformers.transformers import PretrainedModel +from fastdeploy.config import ( + ModelConfig, + iter_architecture_defaults, + try_match_architecture_defaults, +) +from fastdeploy.model_executor.models.interfaces_base import ( + determine_model_category, + get_default_pooling_type, + is_multimodal_model, + is_pooling_model, + is_text_generation_model, +) -class ModelRegistry: - """ - Used to register and retrieve model classes. - """ +class ModelCategory(Enum): + TEXT_GENERATION = "text_generation" + MULTIMODAL = "multimodal" + EMBEDDING = "embedding" + + +@dataclass(frozen=True) +class ModelInfo: + architecture: str + category: ModelCategory + is_text_generation: bool + is_multimodal: bool + is_pooling: bool + module_path: str + default_pooling_type: str + + @staticmethod + def from_model_cls(model_cls: Type[nn.Layer], module_path: str = "") -> "ModelInfo": + return ModelInfo( + architecture=model_cls.__name__, + category=determine_model_category(model_cls.__name__), + is_text_generation=is_text_generation_model(model_cls), + is_multimodal=is_multimodal_model(model_cls.__name__), + is_pooling=is_pooling_model(model_cls), + default_pooling_type=get_default_pooling_type(model_cls), + module_path=module_path, + ) + + +class BaseRegisteredModel(ABC): + """Base class for registered models""" + + @abstractmethod + def load_model_cls(self) -> Type[nn.Layer]: + raise NotImplementedError + + @abstractmethod + def inspect_model_cls(self) -> ModelInfo: + raise NotImplementedError + + +@dataclass(frozen=True) +class LazyRegisteredModel(BaseRegisteredModel): + """Lazy loaded model""" + + module_name: str + class_name: str + + def load_model_cls(self) -> Type[nn.Layer]: + try: + full_module = f"fastdeploy.model_executor.models.{self.module_name}" + module = importlib.import_module(full_module) + return getattr(module, self.class_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Failed to load {self.class_name}: {e}") + + def inspect_model_cls(self) -> ModelInfo: + model_cls = self.load_model_cls() + return ModelInfo.from_model_cls(model_cls, self.module_name) + + +@dataclass(frozen=True) +class RegisteredModel(BaseRegisteredModel): + + model_cls: Type[nn.Layer] + + def load_model_cls(self) -> Type[nn.Layer]: + return self.model_cls + + def inspect_model_cls(self) -> ModelInfo: + return ModelInfo.from_model_cls(self.model_cls) + + +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: BaseRegisteredModel, +) -> Optional[ModelInfo]: + try: + return model.inspect_model_cls() + except Exception: + print("Error in inspecting model architecture '%s'", model_arch) + return None + + +class ModelRegistry: _arch_to_model_cls = {} _arch_to_pretrained_model_cls = {} + _enhanced_models: Dict[str, Dict] = {} + + def __init__(self): + self.models: Dict[str, BaseRegisteredModel] = {} + self.pretrained_models: Dict[str, Type[PretrainedModel]] = {} + self._registered_models: Dict[str, BaseRegisteredModel] = {} + self._register_enhanced_models() + + def _register_enhanced_models(self): + for arch, model_info in self._enhanced_models.items(): + model = LazyRegisteredModel(module_name=model_info["module_path"], class_name=model_info["class_name"]) + self.models[arch] = model + self._registered_models[arch] = model + + @lru_cache(maxsize=128) + def _try_load_model_cls(self, architecture: str) -> Optional[Type[nn.Layer]]: + if architecture not in self.models: + return None + try: + return self.models[architecture].load_model_cls() + except Exception as e: + print(f"Failed to load model {architecture}: {e}") + return None + + @lru_cache(maxsize=128) + def _try_inspect_model_cls(self, model_arch: str) -> Optional[ModelInfo]: + if model_arch not in self.models: + return None + try: + return self.models[model_arch].inspect_model_cls() + except Exception as e: + print(f"Failed to inspect model {model_arch}: {e}") + return None + + def _normalize_arch(self, architecture: str, model_config: ModelConfig) -> str: + if architecture in self.models: + return architecture + + match = try_match_architecture_defaults( + architecture, + runner_type=getattr(model_config, "runner_type", None), + convert_type=getattr(model_config, "convert_type", None), + ) + if match: + suffix, _ = match + for repl_suffix, _ in iter_architecture_defaults(): + base_arch = architecture.replace(suffix, repl_suffix) + if base_arch in self.models: + return base_arch + + return architecture + + def _raise_for_unsupported(self, architectures: list[str]): + all_supported_archs = self.get_supported_archs() + + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed to be inspected. " + "Please check the logs for more details." + ) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) + + def inspect_model_cls( + self, architectures: Union[str, List[str]], model_config: ModelConfig = None + ) -> Tuple[ModelInfo, str]: + if isinstance(architectures, str): + architectures = [architectures] + + if not architectures: + raise ValueError("No model architectures are specified") + + for arch in architectures: + normalized_arch = self._normalize_arch(arch, model_config) + model_info = self._try_inspect_model_cls(normalized_arch) + if model_info is not None: + return (model_info, arch) + + return self._raise_for_unsupported(architectures) @classmethod - def register_model_class(cls, model_class): - """register model class""" - if issubclass(model_class, ModelForCasualLM) and model_class is not ModelForCasualLM: - cls._arch_to_model_cls[model_class.name()] = model_class - return model_class + def register_model_class( + cls, + model_class=None, + *, + architecture: str = None, + module_path: str = None, + category: Union[ModelCategory, List[ModelCategory]] = ModelCategory.TEXT_GENERATION, + primary_use: ModelCategory = None, + ): + """ + Enhanced model class registration supporting both traditional and decorator-style registration. + + Can be used as: + 1. Traditional decorator: @ModelRegistry.register_model_class + 2. Enhanced decorator with metadata: @ModelRegistry.register_model_class(architecture="...", module_path="...") + + Args: + model_class: The model class (when used as simple decorator) + architecture (str): Unique identifier for the model architecture + module_path (str): Relative path to the module containing the model + category: Model category or list of categories + primary_use: Primary category for multi-category models + """ + + def _register(model_cls): + # Traditional registration for ModelForCasualLM subclasses + if issubclass(model_cls, ModelForCasualLM) and model_cls is not ModelForCasualLM: + cls._arch_to_model_cls[model_cls.name()] = model_cls + + # Enhanced decorator-style registration + if architecture and module_path: + categories = category if isinstance(category, list) else [category] + + # Register main entry + arch_key = architecture + cls._enhanced_models[arch_key] = { + "class_name": model_cls.__name__, + "module_path": module_path, + "category": primary_use or categories[0], + "class": model_cls, + } + + # Register category-specific entries for multi-category models + if len(categories) > 1: + for cat in categories: + key = f"{arch_key}_{cat.value}" + cls._enhanced_models[key] = { + "class_name": model_cls.__name__, + "module_path": module_path, + "category": cat, + "primary_use": primary_use or categories[0], + "class": model_cls, + } + return model_cls + + if model_class is not None: + return _register(model_class) + else: + return _register @classmethod def register_pretrained_model(cls, pretrained_model): @@ -50,11 +279,6 @@ def register_pretrained_model(cls, pretrained_model): return pretrained_model - @classmethod - def get_pretrain_cls(cls, architectures: str): - """get_pretrain_cls""" - return cls._arch_to_pretrained_model_cls[architectures] - @classmethod def get_class(cls, name): """get model class""" @@ -62,12 +286,61 @@ def get_class(cls, name): raise ValueError(f"Model '{name}' is not registered!") return cls._arch_to_model_cls[name] + @classmethod + def get_pretrain_cls(cls, architectures: str): + """get_pretrain_cls""" + return cls._arch_to_pretrained_model_cls[architectures] + @classmethod def get_supported_archs(cls): - assert len(cls._arch_to_model_cls) >= len( - cls._arch_to_pretrained_model_cls - ), "model class num is more than pretrained model registry num" - return [key for key in cls._arch_to_model_cls.keys()] + traditional_archs = list(cls._arch_to_model_cls.keys()) + enhanced_archs = list(cls._enhanced_models.keys()) + return traditional_archs + enhanced_archs + + def resolve_model_cls(self, architectures: Union[str, List[str]]) -> Tuple[Type[nn.Layer], str]: + """Resolve model class""" + if isinstance(architectures, str): + architectures = [architectures] + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return model_cls, arch + + raise ValueError(f"Cannot find supported model: {architectures}") + + def is_multimodal_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool: + """Check if it's a multimodal model""" + if isinstance(architectures, str): + architectures = [architectures] + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info.is_multimodal + return False + + def is_text_generation_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool: + """Check if it's a text generation model""" + if isinstance(architectures, str): + architectures = [architectures] + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info.is_text_generation + return False + + def is_pooling_model(self, architectures: Union[str, List[str]], model_config: ModelConfig = None) -> bool: + """Check if it's a pooling model""" + if isinstance(architectures, str): + architectures = [architectures] + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info.is_pooling + return False class ModelForCasualLM(nn.Layer, ABC): @@ -88,7 +361,6 @@ def __init__(self, configs): def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): """ Load model parameters from a given state dictionary. - Args: state_dict (dict[str, np.ndarray | paddle.Tensor]): A dictionary containing model parameters, where keys are parameter names @@ -105,12 +377,10 @@ def forward( ): """ Defines the forward pass of the model for generating text. - Args: input_ids (Tensor, optional): The input token ids to the model. pos_emb (Tensor, optional): position Embeddings for model. **model_kwargs: Additional keyword arguments for the model. - Returns: Tensor or list of Tensors: Generated tokens or decoded outputs. """ diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index e86533f9440..1f95bdc4a6d 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -39,7 +39,11 @@ ) from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) class Qwen2MLP(nn.Layer): @@ -282,6 +286,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Qwen2ForCausalLM", + module_path="qwen2", + category=[ModelCategory.TEXT_GENERATION, ModelCategory.EMBEDDING], + primary_use=ModelCategory.TEXT_GENERATION, +) class Qwen2ForCausalLM(ModelForCasualLM): """ Qwen2ForCausalLM diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index 5de437ef630..8992b1c18fc 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -33,7 +33,11 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer from fastdeploy.platforms import current_platform @@ -157,6 +161,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Qwen2_5_VLForConditionalGeneration", + module_path="qwen2_5_vl.qwen2_5_vl", + category=ModelCategory.MULTIMODAL, + primary_use=ModelCategory.MULTIMODAL, +) class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): """ Qwen2_5_VLForConditionalGeneration diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index d0c532cfc5e..47ed104babf 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -34,8 +34,13 @@ from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP +from fastdeploy.transformer_utils.config import get_pooling_config class Qwen3MLP(Qwen2MLP): @@ -218,6 +223,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Qwen3ForCausalLM", + module_path="qwen3", + category=[ModelCategory.TEXT_GENERATION], + primary_use=ModelCategory.TEXT_GENERATION, +) class Qwen3ForCausalLM(ModelForCasualLM): """ Qwen3ForCausalLM @@ -260,6 +271,8 @@ def load_weights(self, weights_iterator) -> None: process_weights_after_loading, ) + is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -270,8 +283,18 @@ def load_weights(self, weights_iterator) -> None: ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ] + params_dict = dict(self.named_parameters()) + model_path = self.fd_config.model_config.model + revision = self.fd_config.model_config.revision + if is_pooling_model and get_pooling_config(model_path, revision): + params_dict = { + param_name[6:] if param_name.startswith("model.") else param_name: param + for param_name, param in params_dict.items() + } + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) + for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: @@ -282,6 +305,7 @@ def load_weights(self, weights_iterator) -> None: param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) + break else: model_param_name = loaded_weight_name @@ -290,10 +314,11 @@ def load_weights(self, weights_iterator) -> None: param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) - if self.tie_word_embeddings: + if self.tie_word_embeddings and not is_pooling_model: self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight}) @paddle.no_grad() diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index e28b5748f7e..bc270e1267b 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -39,7 +39,11 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) from fastdeploy.model_executor.models.qwen3 import Qwen3Attention @@ -316,6 +320,12 @@ def forward( return out +@ModelRegistry.register_model_class( + architecture="Qwen3MoeForCausalLM", + module_path="qwen3moe", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) class Qwen3MoeForCausalLM(ModelForCasualLM): """ Qwen3MoeForCausalLM diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 0b226bf7b29..754725691e1 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -158,6 +158,7 @@ def default_weight_loader(fd_config: FDConfig) -> None: def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): """fn""" + output_dim = getattr(param, "output_dim", None) weight_need_transpose = getattr(param, "weight_need_transpose", False) if weight_need_transpose: diff --git a/fastdeploy/output/pooler.py b/fastdeploy/output/pooler.py new file mode 100644 index 00000000000..86bdbce98c4 --- /dev/null +++ b/fastdeploy/output/pooler.py @@ -0,0 +1,69 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any + +import msgspec +import paddle + + +class PoolingSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, + array_like=True, +): + """The model output associated with a pooling sequence group.""" + + # Annotated as Any to be compatible with msgspec + # The actual type is in SequenceGroup.pooled_data + data: Any + + def get_data_nbytes(self) -> int: + if isinstance(self.data, paddle.Tensor): + return self.data.numel() * self.data.element_size() + elif hasattr(self.data, "nbytes"): + return self.data.nbytes + else: + return 0 + + def __repr__(self) -> str: + return f"PoolingSequenceGroupOutput(data={self.data}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PoolingSequenceGroupOutput): + raise NotImplementedError() + return self.data == other.data + + +class PoolerOutput(msgspec.Struct, omit_defaults=True, array_like=True): + """The output from a pooling operation in the pooling model.""" + + outputs: list[PoolingSequenceGroupOutput] + + def get_data_nbytes(self) -> int: + return sum(o.get_data_nbytes() for o in self.outputs) + + def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: + return self.outputs[idx] + + def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self.outputs == other.outputs diff --git a/fastdeploy/transformer_utils/__init__.py b/fastdeploy/transformer_utils/__init__.py new file mode 100644 index 00000000000..f4ede90624a --- /dev/null +++ b/fastdeploy/transformer_utils/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/transformer_utils/config.py b/fastdeploy/transformer_utils/config.py new file mode 100644 index 00000000000..4eed4745849 --- /dev/null +++ b/fastdeploy/transformer_utils/config.py @@ -0,0 +1,139 @@ +import json +from pathlib import Path +from typing import Optional, Union + +import huggingface_hub +from huggingface_hub import hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) + +from fastdeploy.utils import get_logger + +logger = get_logger("transformer_config", "transformer_config.log") + + +def file_or_path_exists(model, config_name): + if (local_path := Path(model)).exists(): + return (local_path / config_name).is_file() + + return False + + +def get_pooling_config_name(pooling_name: str): + + if "pooling_mode_" in pooling_name: + pooling_name = pooling_name.replace("pooling_mode_", "") + + if "_" in pooling_name: + pooling_name = pooling_name.split("_")[0] + + if "lasttoken" in pooling_name: + pooling_name = "last" + + supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"] + pooling_type_name = pooling_name.upper() + + if pooling_type_name in supported_pooling_types: + return pooling_type_name + + raise NotImplementedError(f"Pooling type {pooling_type_name} not supported") + + +def try_get_local_file(model: Union[str, Path], file_name: str, revision: Optional[str] = "main") -> Optional[Path]: + file_path = Path(model) / file_name + if file_path.is_file(): + return file_path + else: + try: + cached_filepath = try_to_load_from_cache(repo_id=model, filename=file_name, revision=revision) + if isinstance(cached_filepath, str): + return Path(cached_filepath) + except ValueError: + ... + return None + + +def get_hf_file_to_dict(file_name: str, model: Union[str, Path], revision: Optional[str] = "main"): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) + + if file_path is None: + try: + hf_hub_file = hf_hub_download(model, file_name, revision=revision) + except huggingface_hub.errors.OfflineModeIsEnabled: + return None + except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", e) + return None + except HfHubHTTPError as e: + logger.warning( + "Cannot connect to Hugging Face Hub. Skipping file " "download for '%s':", file_name, exc_info=e + ) + return None + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path) as file: + return json.load(file) + + return None + + +def get_pooling_config(model: str, revision: Optional[str] = "main"): + """ + This function gets the pooling and normalize + config from the model - only applies to + sentence-transformers models. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + modules_dict = None + if file_or_path_exists(model, config_name=modules_file_name): + modules_dict = get_hf_file_to_dict(modules_file_name, model) + + if modules_dict is None: + return None + + pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None) + + normalize = bool( + next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Normalize"), False) + ) + + if pooling: + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model) + pooling_type_name = next((item for item, val in pooling_dict.items() if val is True), None) + + if pooling_type_name is not None: + pooling_type_name = get_pooling_config_name(pooling_type_name) + + return {"pooling_type": pooling_type_name, "normalize": normalize} + + return None diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 1a2dd0c79bc..924d283c3f9 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -51,6 +51,7 @@ from fastdeploy.logger.logger import FastDeployLogger T = TypeVar("T") +from typing import Callable, Optional # [N,2] -> every line is [config_name, enable_xxx_name] # Make sure enable_xxx equal to config.enable_xxx @@ -852,3 +853,24 @@ def get_logger(name, file_name=None, without_formater=False, print_to_console=Fa console_logger = get_logger("console", "console.log", print_to_console=True) spec_logger = get_logger("speculate", "speculate.log") zmq_client_logger = get_logger("zmq_client", "zmq_client.log") + + +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: + + def _parse_type(val: str) -> T: + try: + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError(f"Value {val} cannot be converted to {return_type}.") from e + + return _parse_type + + +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + + return _optional_type diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 806c8cb75da..7ad7ad50089 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1317,8 +1317,12 @@ def _dummy_run( self.parallel_config.max_model_len, ) - # 4. Execute spec decode - logits = self.model.compute_logits(hidden_states) + logits = None + if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: + pass + else: + # 4. Execute spec decode + logits = self.model.compute_logits(hidden_states) if not self.speculative_decoding: set_value_by_flags_and_idx( @@ -1623,8 +1627,13 @@ class at the server level, which is too granular for ModelRunner. self.parallel_config.max_model_len, ) + logits = None # 4. Compute logits, Sample - logits = self.model.compute_logits(hidden_states) + if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: + pass + else: + # 4. Execute spec decode + logits = self.model.compute_logits(hidden_states) if not self.speculative_decoding: set_value_by_flags_and_idx( diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 4b85a8a5e7f..f4e43faef5f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -45,7 +45,7 @@ from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, optional_type from fastdeploy.worker.worker_base import WorkerBase logger = get_logger("worker_process", "worker_process.log") @@ -642,6 +642,27 @@ def parse_args(): help="Flag to specify dtype of lm_head as FP32", ) + parser.add_argument( + "--runner", + type=str, + default="auto", + help="The type of model runner to use.Each FD instance only supports one model runner.even if the same model can be used for multiple types.", + ) + + parser.add_argument( + "--convert", + type=str, + default="auto", + help="Convert the model using adapters. The most common use case is to adapt a text generation model to be used for pooling tasks.", + ) + + parser.add_argument( + "--override-pooler-config", + type=optional_type(json.loads), + default=None, + help="Override configuration for the pooler.", + ) + args = parser.parse_args() return args diff --git a/requirements.txt b/requirements.txt index ddad9d9b3bb..8eb02b628a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,3 +39,4 @@ opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +msgspec diff --git a/tests/plugins/fd_add_dummy_model/__init__.py b/tests/plugins/fd_add_dummy_model/__init__.py index 1c7dba0cc63..0a4923c5983 100644 --- a/tests/plugins/fd_add_dummy_model/__init__.py +++ b/tests/plugins/fd_add_dummy_model/__init__.py @@ -14,9 +14,8 @@ from paddleformers.transformers import PretrainedModel -from fastdeploy import ModelRegistry from fastdeploy.config import ErnieArchitectures -from fastdeploy.model_executor.models.model_base import ModelForCasualLM +from fastdeploy.model_executor.models.model_base import ModelForCasualLM, ModelRegistry class MyPretrainedModel(PretrainedModel): diff --git a/tests/pooling/test_embedding.py b/tests/pooling/test_embedding.py new file mode 100644 index 00000000000..d609726e235 --- /dev/null +++ b/tests/pooling/test_embedding.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import sys + +import paddle +import pytest + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) +from fastdeploy.model_executor.models.model_base import ModelRegistry + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from tests.model_loader.utils import get_torch_model_path + + +class TestModelLoader: + + @pytest.fixture(scope="session", autouse=True) + def setup_paddle(self): + if not paddle.is_compiled_with_cuda(): + print("CUDA not available, using CPU") + paddle.set_device("cpu") + else: + print("Using CUDA device") + paddle.set_device("gpu") + yield + + @pytest.fixture(scope="session") + def model_path(self): + try: + torch_model_path = get_torch_model_path("Qwen3-0.6B") + if os.path.exists(torch_model_path): + return torch_model_path + except Exception as e: + print(f"Could not get torch model path: {e}") + + @pytest.fixture + def model_config(self, model_path): + model_args = { + "model": model_path, + "dtype": "bfloat16", + "max_model_len": 8192, + "tensor_parallel_size": 1, + "runner": "auto", + "convert": "auto", + } + + try: + return ModelConfig(model_args) + except Exception as e: + print(f"Could not create ModelConfig: {e}") + + @pytest.fixture + def fd_config(self, model_config): + try: + cache_args = { + "block_size": 64, + "gpu_memory_utilization": 0.9, + "cache_dtype": "bfloat16", + "model_cfg": model_config, + "tensor_parallel_size": 1, + } + cache_config = CacheConfig(cache_args) + + parallel_args = { + "tensor_parallel_size": 1, + "data_parallel_size": 1, + } + parallel_config = ParallelConfig(parallel_args) + + load_args = {} + load_config = LoadConfig(load_args) + + graph_opt_args = { + "enable_cudagraph": False, + "cudagraph_capture_sizes": None, + } + graph_opt_config = GraphOptimizationConfig(graph_opt_args) + + return FDConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + load_config=load_config, + graph_opt_config=graph_opt_config, + test_mode=True, + ) + except Exception as e: + print(f"Could not create FDConfig: {e}") + + @pytest.fixture + def model_json_config(self, model_path): + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + return None + + def test_embedding_with_none_convert_type(self, fd_config, model_json_config): + if model_json_config is None: + pytest.skip("Model config not available") + + if fd_config is None: + pytest.skip("FDConfig not available") + + print("=" * 60) + print("Testing initialize_model with convert_type='none'") + print("=" * 60) + + architectures = model_json_config.get("architectures", []) + if not architectures: + pytest.skip("No architectures found in model config") + + fd_config.model_config.convert_type = "none" + + try: + model_cls = ModelRegistry.get_class(architectures) + + if hasattr(model_cls, "__name__"): + assert ( + "ForEmbedding" not in model_cls.__name__ + ), f"Standard model should not have 'ForEmbedding' in name, but got: {model_cls.__name__}" + print(f"Confirmed standard model type (no ForEmbedding): {model_cls.__name__}") + + standard_methods = set(dir(model_cls)) + assert "_init_pooler" not in standard_methods, "Standard model should not have _init_pooler method" + + except Exception as e: + print(f"Error in none: {e}") + + def test_embedding_with_embed_convert_type(self, fd_config, model_json_config): + if model_json_config is None: + pytest.skip("Model config not available") + + if fd_config is None: + pytest.skip("FDConfig not available") + + print("=" * 60) + print("Testing embedding with convert_type='embed'") + print("=" * 60) + + architectures = model_json_config.get("architectures", []) + if not architectures: + pytest.skip("No architectures found in model config") + + fd_config.model_config.convert_type = "embed" + + try: + model_cls = ModelRegistry.get_class(architectures) + if hasattr(model_cls, "__name__"): + assert "ForEmbedding" in model_cls.__name__, "Embedding model should have 'ForEmbedding' in name" + print(f"Confirmed embedding model type: {model_cls.__name__}") + + embedding_methods = set(dir(model_cls)) + assert "_init_pooler" in embedding_methods, "Embedding model should have _init_pooler method" + + except Exception as e: + print(f"Error in convert embed: {e}")