Skip to content
Merged
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,14 +1379,6 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
Expand Down
29 changes: 24 additions & 5 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,8 @@ def get_max_tokens_per_item_by_modality(
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.

Note:
This is currently directly used only in V1.
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
Expand All @@ -269,6 +266,28 @@ def get_max_tokens_per_item_by_modality(
for key, plugin in self._plugins.items()
}

def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.

Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand Down
112 changes: 111 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from typing import Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.request import Request

if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig

logger = init_logger(__name__)


class EncoderCacheManager:

Expand Down Expand Up @@ -46,3 +54,105 @@ def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed


def compute_encoder_cache_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> int:
"""Compute the encoder cache budget based on the model and scheduler
configurations.

Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.

Returns:
The encoder cache budget, in unit of number of tokens
in the input sequence.
"""

encoder_cache_budget = 0

if not model_config.is_multimodal_model:
return encoder_cache_budget

# TODO: handle encoder-decoder models once we support them.
encoder_cache_budget, _, _ = compute_encoder_cache_budget_multimodal(
model_config, scheduler_config)

return encoder_cache_budget


def compute_encoder_cache_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
) -> tuple[int, Optional[str], int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.

Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.

Returns:
- The encoder cache budget, in unit of number of tokens in the
input sequence.
- The modality of the multimodal item that requires the most tokens.
- The number of multimodal items used to compute the encoder cache
budget.
"""

encoder_cache_budget = 0
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)

if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return encoder_cache_budget, None, 0

modality, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])

max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_reqs = scheduler_config.max_num_seqs

# The biggest possible multimodal item cannot be fully prefilled in a
# batch, so every batch can partially prefill at most one of such item.
if max_tokens_per_mm_item > max_num_batched_tokens:
num_items = 1

# A batch can fully cover multiple biggest possible multimodal items, and
# one that will be partially prefilled.
else:
num_items = cdiv(max_num_batched_tokens, max_tokens_per_mm_item)

# NOTE: We need the encoder cache to be able to compute & hold ONE
# ADDITIONAL multimodal item, and is required only when:
# - Two requests in the current batch share the same prefix with such item
# as part of the prefix.
# - AND the prefix length is divisible by the block size, triggering the
# recomputation of the last block.
# - AND the part of the embeddings of the item is in this last block.

# This issue can be fundamentally resolved by supporting num_new_tokens=0
# on the model runner.
num_items += 1

# Number of items needed cannot be bigger than max number of running
# requests * max number of multimodal items per request.
max_mm_items_per_req = max(
MULTIMODAL_REGISTRY.get_mm_limits_per_prompt(model_config).values())

num_items = min(num_items, max_num_reqs * max_mm_items_per_req)
encoder_cache_budget = num_items * max_tokens_per_mm_item

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_cache_budget, num_items, modality)

return encoder_cache_budget, modality, num_items
24 changes: 16 additions & 8 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_cache_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -24,6 +25,7 @@ class Scheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None:
Expand Down Expand Up @@ -68,16 +70,22 @@ def __init__(
self.running_reqs_data: Dict[str, RunningRequestData] = {}

# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_cache_budget = compute_encoder_cache_budget(
model_config, scheduler_config)

# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.max_num_encoder_input_tokens = encoder_cache_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)
cache_size=encoder_cache_budget)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ def __init__(
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
vllm_config.lora_config)
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
)

self._last_logging_time = time.time()

Expand Down
55 changes: 14 additions & 41 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import (
compute_encoder_cache_budget, compute_encoder_cache_budget_multimodal)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -88,8 +90,8 @@ def __init__(
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False

self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
self.encoder_cache_budget = compute_encoder_cache_budget(
self.model_config, self.scheduler_config)

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -721,7 +723,16 @@ def profile_run(self) -> None:
]

# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:
# TODO: handle encoder-decoder models once we support them.
if self.is_multimodal_model and self.encoder_cache_budget > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: Is it required to check self.encoder_cache_budget > 0 here?

Copy link
Member Author

@ywang96 ywang96 Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that the users will explicity disable all modalities (e.g, --limit-mm-per-prompt image=0) for whatever reason that have, in that case the model is just a text model with a multimodal encoder loaded but not used (this can be also optimized in the future), so we should just skip the profiling with multimodal data entirely.

Not a required check, but running all the following code means getting all the dummy data then not using any of them.


# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
_, dummy_data_modality, max_num_mm_items = compute_encoder_cache_budget_multimodal( # noqa: E501
self.model_config,
self.scheduler_config,
)

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
Expand All @@ -731,44 +742,6 @@ def profile_run(self) -> None:
)
dummy_mm_data = dummy_request_data.multi_modal_data

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
self.model_config)

dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item

# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")

# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())

# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req

max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)

# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
Expand Down
Loading