Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
75ae5fd
spec draft
LiuXiaoxuanPKU Nov 6, 2023
46cd4c3
Merge branch 'vllm-project:main' into spec
LiuXiaoxuanPKU Nov 6, 2023
edeaec0
minor
LiuXiaoxuanPKU Nov 6, 2023
95a7e13
minor
LiuXiaoxuanPKU Nov 8, 2023
366fbb9
draft tokens
LiuXiaoxuanPKU Nov 8, 2023
3c7397e
minor
LiuXiaoxuanPKU Nov 8, 2023
9f35009
merge
LiuXiaoxuanPKU Nov 8, 2023
9b64276
Merge branch 'main' of github.com:LiuXiaoxuanPKU/vllm
LiuXiaoxuanPKU Nov 8, 2023
1525262
Merge branch 'main' into spec
LiuXiaoxuanPKU Nov 8, 2023
7e6224a
minor
LiuXiaoxuanPKU Nov 9, 2023
93901c8
Merge branch 'spec' of github.com:LiuXiaoxuanPKU/vllm into spec
LiuXiaoxuanPKU Nov 9, 2023
692328a
draft logits
LiuXiaoxuanPKU Nov 9, 2023
8b6d647
need to change draft token probs data structure
LiuXiaoxuanPKU Nov 9, 2023
675e1ae
rejection sampling
LiuXiaoxuanPKU Nov 9, 2023
32267f6
rejection sampling
LiuXiaoxuanPKU Nov 10, 2023
1aab040
format
LiuXiaoxuanPKU Nov 12, 2023
826b54a
get draft probs
LiuXiaoxuanPKU Nov 12, 2023
b2ec9aa
style
LiuXiaoxuanPKU Nov 12, 2023
6382396
combine draft_token_ids and output_token_ids in SequenceData
LiuXiaoxuanPKU Nov 13, 2023
89d8ba2
invalidate kv draft
LiuXiaoxuanPKU Nov 13, 2023
9594d08
fix
LiuXiaoxuanPKU Nov 13, 2023
6b1e94c
pass in multiple tokens for generation phase, kv_mqa
LiuXiaoxuanPKU Nov 13, 2023
2d5c379
pass scheduler to spec worker
LiuXiaoxuanPKU Nov 13, 2023
025bb89
mqa
LiuXiaoxuanPKU Nov 15, 2023
dd23ff7
separate sampler
LiuXiaoxuanPKU Nov 15, 2023
f1b3987
lots of fix, multi_qa_kv runnable
LiuXiaoxuanPKU Nov 16, 2023
9a85990
nan in hidden states
LiuXiaoxuanPKU Nov 16, 2023
54bfebd
lots of style fix, early break accepting tokens
LiuXiaoxuanPKU Nov 17, 2023
a904ac9
fix free bug
LiuXiaoxuanPKU Nov 18, 2023
0cb9326
bug fix
LiuXiaoxuanPKU Nov 18, 2023
4e9ae6c
minor fix get target probs in prefill phase
LiuXiaoxuanPKU Nov 18, 2023
0ff36e7
fix mismatch between logical and physical blocks!!
LiuXiaoxuanPKU Nov 24, 2023
d2d67f9
add alphas
LiuXiaoxuanPKU Nov 27, 2023
7d94cb2
tokenizer & bug fix
LiuXiaoxuanPKU Nov 30, 2023
b1a5a88
pass tests
LiuXiaoxuanPKU Nov 30, 2023
93c7956
add flag
LiuXiaoxuanPKU Dec 3, 2023
141da66
remove speculative decoding for prompt run
LiuXiaoxuanPKU Dec 5, 2023
439c88b
remove temperature, only support all greedy for now
LiuXiaoxuanPKU Dec 6, 2023
40ab8d4
clean
Dec 7, 2023
bf2ebe9
minor
Dec 7, 2023
179e968
merge
LiuXiaoxuanPKU Dec 7, 2023
664a256
fix & pass tests
LiuXiaoxuanPKU Dec 7, 2023
7f9a373
format
LiuXiaoxuanPKU Dec 7, 2023
0540142
remove old files
LiuXiaoxuanPKU Dec 7, 2023
993f2d4
remove untouched file
LiuXiaoxuanPKU Dec 8, 2023
c410cbe
format
LiuXiaoxuanPKU Dec 8, 2023
9f2d98b
format
LiuXiaoxuanPKU Dec 8, 2023
c7df07b
restore test
LiuXiaoxuanPKU Dec 15, 2023
8a3208f
remove cached_mqa kernel, remove spec dec
LiuXiaoxuanPKU Dec 15, 2023
a5e8ea1
restore input metadata, change sampler
LiuXiaoxuanPKU Dec 15, 2023
83775d0
restore
LiuXiaoxuanPKU Dec 15, 2023
7218b7f
remove test
LiuXiaoxuanPKU Dec 15, 2023
33767b9
sequence
LiuXiaoxuanPKU Dec 15, 2023
8295830
update comments in sequence
LiuXiaoxuanPKU Dec 15, 2023
35b8a5c
engine
LiuXiaoxuanPKU Dec 15, 2023
9bab2c6
restore runner
LiuXiaoxuanPKU Dec 15, 2023
e11933d
Merge branch 'main' into multi-token
LiuXiaoxuanPKU Dec 15, 2023
0635148
fix
LiuXiaoxuanPKU Dec 15, 2023
7f9ac80
minor
LiuXiaoxuanPKU Dec 15, 2023
9b7d454
format
LiuXiaoxuanPKU Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]

# delete num tokens from the end in the same block
def delete_last_tokens(self, num: int) -> None:
assert num > 0
assert num <= self.num_tokens
self.num_tokens -= num
for i in range(self.num_tokens, len(self.token_ids)):
self.token_ids[i] = _BLANK_TOKEN_ID


class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
Expand Down
9 changes: 9 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number

def free_tailing_blocks(self, seq: Sequence) -> None:
block_table = self.block_tables[seq.seq_id]
free_cnt = len(seq.logical_token_blocks) - len(block_table)
while free_cnt > 0:
block = block_table.pop()
self.gpu_allocator.free(block)
free_cnt -= 1
self.block_tables[seq.seq_id] = block_table

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
Expand Down
24 changes: 23 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
SequenceGroupMetadata, SequenceStatus,
SequenceOutput)

logger = init_logger(__name__)

Expand Down Expand Up @@ -309,6 +310,27 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_invalid_kv(self, seq: Sequence, seq_out: SequenceOutput):
# if all the tokens are accepted
# draft_token_ids: [A, B, C], accepted_tokens: [A, B, C, D], invalid_token_cnt = 3 + 1 - 4 = 0
# if part of the tokens are accepted
# draft_token_ids: [A, B, C], accepted_tokens: [A, B, D], invalid_token_cnt = 3 + 1 - 3 = 1
invalid_token_cnt = len(seq.data.get_draft_token_ids()) + 1 - len(
seq_out.accepted_tokens)
assert invalid_token_cnt >= 0

if invalid_token_cnt == 0:
return invalid_token_cnt

# delete data
seq.data.output_token_ids = seq.data.output_token_ids[:
-invalid_token_cnt]
# delete from logical table
seq.delete_tailing_tokens(invalid_token_cnt)
# delete from physical table
self.block_manager.free_tailing_blocks(seq)
return invalid_token_cnt

def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
Expand Down
93 changes: 71 additions & 22 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
parent.step_gen_token_ids = [last_child_sample.output_token]
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
self._truncate_sequence(seq, seq_group.sampling_params)
self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)

Expand Down Expand Up @@ -657,12 +659,77 @@ def _log_system_stats(
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now

def _truncate_step_gen_token_ids(self, seq: Sequence,
truncate_len: int) -> None:
if truncate_len > 0:
seq.step_gen_token_ids = seq.step_gen_token_ids[:-truncate_len]

def _truncate_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:

output_token_ids = seq.get_output_token_ids()
for stop_token_id in sampling_params.stop_token_ids:
if stop_token_id in seq.get_token_ids():
# seq: [p1, p2, p3, A, B, C], stop_token: B, p1, p2, p3 are prompt tokens
# truncate_len = 4 + 1 - 3 = 2
# we need to include the stop_token in the output
truncated_output_len = seq.get_token_ids().index(
stop_token_id) + 1 - seq.get_prompt_len()
self._truncate_step_gen_token_ids(
seq,
len(output_token_ids) - truncated_output_len)
# we don't modify logical/physical block here
seq.data.output_token_ids = output_token_ids[:
truncated_output_len]
seq.status = SequenceStatus.FINISHED_STOPPED
return

# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
truncated_output_len = self.scheduler_config.max_model_len - seq.get_prompt_len(
)
self._truncate_step_gen_token_ids(
seq,
len(output_token_ids) - truncated_output_len)
seq.data.output_token_ids = output_token_ids[:truncated_output_len]
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() >= sampling_params.max_tokens:
truncated_output_len = sampling_params.max_tokens
self._truncate_step_gen_token_ids(
seq,
len(output_token_ids) - truncated_output_len)
seq.data.output_token_ids = output_token_ids[:truncated_output_len]
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and self.tokenizer.eos_token_id in seq.get_output_token_ids()):
truncated_output_len = output_token_ids.index(
self.tokenizer.eos_token_id) + 1
self._truncate_step_gen_token_ids(
seq,
len(output_token_ids) - truncated_output_len)
seq.data.output_token_ids = output_token_ids[:truncated_output_len]
seq.status = SequenceStatus.FINISHED_STOPPED
return

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
if seq.tokens is None:
# prefill phase
new_token_ids = seq.get_token_ids()
else:
gen_len = len(seq.step_gen_token_ids)
new_token_ids = seq.get_output_token_ids()[-gen_len:]
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.tokenizer,
all_input_ids=seq.get_token_ids(),
prompt_len=seq.get_prompt_len(),
new_token_ids=new_token_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
Expand All @@ -681,31 +748,13 @@ def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
if stop_str in seq.output_text:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.output_text = seq.output_text[:seq.output_text.
index(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
return

# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return

def _run_workers_in_batch(
self,
Expand Down
86 changes: 73 additions & 13 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ def forward(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
prompt_run = sampling_metadata.num_prompts > 0
len_to_gen = hidden_states.shape[1]
if len_to_gen > 1 and (not prompt_run):
return self._multi_token_forward(embedding, hidden_states,
sampling_metadata, embedding_bias)
else:
return self._forward(embedding, hidden_states, sampling_metadata,
embedding_bias)

def _forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
Expand Down Expand Up @@ -97,6 +113,45 @@ def forward(
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)

def _multi_token_forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Sampler forward for speculative decoding.
# It is a simiplified version of the original forward
# and only supports argmax sampling
batch_size = hidden_states.shape[0]
len_to_gen = hidden_states.shape[1]
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)

# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)

# Do not apply templerature since we only support greedy sampling

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# sample_results = torch.argmax(logprobs, dim=-1).cpu().reshape(batch_size, -1)
sample_results = _greedy_sample(sampling_metadata.seq_groups, logprobs,
len_to_gen)
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)

probdis = probs.reshape(batch_size, len_to_gen, -1)
# change probs to a list of lists
probdis = [list(tensor.unbind(0)) for tensor in probdis.unbind(0)]
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)


def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
Expand Down Expand Up @@ -362,11 +417,12 @@ def _apply_min_p(
return logits


def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
def _greedy_sample(selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
len_to_gen: int = 1) -> List[Tuple[List[int], List[int]]]:
samples = torch.argmax(logprobs, dim=-1).cpu()
if len_to_gen > 1:
samples = samples.reshape(-1, len_to_gen)
sample_idx = 0
results = []
for seq_group in selected_seq_groups:
Expand All @@ -375,10 +431,13 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx].item()]
if len_to_gen > 1:
next_token_ids = samples[sample_idx].tolist()
else:
next_token_ids = [samples[sample_idx].item()]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
# assert sample_idx == logprobs.size(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment?

return results


Expand Down Expand Up @@ -545,13 +604,14 @@ def _get_logprobs(
token_id for token_id in prompt_tokens[1:])
sample_idx += prompt_len - 1
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
[sample_idx + parent_id
for parent_id in parent_ids] * len(next_token_ids))
batched_logprobs_query_token_indices.extend(next_token_ids)
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
# assert sample_idx == logprobs.size(0)

# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
Expand Down Expand Up @@ -629,11 +689,10 @@ def _get_logprobs(


def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs]) -> SamplerOutput:
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
Expand All @@ -642,6 +701,7 @@ def _build_sampler_output(
seq_ids, _ = seq_group
next_token_ids, parent_ids = sample_result
seq_outputs = []

for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
Expand Down
24 changes: 20 additions & 4 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
self.block_size = block_size

self.data = SequenceData(prompt_token_ids)
self.step_gen_token_ids: List[int] = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand All @@ -140,6 +141,9 @@ def _append_logical_block(self) -> None:
)
self.logical_token_blocks.append(block)

def _delete_logical_block(self, block: LogicalTokenBlock) -> None:
self.logical_token_blocks.remove(block)

def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
cursor = 0
while cursor < len(token_ids):
Expand All @@ -166,6 +170,18 @@ def append_token_id(
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id])

# delete n tokens from the end of the sequence
def delete_tailing_tokens(self, n: int) -> None:
while n > 0:
assert len(self.logical_token_blocks) > 0
last_block = self.logical_token_blocks[-1]
if last_block.num_tokens < n:
n -= last_block.num_tokens
self._delete_logical_block(last_block)
else:
last_block.delete_last_tokens(n)
break

def get_len(self) -> int:
return self.data.get_len()

Expand Down Expand Up @@ -358,16 +374,16 @@ class SequenceOutput:
Args:
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
output_token: The output token ID(s).
logprobs: The logprobs of the output token(s).
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""

def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float],
output_token: Union[int, List[int]],
logprobs: Union[Dict[int, float], List[Dict[int, float]]],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
Expand Down
Loading