Skip to content
2 changes: 1 addition & 1 deletion tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
Expand Down
6 changes: 3 additions & 3 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
Expand Down Expand Up @@ -389,12 +389,12 @@ def benchmark():
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
capture_sizes=cudagraph_sizes,
)

vllm_config = VllmConfig(compilation_config=compilation_config)
Expand Down
54 changes: 26 additions & 28 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2687,7 +2687,7 @@ class CompilationConfig(BaseModel):
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- capture_sizes: sizes to capture cudagraph.
- None (default): capture sizes are inferred from vllm config.
- List[int]: capture sizes are specified as given.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
Expand All @@ -2703,10 +2703,11 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for cudagraph sizes that are
in candidate_compile_sizes, using configurations
in inductor_compile_config.
- candidate_compile_sizes: sizes to compile for inductor.
is compiled. In addition, compile for compile_sizes,
using configurations in inductor_compile_config.
- compile_sizes: sizes to compile for inductor. In addition
to integers, it also supports "cudagraph" to
specify the sizes for cudagraph capture.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
Expand Down Expand Up @@ -2734,13 +2735,13 @@ class CompilationConfig(BaseModel):
splitting_ops: List[str] = Field(default=None) # type: ignore

use_inductor: bool = True
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
compile_sizes: Optional[List[Union[int, str]]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)

use_cudagraph: bool = False
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False

class PassConfig(BaseModel):
Expand Down Expand Up @@ -2782,8 +2783,6 @@ def model_post_init(self, __context: Any) -> None:
pass_config: PassConfig = Field(default_factory=PassConfig)

# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
max_capture_size: int = PrivateAttr
# optimization:
# Intuitively, bs_to_padded_graph_size should be Dict[int, int].
Expand Down Expand Up @@ -2909,31 +2908,30 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)

def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
def init_with_cudagraph_sizes(self,
cudagraph_capture_sizes: List[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes."""

if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
if self.capture_sizes is None:
self.capture_sizes = cudagraph_capture_sizes
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)

if self.candidate_compile_sizes is None:
self.candidate_compile_sizes = []
self.compile_sizes = [
x for x in self.candidate_compile_sizes if x in self.capture_sizes
]
ignored_sizes = [
x for x in self.candidate_compile_sizes
if x not in self.capture_sizes
]
if ignored_sizes:
logger.warning(("candidate_compile_sizes %s are ignored "
"because they are not cudagraph capture sizes."),
ignored_sizes)
cudagraph_capture_sizes, self.capture_sizes)

computed_compile_sizes = []
if self.compile_sizes is not None:
for x in self.compile_sizes:
if isinstance(x, str):
assert x == "cudagraph", \
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph', got {x}"
computed_compile_sizes.extend(self.capture_sizes)
else:
assert isinstance(x, int)
computed_compile_sizes.append(x)
self.compile_sizes = computed_compile_sizes

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast

import numpy as np
import torch
Expand Down Expand Up @@ -826,10 +826,12 @@ def load_model(self) -> None:
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
Expand Down Expand Up @@ -955,8 +957,7 @@ def profile_run(self) -> None:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
Expand All @@ -982,8 +983,8 @@ def capture_model(self) -> None:
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.model_runner.initialize_kv_cache(kv_cache_config)

def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.vllm_config.compilation_config.capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
Expand Down
12 changes: 9 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,13 +1244,19 @@ def set_in_profile_run(self):

@torch.inference_mode()
def profile_run(self) -> None:
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs)

def _dummy_run(self,
max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None:
with self.set_in_profile_run():
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = \
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs

# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
Expand Down
12 changes: 12 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@ def _init_cache_engine(self):
self.gpu_cache)

def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.vllm_config.compilation_config.capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
Expand Down
Loading