-
-
Notifications
You must be signed in to change notification settings - Fork 12.6k
[V1][Core] Autotune encoder cache budget #11895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
495f669
8c67ecd
5938a1f
0e4ab3c
bd1ccf1
2a4b1d5
9ee3f3d
7614888
aaf3cef
e8f50f4
767b0d6
eb125b5
f539470
29ad359
3103622
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ | |
| is_pin_memory_available) | ||
| from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, | ||
| FlashAttentionMetadata) | ||
| from vllm.v1.core.encoder_cache_manager import ( | ||
| compute_encoder_cache_budget, compute_encoder_cache_budget_multimodal) | ||
| from vllm.v1.engine.mm_input_mapper import MMInputMapperClient | ||
| from vllm.v1.outputs import ModelRunnerOutput | ||
| from vllm.v1.sample.metadata import SamplingMetadata | ||
|
|
@@ -88,8 +90,8 @@ def __init__( | |
| self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) | ||
| self.mm_input_mapper_profiling.use_cache = False | ||
|
|
||
| self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 | ||
| self.encoder_cache_size = self.scheduler_config.encoder_cache_size | ||
| self.encoder_cache_budget = compute_encoder_cache_budget( | ||
| self.model_config, self.scheduler_config) | ||
|
|
||
| # Lazy initialization | ||
| # self.model: nn.Module # Set after load_model | ||
|
|
@@ -721,7 +723,16 @@ def profile_run(self) -> None: | |
| ] | ||
|
|
||
| # Profile with multimodal encoder & encoder cache. | ||
| if self.is_multimodal_model: | ||
| # TODO: handle encoder-decoder models once we support them. | ||
| if self.is_multimodal_model and self.encoder_cache_budget > 0: | ||
|
||
|
|
||
| # NOTE: Currently model is profiled with a single non-text | ||
| # modality with the max possible input tokens even when | ||
| # it supports multiple. | ||
| _, dummy_data_modality, max_num_mm_items = compute_encoder_cache_budget_multimodal( # noqa: E501 | ||
| self.model_config, | ||
| self.scheduler_config, | ||
| ) | ||
|
|
||
| # Create dummy batch of multimodal inputs. | ||
| dummy_request_data = self.input_registry.dummy_data_for_profiling( | ||
|
|
@@ -731,44 +742,6 @@ def profile_run(self) -> None: | |
| ) | ||
| dummy_mm_data = dummy_request_data.multi_modal_data | ||
|
|
||
| # NOTE: Currently model is profiled with a single non-text | ||
| # modality with the max possible input tokens even when | ||
| # it supports multiple. | ||
| max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501 | ||
| self.model_config) | ||
|
|
||
| dummy_data_modality, max_tokens_per_mm_item = max( | ||
| max_tokens_by_modality_dict.items(), key=lambda item: item[1]) | ||
|
|
||
| # Check how many items of this modality can be supported by | ||
| # the encoder cache budget. | ||
| encoder_cache_budget = min(self.max_num_encoder_input_tokens, | ||
| self.encoder_cache_size) | ||
| max_num_mm_items_encoder_budget = encoder_cache_budget // \ | ||
| max_tokens_per_mm_item | ||
|
|
||
| # TODO: Allow users to set encoder_cache_budget in case this | ||
| # happens. | ||
| assert max_num_mm_items_encoder_budget > 0, ( | ||
| f"Encoder cache budget={encoder_cache_budget} is too small to " | ||
| f"support the maximum possible size of multimodal embeddings" | ||
| f"={max_tokens_per_mm_item}.") | ||
|
|
||
| # Check how many items of this modality can be supported by | ||
| # the decoder budget. | ||
| max_mm_items_per_req = max( | ||
| self.mm_registry.get_mm_limits_per_prompt( | ||
| self.model_config).values()) | ||
|
|
||
| # NOTE: We do not consider max_num_batched_tokens on purpose | ||
| # because the multimodal embeddings can be generated in advance | ||
| # and chunked prefilled. | ||
| max_num_mm_items_decoder_budget = self.max_num_reqs * \ | ||
| max_mm_items_per_req | ||
|
|
||
| max_num_mm_items = min(max_num_mm_items_encoder_budget, | ||
| max_num_mm_items_decoder_budget) | ||
|
|
||
| # Dummy data definition in V0 may contain multiple multimodal items | ||
| # (e.g, multiple images) for a single request, therefore here we | ||
| # always replicate first item by max_num_mm_items times since in V1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.