Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/v1/core/test_prefill_chunk_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.v1.core.sched.prefill_chunk_alignment import (
DefaultPrefillChunkAlignmentPolicy,
MambaPrefillChunkAlignmentPolicy,
create_prefill_chunk_alignment_policy,
)


class _FakeRequest:
def __init__(
self,
num_prompt_tokens: int,
num_tokens: int,
num_computed_tokens: int = 0,
) -> None:
self.num_prompt_tokens = num_prompt_tokens
self.num_tokens = num_tokens
self.num_computed_tokens = num_computed_tokens


@pytest.mark.parametrize("use_eagle", [False, True])
def test_mamba_align_scheduled_tokens_block_aligns(use_eagle: bool):
policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=use_eagle)
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100)
# Far from the tail: must be a multiple of block_size.
assert policy.align_scheduled_tokens(req, 30) == 16


def test_mamba_align_scheduled_tokens_forces_last_chunk_no_eagle():
policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False)
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=80)
# last_cache_position == 96; chunk crosses it -> snap to 96 - 80 == 16.
assert policy.align_scheduled_tokens(req, 20) == 16


def test_mamba_align_scheduled_tokens_eagle_pulls_back_one_block():
policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=True)
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=64)
# Eagle: last_cache_position drops from 96 to 80; snap to 80 - 64 == 16.
assert policy.align_scheduled_tokens(req, 20) == 16


def test_mamba_align_scheduled_tokens_passes_through_tail():
policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False)
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100, num_computed_tokens=96)
# Past last_cache_position == 96: prefill the last few tokens unchanged.
assert policy.align_scheduled_tokens(req, 4) == 4


def test_mamba_align_external_cached_tokens_rejects_nonzero():
policy = MambaPrefillChunkAlignmentPolicy(block_size=16, use_eagle=False)
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100)
with pytest.raises(AssertionError, match="External KV connector"):
policy.align_external_cached_tokens(
req, num_local_cached_tokens=0, num_external_cached_tokens=16
)


def test_default_policy_is_noop():
policy = DefaultPrefillChunkAlignmentPolicy()
req = _FakeRequest(num_prompt_tokens=100, num_tokens=100)
assert policy.align_scheduled_tokens(req, 30) == 30
assert (
policy.align_external_cached_tokens(
req, num_local_cached_tokens=8, num_external_cached_tokens=24
)
== 24
)


def test_factory_returns_mamba_policy_only_when_align_mode():
mamba_policy = create_prefill_chunk_alignment_policy(
has_mamba_layers=True,
mamba_cache_mode="align",
block_size=16,
use_eagle=False,
)
assert isinstance(mamba_policy, MambaPrefillChunkAlignmentPolicy)

for has_mamba, mode in [
(False, "align"),
(True, "all"),
(True, "none"),
(False, "none"),
]:
default_policy = create_prefill_chunk_alignment_policy(
has_mamba_layers=has_mamba,
mamba_cache_mode=mode,
block_size=16,
use_eagle=False,
)
assert isinstance(default_policy, DefaultPrefillChunkAlignmentPolicy)
213 changes: 213 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ECTransferConfig,
KVTransferConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
Expand All @@ -31,6 +32,7 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
)
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -684,6 +686,217 @@ def test_schedule_order(enable_chunked_prefill: bool):
assert len(scheduler_output1.scheduled_new_reqs) == 1


def _create_mamba_align_scheduler(
*,
block_size: int = 16,
max_num_batched_tokens: int = 64,
max_num_scheduled_tokens: int | None = None,
max_model_len: int = 256,
num_blocks: int = 10000,
kv_connector_matched_tokens: int = 0,
kv_connector_is_async: bool = False,
) -> Scheduler:
model_config = ModelConfig(
model="facebook/opt-125m",
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=True,
)
scheduler_config = SchedulerConfig(
max_num_seqs=8,
max_num_batched_tokens=max_num_batched_tokens,
max_num_scheduled_tokens=max_num_scheduled_tokens,
max_model_len=max_model_len,
disable_chunked_mm_input=False,
enable_chunked_prefill=True,
async_scheduling=False,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
cache_dtype="auto",
enable_prefix_caching=True,
mamba_block_size=block_size,
mamba_cache_mode="align",
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
parallel_config=ParallelConfig(),
kv_transfer_config=(
KVTransferConfig(
kv_connector="MockKVConnector",
kv_role="kv_both",
kv_connector_extra_config={
"matched_tokens": kv_connector_matched_tokens,
"is_async": kv_connector_is_async,
},
)
if kv_connector_matched_tokens > 0
else None
),
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
MambaSpec(
block_size=block_size,
shapes=((1,),),
dtypes=(torch.float32,),
mamba_cache_mode="align",
),
)
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
block_size=block_size,
structured_output_manager=StructuredOutputManager(vllm_config),
log_stats=False,
)


def test_mamba_align_prefill_chunk_alignment_clamps_to_block_boundary():
scheduler = _create_mamba_align_scheduler(
block_size=16,
max_num_batched_tokens=30,
max_model_len=256,
)
(request,) = create_requests(
num_requests=1,
num_tokens=100,
block_size=16,
req_ids=["long"],
)

scheduler.add_request(request)
output = scheduler.schedule()

assert output.num_scheduled_tokens["long"] == 16


def test_mamba_align_prefill_chunk_alignment_blocks_waiting_queue_on_tiny_budget():
scheduler = _create_mamba_align_scheduler(
block_size=16,
max_num_batched_tokens=32,
max_num_scheduled_tokens=8,
max_model_len=256,
)
(long_request,) = create_requests(
num_requests=1,
num_tokens=100,
block_size=16,
req_ids=["long"],
)
(short_request,) = create_requests(
num_requests=1,
num_tokens=4,
block_size=16,
req_ids=["short"],
)

scheduler.add_request(long_request)
scheduler.add_request(short_request)
output = scheduler.schedule()

assert output.num_scheduled_tokens == {}
assert [request.request_id for request in scheduler.waiting] == ["long", "short"]


def test_prefill_chunk_alignment_policy_is_noop_without_mamba_align():
scheduler = create_scheduler(
max_num_batched_tokens=30,
block_size=16,
max_model_len=256,
)
(request,) = create_requests(
num_requests=1,
num_tokens=100,
block_size=16,
req_ids=["long"],
)

scheduler.add_request(request)
output = scheduler.schedule()

assert output.num_scheduled_tokens["long"] == 30


@pytest.mark.parametrize("is_async", [False, True])
def test_mamba_align_prefill_chunk_alignment_rejects_external_kv_hits(is_async: bool):
scheduler = _create_mamba_align_scheduler(
block_size=16,
max_num_batched_tokens=30,
max_model_len=256,
kv_connector_matched_tokens=16,
kv_connector_is_async=is_async,
)
(request,) = create_requests(
num_requests=1,
num_tokens=100,
block_size=16,
req_ids=["external"],
)
scheduler.add_request(request)

with pytest.raises(AssertionError, match="External KV connector"):
scheduler.schedule()


def test_prefill_chunk_alignment_policy_can_adjust_external_kv_hits():
scheduler = _create_mamba_align_scheduler(
block_size=16,
max_num_batched_tokens=96,
max_model_len=256,
kv_connector_matched_tokens=24,
)

class ExternalKVAlignmentPolicy:
def align_external_cached_tokens(
self,
request: Request,
*,
num_local_cached_tokens: int,
num_external_cached_tokens: int,
) -> int:
assert num_local_cached_tokens == 0
assert num_external_cached_tokens == 24
return 16

def align_scheduled_tokens(
self,
request: Request,
num_scheduled_tokens: int,
*,
num_local_cached_tokens: int = 0,
num_external_cached_tokens: int = 0,
) -> int:
assert num_external_cached_tokens == 16
return num_scheduled_tokens

# Inject a test-only policy to exercise the scheduler's policy boundary.
scheduler.prefill_chunk_alignment_policy = ExternalKVAlignmentPolicy()
(request,) = create_requests(
num_requests=1,
num_tokens=96,
block_size=16,
req_ids=["external"],
)
scheduler.add_request(request)

output = scheduler.schedule()

assert output.num_scheduled_tokens["external"] == 80


def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
Expand Down
Loading
Loading