diff --git a/tests/v1/core/test_prefill_chunk_alignment.py b/tests/v1/core/test_prefill_chunk_alignment.py new file mode 100644 index 000000000000..9208c8e7ea86 --- /dev/null +++ b/tests/v1/core/test_prefill_chunk_alignment.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.v1.core.sched.prefill_chunk_alignment import ( + DefaultPrefillChunkAlignmentPolicy, + MambaPrefillChunkAlignmentPolicy, + create_prefill_chunk_alignment_policy, +) + + +class _FakeRequest: + def __init__( + self, + num_prompt_tokens: int, + num_tokens: int, + num_computed_tokens: int = 0, + ) -> None: + self.num_prompt_tokens = num_prompt_tokens + self.num_tokens = num_tokens + self.num_computed_tokens = num_computed_tokens + + +@pytest.mark.parametrize("use_eagle", [False, True]) +def test_mamba_align_scheduled_tokens_block_aligns(use_eagle: bool): + policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=use_eagle) + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100) + # Far from the tail: must be a multiple of block_size. + assert policy.align_scheduled_tokens(req, 30) == 16 + + +def test_mamba_align_scheduled_tokens_forces_last_chunk_no_eagle(): + policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False) + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=80) + # last_cache_position == 96; chunk crosses it -> snap to 96 - 80 == 16. + assert policy.align_scheduled_tokens(req, 20) == 16 + + +def test_mamba_align_scheduled_tokens_eagle_pulls_back_one_block(): + policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=True) + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=64) + # Eagle: last_cache_position drops from 96 to 80; snap to 80 - 64 == 16. + assert policy.align_scheduled_tokens(req, 20) == 16 + + +def test_mamba_align_scheduled_tokens_passes_through_tail(): + policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False) + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=96) + # Past last_cache_position == 96: prefill the last few tokens unchanged. + assert policy.align_scheduled_tokens(req, 4) == 4 + + +def test_mamba_align_external_cached_tokens_rejects_nonzero(): + policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False) + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100) + with pytest.raises(AssertionError, match="External KV connector"): + policy.align_external_cached_tokens( + req, num_local_cached_tokens=0, num_external_cached_tokens=16 + ) + + +def test_default_policy_is_noop(): + policy = DefaultPrefillChunkAlignmentPolicy() + req = _FakeRequest(num_prompt_tokens=100, num_tokens=100) + assert policy.align_scheduled_tokens(req, 30) == 30 + assert ( + policy.align_external_cached_tokens( + req, num_local_cached_tokens=8, num_external_cached_tokens=24 + ) + == 24 + ) + + +def test_factory_returns_mamba_policy_only_when_align_mode(): + mamba_policy = create_prefill_chunk_alignment_policy( + has_mamba_layers=True, + mamba_cache_mode="align", + block_size=16, + use_eagle=False, + ) + assert isinstance(mamba_policy, MambaPrefillChunkAlignmentPolicy) + + for has_mamba, mode in [ + (False, "align"), + (True, "all"), + (True, "none"), + (False, "none"), + ]: + default_policy = create_prefill_chunk_alignment_policy( + has_mamba_layers=has_mamba, + mamba_cache_mode=mode, + block_size=16, + use_eagle=False, + ) + assert isinstance(default_policy, DefaultPrefillChunkAlignmentPolicy) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 42f4825e2b3b..61d19efe5d62 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -11,6 +11,7 @@ ECTransferConfig, KVTransferConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig, @@ -31,6 +32,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + MambaSpec, ) from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -684,6 +686,217 @@ def test_schedule_order(enable_chunked_prefill: bool): assert len(scheduler_output1.scheduled_new_reqs) == 1 +def _create_mamba_align_scheduler( + *, + block_size: int = 16, + max_num_batched_tokens: int = 64, + max_num_scheduled_tokens: int | None = None, + max_model_len: int = 256, + num_blocks: int = 10000, + kv_connector_matched_tokens: int = 0, + kv_connector_is_async: bool = False, +) -> Scheduler: + model_config = ModelConfig( + model="facebook/opt-125m", + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=True, + ) + scheduler_config = SchedulerConfig( + max_num_seqs=8, + max_num_batched_tokens=max_num_batched_tokens, + max_num_scheduled_tokens=max_num_scheduled_tokens, + max_model_len=max_model_len, + disable_chunked_mm_input=False, + enable_chunked_prefill=True, + async_scheduling=False, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + cache_dtype="auto", + enable_prefix_caching=True, + mamba_block_size=block_size, + mamba_cache_mode="align", + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + parallel_config=ParallelConfig(), + kv_transfer_config=( + KVTransferConfig( + kv_connector="MockKVConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "matched_tokens": kv_connector_matched_tokens, + "is_async": kv_connector_is_async, + }, + ) + if kv_connector_matched_tokens > 0 + else None + ), + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + MambaSpec( + block_size=block_size, + shapes=((1,),), + dtypes=(torch.float32,), + mamba_cache_mode="align", + ), + ) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + block_size=block_size, + structured_output_manager=StructuredOutputManager(vllm_config), + log_stats=False, + ) + + +def test_mamba_align_prefill_chunk_alignment_clamps_to_block_boundary(): + scheduler = _create_mamba_align_scheduler( + block_size=16, + max_num_batched_tokens=30, + max_model_len=256, + ) + (request,) = create_requests( + num_requests=1, + num_tokens=100, + block_size=16, + req_ids=["long"], + ) + + scheduler.add_request(request) + output = scheduler.schedule() + + assert output.num_scheduled_tokens["long"] == 16 + + +def test_mamba_align_prefill_chunk_alignment_blocks_waiting_queue_on_tiny_budget(): + scheduler = _create_mamba_align_scheduler( + block_size=16, + max_num_batched_tokens=32, + max_num_scheduled_tokens=8, + max_model_len=256, + ) + (long_request,) = create_requests( + num_requests=1, + num_tokens=100, + block_size=16, + req_ids=["long"], + ) + (short_request,) = create_requests( + num_requests=1, + num_tokens=4, + block_size=16, + req_ids=["short"], + ) + + scheduler.add_request(long_request) + scheduler.add_request(short_request) + output = scheduler.schedule() + + assert output.num_scheduled_tokens == {} + assert [request.request_id for request in scheduler.waiting] == ["long", "short"] + + +def test_prefill_chunk_alignment_policy_is_noop_without_mamba_align(): + scheduler = create_scheduler( + max_num_batched_tokens=30, + block_size=16, + max_model_len=256, + ) + (request,) = create_requests( + num_requests=1, + num_tokens=100, + block_size=16, + req_ids=["long"], + ) + + scheduler.add_request(request) + output = scheduler.schedule() + + assert output.num_scheduled_tokens["long"] == 30 + + +@pytest.mark.parametrize("is_async", [False, True]) +def test_mamba_align_prefill_chunk_alignment_rejects_external_kv_hits(is_async: bool): + scheduler = _create_mamba_align_scheduler( + block_size=16, + max_num_batched_tokens=30, + max_model_len=256, + kv_connector_matched_tokens=16, + kv_connector_is_async=is_async, + ) + (request,) = create_requests( + num_requests=1, + num_tokens=100, + block_size=16, + req_ids=["external"], + ) + scheduler.add_request(request) + + with pytest.raises(AssertionError, match="External KV connector"): + scheduler.schedule() + + +def test_prefill_chunk_alignment_policy_can_adjust_external_kv_hits(): + scheduler = _create_mamba_align_scheduler( + block_size=16, + max_num_batched_tokens=96, + max_model_len=256, + kv_connector_matched_tokens=24, + ) + + class ExternalKVAlignmentPolicy: + def align_external_cached_tokens( + self, + request: Request, + *, + num_local_cached_tokens: int, + num_external_cached_tokens: int, + ) -> int: + assert num_local_cached_tokens == 0 + assert num_external_cached_tokens == 24 + return 16 + + def align_scheduled_tokens( + self, + request: Request, + num_scheduled_tokens: int, + *, + num_local_cached_tokens: int = 0, + num_external_cached_tokens: int = 0, + ) -> int: + assert num_external_cached_tokens == 16 + return num_scheduled_tokens + + # Inject a test-only policy to exercise the scheduler's policy boundary. + scheduler.prefill_chunk_alignment_policy = ExternalKVAlignmentPolicy() + (request,) = create_requests( + num_requests=1, + num_tokens=96, + block_size=16, + req_ids=["external"], + ) + scheduler.add_request(request) + + output = scheduler.schedule() + + assert output.num_scheduled_tokens["external"] == 80 + + def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. diff --git a/vllm/v1/core/sched/prefill_chunk_alignment.py b/vllm/v1/core/sched/prefill_chunk_alignment.py new file mode 100644 index 000000000000..e4557fa00e52 --- /dev/null +++ b/vllm/v1/core/sched/prefill_chunk_alignment.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Protocol + +from vllm.config.cache import MambaCacheMode +from vllm.v1.request import Request + + +class PrefillChunkAlignmentPolicy(Protocol): + def align_external_cached_tokens( + self, + request: Request, + *, + num_local_cached_tokens: int, + num_external_cached_tokens: int, + ) -> int: + """Return externally cached tokens usable after policy alignment. + + The scheduler treats this value as authoritative for connector + prefix-cache hit metrics and async-load gating. + """ + ... + + def align_scheduled_tokens( + self, + request: Request, + num_scheduled_tokens: int, + *, + num_local_cached_tokens: int = 0, + num_external_cached_tokens: int = 0, + ) -> int: + """Return the number of tokens to schedule after policy alignment. + + `num_local_cached_tokens` and `num_external_cached_tokens` are 0 on + the running-queue path (already reflected in + `request.num_computed_tokens`) and non-zero on the waiting-queue path. + """ + ... + + +class DefaultPrefillChunkAlignmentPolicy: + """Default policy for models without prefill chunk alignment constraints.""" + + def align_external_cached_tokens( + self, + request: Request, + *, + num_local_cached_tokens: int, + num_external_cached_tokens: int, + ) -> int: + return num_external_cached_tokens + + def align_scheduled_tokens( + self, + request: Request, + num_scheduled_tokens: int, + *, + num_local_cached_tokens: int = 0, + num_external_cached_tokens: int = 0, + ) -> int: + return num_scheduled_tokens + + +class MambaPrefillChunkAlignmentPolicy: + """Align Mamba align-mode prefill chunks to cache block boundaries. + + In EAGLE mode, the final cacheable chunk is backed up by one block because + FullAttn prunes the last matching block. + """ + + def __init__(self, *, block_size: int, use_eagle: bool) -> None: + self.block_size = block_size + self.use_eagle = use_eagle + + def align_external_cached_tokens( + self, + request: Request, + *, + num_local_cached_tokens: int, + num_external_cached_tokens: int, + ) -> int: + assert num_external_cached_tokens == 0, ( + "External KV connector is not verified yet" + ) + return num_external_cached_tokens + + def align_scheduled_tokens( + self, + request: Request, + num_scheduled_tokens: int, + *, + num_local_cached_tokens: int = 0, + num_external_cached_tokens: int = 0, + ) -> int: + num_computed_tokens = ( + request.num_computed_tokens + + num_local_cached_tokens + + num_external_cached_tokens + ) + # Perform block-aligned splitting at prefill phase, including: + # * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0 + # * resumed requests: num_computed_tokens < ( + # num_prompt_tokens + num_output_tokens + # ) + # NOTE: Use `request.num_tokens - 1` to bypass normal decoding. + if num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1): + # To enable block-aligned caching of the Mamba state, scheduled tokens + # must be a multiple of `block_size`. + # As an exception, if the scheduled token count is less than + # `block_size`, the state is simply not cached, requiring no special + # handling. + # Additionally, when Eagle mode is enabled, FullAttn prunes the last + # matching block. To prevent this from causing a Mamba cache miss, the + # last chunk must be not smaller than `block_size`. + last_cache_position = ( + request.num_tokens - request.num_tokens % self.block_size + ) + # eagle prune + if self.use_eagle: + last_cache_position = max(last_cache_position - self.block_size, 0) + num_computed_tokens_after_sched = num_computed_tokens + num_scheduled_tokens + if num_computed_tokens_after_sched < last_cache_position: + # align to block_size + num_scheduled_tokens = ( + num_scheduled_tokens // self.block_size * self.block_size + ) + elif ( + num_computed_tokens + < last_cache_position + < num_computed_tokens_after_sched + ): + # force to cache the last chunk + num_scheduled_tokens = last_cache_position - num_computed_tokens + else: + # prefill the last few tokens + pass + return num_scheduled_tokens + + +def create_prefill_chunk_alignment_policy( + *, + has_mamba_layers: bool, + mamba_cache_mode: MambaCacheMode, + block_size: int, + use_eagle: bool, +) -> PrefillChunkAlignmentPolicy: + if has_mamba_layers and mamba_cache_mode == "align": + return MambaPrefillChunkAlignmentPolicy( + block_size=block_size, + use_eagle=use_eagle, + ) + return DefaultPrefillChunkAlignmentPolicy() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 032767cdf3b0..de636d4e24c4 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -45,6 +45,9 @@ NewRequestData, SchedulerOutput, ) +from vllm.v1.core.sched.prefill_chunk_alignment import ( + create_prefill_chunk_alignment_policy, +) from vllm.v1.core.sched.request_queue import ( RequestQueue, SchedulingPolicy, @@ -253,8 +256,11 @@ def __init__( self.has_mamba_layers = kv_cache_config.has_mamba_layers self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing - self.need_mamba_block_aligned_split = ( - self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align" + self.prefill_chunk_alignment_policy = create_prefill_chunk_alignment_policy( + has_mamba_layers=self.has_mamba_layers, + mamba_cache_mode=self.cache_config.mamba_cache_mode, + block_size=self.cache_config.block_size, + use_eagle=self.use_eagle, ) self.perf_metrics: ModelMetrics | None = None if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: @@ -299,56 +305,6 @@ def __init__( self._pause_state: PauseState = PauseState.UNPAUSED - def _mamba_block_aligned_split( - self, - request: Request, - num_new_tokens: int, - num_new_local_computed_tokens: int = 0, - num_external_computed_tokens: int = 0, - ) -> int: - assert num_external_computed_tokens == 0, ( - "External KV connector is not verified yet" - ) - num_computed_tokens = ( - request.num_computed_tokens - + num_new_local_computed_tokens - + num_external_computed_tokens - ) - # Perform block-aligned splitting at prefill phase, including: - # * non-resumed requests: num_computed_tokens < num_prompt_tokens + 0 - # * resumed requests: num_computed_tokens < ( - # num_prompt_tokens + num_output_tokens - # ) - # NOTE: Use `request.num_tokens - 1` to bypass normal decoding. - if num_computed_tokens < max(request.num_prompt_tokens, request.num_tokens - 1): - # To enable block-aligned caching of the Mamba state, `num_new_tokens` - # must be a multiple of `block_size`. - # As an exception, if `num_new_tokens` is less than `block_size`, the - # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be not smaller than `block_size`. - block_size = self.cache_config.block_size - last_cache_position = request.num_tokens - request.num_tokens % block_size - # eagle prune - if self.use_eagle: - last_cache_position = max(last_cache_position - block_size, 0) - num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens - if num_computed_tokens_after_sched < last_cache_position: - # align to block_size - num_new_tokens = num_new_tokens // block_size * block_size - elif ( - num_computed_tokens - < last_cache_position - < num_computed_tokens_after_sched - ): - # force to cache the last chunk - num_new_tokens = last_cache_position - num_computed_tokens - else: - # prefill the last few tokens - pass - return num_new_tokens - def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -438,10 +394,12 @@ def schedule(self) -> SchedulerOutput: shift_computed_tokens=1 if self.use_eagle else 0, ) - if self.need_mamba_block_aligned_split: - num_new_tokens = self._mamba_block_aligned_split( - request, num_new_tokens - ) + # Running-queue path: cached state is already reflected in + # `request.num_computed_tokens`, so policy kwargs default to 0. + num_new_tokens = self.prefill_chunk_alignment_policy.align_scheduled_tokens( + request, + num_new_tokens, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -453,8 +411,7 @@ def schedule(self) -> SchedulerOutput: # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. - # 4. Insufficient budget for a block-aligned chunk in hybrid - # models with mamba cache mode \"align\". + # 4. Insufficient budget after applying the prefill alignment policy. # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. @@ -638,6 +595,19 @@ def schedule(self) -> SchedulerOutput: connector_prefix_cache_queries = ( request.num_tokens - num_new_local_computed_tokens ) + + chunk_align_policy = self.prefill_chunk_alignment_policy + num_external_computed_tokens = ( + chunk_align_policy.align_external_cached_tokens( + request, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) + ) + # Stats and async transfer state follow usable hits. + load_kv_async = ( + load_kv_async and num_external_computed_tokens > 0 + ) connector_prefix_cache_hits = num_external_computed_tokens # Total computed tokens (local + external). @@ -710,12 +680,13 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if self.need_mamba_block_aligned_split: - num_new_tokens = self._mamba_block_aligned_split( - request, - num_new_tokens, - num_new_local_computed_tokens, - num_external_computed_tokens, + num_new_tokens = ( + self.prefill_chunk_alignment_policy.align_scheduled_tokens( + request, + num_new_tokens, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) ) if num_new_tokens == 0: break