-
Notifications
You must be signed in to change notification settings - Fork 737
[Feature] support pool #3827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] support pool #3827
Changes from 35 commits
8e0c8d4
57795ea
302adb0
87237b0
cef99ec
926c796
a76a43e
344a8df
5ec6a93
2235785
98b32fc
4f90dfc
0ef4c9a
1daacb7
9d1a011
2eff16e
9fcd05e
945cba5
6f545aa
222d1b2
97b4649
4cf6164
db0a4bf
2681765
10969cd
669d712
dd45025
31b7311
e023c6b
cb80ce8
893cdbb
d4dcc3c
41aa2c5
8e92eb4
91f777e
a90a091
6f6c549
1fdf477
9e4d1fa
ebf4e0c
27ec018
adc5b8f
5e264f3
798f788
b72ac60
a1de646
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,24 +18,83 @@ | |
|
|
||
| import json | ||
| import os | ||
| from dataclasses import field | ||
| from enum import Enum | ||
| from typing import Any, Dict, List, Literal, Optional, Union | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
| from paddleformers.transformers.configuration_utils import PretrainedConfig | ||
| from typing_extensions import assert_never | ||
|
|
||
| import fastdeploy | ||
| from fastdeploy import envs | ||
| from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase | ||
| from fastdeploy.multimodal.registry import MultimodalRegistry | ||
| from fastdeploy.platforms import current_platform | ||
| from fastdeploy.scheduler import SchedulerConfig | ||
| from fastdeploy.transformer_utils.config import get_pooling_config | ||
| from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger | ||
|
|
||
| logger = get_logger("config", "config.log") | ||
|
|
||
| TaskOption = Literal["generate"] | ||
| TaskOption = Literal["auto", "generate", "embedding", "embed"] | ||
|
|
||
| RunnerType = Literal["generate", "pooling"] | ||
|
|
||
| RunnerOption = Literal["auto", "generate", "pooling"] | ||
|
|
||
| ConvertOption = Literal["auto", "none", "embed"] | ||
|
|
||
| ConvertType = Literal["none", "embed"] | ||
|
|
||
| _ResolvedTask = Literal["generate", "encode", "embed"] | ||
|
|
||
| _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { | ||
| "generate": [], | ||
| "pooling": ["embed"], | ||
| } | ||
|
|
||
| # Some model suffixes are based on auto classes from Transformers: | ||
| # https://huggingface.co/docs/transformers/en/model_doc/auto | ||
| # NOTE: Items higher on this list priority over lower ones | ||
| _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ | ||
| ("ForCausalLM", ("generate", "none")), | ||
| ("ForConditionalGeneration", ("generate", "none")), | ||
| ("ChatModel", ("generate", "none")), | ||
| ("LMHeadModel", ("generate", "none")), | ||
| ("ForTextEncoding", ("pooling", "embed")), | ||
| ("EmbeddingModel", ("pooling", "embed")), | ||
| ("ForSequenceClassification", ("pooling", "classify")), | ||
| ("ForAudioClassification", ("pooling", "classify")), | ||
| ("ForImageClassification", ("pooling", "classify")), | ||
| ("ForVideoClassification", ("pooling", "classify")), | ||
| ("ClassificationModel", ("pooling", "classify")), | ||
| ("ForRewardModeling", ("pooling", "reward")), | ||
| ("RewardModel", ("pooling", "reward")), | ||
| # Let other `*Model`s take priority | ||
| ("Model", ("pooling", "embed")), | ||
| ] | ||
|
|
||
|
|
||
| def iter_architecture_defaults(): | ||
| yield from _SUFFIX_TO_DEFAULTS | ||
|
|
||
|
|
||
| def try_match_architecture_defaults( | ||
| architecture: str, | ||
| *, | ||
| runner_type: Optional[RunnerType] = None, | ||
| convert_type: Optional[ConvertType] = None, | ||
| ): | ||
| for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults(): | ||
| if ( | ||
| (runner_type is None or runner_type == default_runner_type) | ||
| and (convert_type is None or convert_type == default_convert_type) | ||
| and architecture.endswith(suffix) | ||
| ): | ||
| return suffix, (default_runner_type, default_convert_type) | ||
| return None | ||
|
|
||
|
|
||
| class MoEPhase: | ||
|
|
@@ -133,6 +192,12 @@ def __init__( | |
| self.eos_tokens_lens: int = 2 | ||
| self.lm_head_fp32: bool = False | ||
| self.model_format = "auto" | ||
| self.runner = "auto" | ||
| self.convert = "auto" | ||
| self.pooler_config: Optional["PoolerConfig"] = field(init=False) | ||
| self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None | ||
| self.revision = None | ||
|
Comment on lines
+195
to
+199
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么不单独加一个PoolerConfig?把runner/convert等都加进去
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pooler.py中创建DisPatchPooler,里面的ResolvedPoolingConfig它的from_config中需要pooler_config |
||
|
|
||
| self.partial_rotary_factor: float = 1.0 | ||
| for key, value in args.items(): | ||
| if hasattr(self, key) and value != "None": | ||
|
|
@@ -160,6 +225,7 @@ def __init__( | |
| self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size) | ||
|
|
||
| architectures = self.architectures[0] | ||
|
|
||
| if MultimodalRegistry.contains_model(architectures): | ||
| self.enable_mm = True | ||
| else: | ||
|
|
@@ -170,6 +236,39 @@ def __init__( | |
| self.override_name_from_config() | ||
| self.read_from_env() | ||
| self.read_model_config() | ||
| self.runner_type = self._get_runner_type(self.architectures, self.runner) | ||
| self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert) | ||
|
|
||
| registry = self.registry | ||
| is_generative_model = registry.is_text_generation_model(self.architectures, self) | ||
| is_pooling_model = registry.is_pooling_model(self.architectures, self) | ||
|
|
||
| if self.runner_type == "generate" and not is_generative_model: | ||
| generate_converts = _RUNNER_CONVERTS["generate"] | ||
| if self.convert_type not in generate_converts: | ||
| raise ValueError("This model does not support '--runner generate.") | ||
| if self.runner_type == "pooling" and not is_pooling_model: | ||
| pooling_converts = _RUNNER_CONVERTS["pooling"] | ||
| if self.convert_type not in pooling_converts: | ||
| convert_option = "<" + "|".join(pooling_converts) + ">" | ||
| raise ValueError( | ||
| "This model does not support `--runner pooling`. " | ||
| f"You can pass `--convert {convert_option} to adapt " | ||
| "it into a pooling model." | ||
| ) | ||
|
|
||
| self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type) | ||
| model_info, arch = registry.inspect_model_cls(self.architectures, self) | ||
| self._model_info = model_info | ||
| self._architecture = arch | ||
|
|
||
| self.pooler_config = self._init_pooler_config() | ||
|
|
||
| @property | ||
| def registry(self): | ||
| from fastdeploy.model_executor.models.model_base import ModelRegistry | ||
|
|
||
| return ModelRegistry() | ||
|
Comment on lines
+273
to
+276
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model_config怎么还返回了一个ModelRegistry呢
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我写的方法都不是类方法,旧的还是类方法 |
||
|
|
||
| def override_name_from_config(self): | ||
| """ | ||
|
|
@@ -193,7 +292,6 @@ def override_name_from_config(self): | |
| def read_from_env(self): | ||
| """ | ||
| Read configuration information from environment variables and update the object's attributes. | ||
|
|
||
| If an attribute is not present or is an empty string in the environment variables, use the default value. | ||
| """ | ||
| self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM) | ||
|
|
@@ -234,6 +332,165 @@ def read_model_config(self): | |
| f"Config file path: {config_path}" | ||
| ) | ||
|
|
||
| def _get_default_runner_type( | ||
| self, | ||
| architectures: list[str], | ||
| ) -> RunnerType: | ||
| registry = self.registry | ||
| if get_pooling_config(self.model, self.revision): | ||
| return "pooling" | ||
| for arch in architectures: | ||
| if arch in registry.get_supported_archs(): | ||
| if registry.is_pooling_model(architectures, self): | ||
| return "pooling" | ||
| if registry.is_text_generation_model(architectures, self): | ||
| return "generate" | ||
| match = try_match_architecture_defaults(arch) | ||
| if match: | ||
| _, (runner_type, _) = match | ||
| return runner_type | ||
| return "generate" | ||
|
|
||
| def _get_default_convert_type( | ||
| self, | ||
| architectures: list[str], | ||
| runner_type: RunnerType, | ||
| ) -> ConvertType: | ||
| registry = self.registry | ||
|
|
||
| for arch in architectures: | ||
| if arch in registry.get_supported_archs(): | ||
| if runner_type == "generate" and registry.is_text_generation_model(architectures, self): | ||
| return "none" | ||
| if runner_type == "pooling" and registry.is_pooling_model(architectures, self): | ||
| return "none" | ||
| match = try_match_architecture_defaults(arch, runner_type=runner_type) | ||
| if match: | ||
| _, (_, convert_type) = match | ||
| return convert_type | ||
|
|
||
| # This is to handle Sentence Transformers models that use *ForCausalLM | ||
| # and also multi-modal pooling models which are not defined as | ||
| # Sentence Transformers models | ||
| if runner_type == "pooling": | ||
| return "embed" | ||
|
|
||
| return "none" | ||
|
|
||
| def _get_runner_type( | ||
| self, | ||
| architectures: list[str], | ||
| runner: RunnerOption, | ||
| ) -> RunnerType: | ||
| if runner != "auto": | ||
| return runner | ||
|
|
||
| runner_type = self._get_default_runner_type(architectures) | ||
| if runner_type != "generate": | ||
| logger.info( | ||
| "Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.", | ||
| runner_type, | ||
| ) | ||
|
|
||
| return runner_type | ||
|
|
||
| def _get_convert_type( | ||
| self, | ||
| architectures: list[str], | ||
| runner_type: RunnerType, | ||
| convert: ConvertOption, | ||
| ) -> ConvertType: | ||
| if convert != "auto": | ||
| return convert | ||
|
|
||
| convert_type = self._get_default_convert_type(architectures, runner_type) | ||
|
|
||
| if convert_type != "none": | ||
| logger.info( | ||
| "Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.", | ||
| convert_type, | ||
| ) | ||
|
|
||
| return convert_type | ||
|
|
||
| def _get_supported_generation_tasks( | ||
| self, | ||
| architectures: list[str], | ||
| convert_type: ConvertType, | ||
| ) -> list[_ResolvedTask]: | ||
| registry = self.registry | ||
|
|
||
| supported_tasks = list[_ResolvedTask]() | ||
| if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]: | ||
| supported_tasks.append("generate") | ||
|
|
||
| # TODO:Temporarily does not support transcription. | ||
| return supported_tasks | ||
|
|
||
| def _get_default_pooling_task( | ||
| self, | ||
| architectures: list[str], | ||
| ) -> Literal["embed"]: | ||
| # Temporarily does not support classification and reward. | ||
| for arch in architectures: | ||
| match = try_match_architecture_defaults(arch, runner_type="pooling") | ||
| if match: | ||
| _, (_, convert_type) = match | ||
| assert convert_type != "none" | ||
| return convert_type | ||
|
|
||
| return "embed" | ||
|
|
||
| def _get_supported_pooling_tasks( | ||
| self, | ||
| architectures: list[str], | ||
| convert_type: ConvertType, | ||
| ) -> list[_ResolvedTask]: | ||
| registry = self.registry | ||
|
|
||
| supported_tasks = list[_ResolvedTask]() | ||
| if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]: | ||
| supported_tasks.append("encode") | ||
|
|
||
| extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type | ||
| supported_tasks.append(extra_task) | ||
|
|
||
| return supported_tasks | ||
|
|
||
| def _get_supported_tasks( | ||
| self, | ||
| architectures: list[str], | ||
| runner_type: RunnerType, | ||
| convert_type: ConvertType, | ||
| ) -> list[_ResolvedTask]: | ||
| if runner_type == "generate": | ||
| return self._get_supported_generation_tasks(architectures, convert_type) | ||
| if runner_type == "pooling": | ||
| return self._get_supported_pooling_tasks(architectures, convert_type) | ||
|
|
||
| assert_never(runner_type) | ||
|
|
||
| def _init_pooler_config(self) -> Optional["PoolerConfig"]: | ||
| if self.runner_type == "pooling": | ||
| if isinstance(self.override_pooler_config, dict): | ||
| self.override_pooler_config = PoolerConfig(**self.override_pooler_config) | ||
|
|
||
| pooler_config = self.override_pooler_config or PoolerConfig() | ||
|
|
||
| base_config = get_pooling_config(self.model, self.revision) | ||
| if base_config is not None: | ||
| for k, v in base_config.items(): | ||
| if getattr(pooler_config, k) is None: | ||
| setattr(pooler_config, k, v) | ||
|
|
||
| default_pooling_type = self._model_info.default_pooling_type | ||
| if pooler_config.pooling_type is None: | ||
| pooler_config.pooling_type = default_pooling_type | ||
|
|
||
| return pooler_config | ||
|
|
||
| return None | ||
|
|
||
| def _get_download_model(self, model_name, model_type="default"): | ||
| # TODO: Provide dynamic graph for self-downloading and save to the specified download directory. | ||
| pass | ||
|
|
@@ -850,6 +1107,41 @@ def __init__( | |
| setattr(self, key, value) | ||
|
|
||
|
|
||
| class PoolerConfig: | ||
| """Controls the behavior of output pooling in pooling models.""" | ||
|
|
||
| pooling_type: Optional[str] = None | ||
| """ | ||
| The pooling method of the pooling model. | ||
| """ | ||
| # for embeddings models | ||
| normalize: Optional[bool] = None | ||
| """ | ||
| Whether to normalize the embeddings outputs. Defaults to True. | ||
| """ | ||
| dimensions: Optional[int] = None | ||
| """ | ||
| Reduce the dimensions of embeddings if model | ||
| support matryoshka representation. Defaults to None. | ||
| """ | ||
| enable_chunked_processing: Optional[bool] = None | ||
| """ | ||
| Whether to enable chunked processing for long inputs that exceed the model's | ||
| maximum position embeddings. When enabled, long inputs will be split into | ||
| chunks, processed separately, and then aggregated using weighted averaging. | ||
| This allows embedding models to handle arbitrarily long text without CUDA | ||
| errors. Defaults to False. | ||
| """ | ||
| max_embed_len: Optional[int] = None | ||
| """ | ||
| Maximum input length allowed for embedding generation. When set, allows | ||
| inputs longer than max_embed_len to be accepted for embedding models. | ||
| When an input exceeds max_embed_len, it will be handled according to | ||
| the original max_model_len validation logic. | ||
| Defaults to None (i.e. set to max_model_len). | ||
| """ | ||
|
|
||
|
|
||
| class LoRAConfig: | ||
| """LoRA Config""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里embedding 和embed的区别是什么