Skip to content

Commit 9d9072a

Browse files
zhuohan123wanmok
andauthored
Implement prompt logprobs & Batched topk for computing logprobs (#1328)
Co-authored-by: Yunmo Chen <[email protected]>
1 parent 928de46 commit 9d9072a

File tree

14 files changed

+371
-132
lines changed

14 files changed

+371
-132
lines changed

examples/llm_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def main(args: argparse.Namespace):
1111
# Test the following prompts.
1212
test_prompts = [
1313
("A robot may not injure a human being",
14-
SamplingParams(temperature=0.0)),
14+
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
1515
("To be or not to be,",
1616
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
1717
("What is the meaning of life?",

tests/async_engine/test_request_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_request_tracker():
6464
stream_5 = tracker.add_request("5")
6565
assert tracker.new_requests_event.flag
6666
tracker.process_request_output(
67-
RequestOutput("2", "output", [], [], finished=True))
67+
RequestOutput("2", "output", [], [], [], finished=True))
6868
new, finished = tracker.get_new_and_finished_requests()
6969
assert not tracker.new_requests_event.flag
7070
assert len(finished) == 1

tests/conftest.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,39 @@ def generate_beam_search(
107107
outputs[i] = (output_ids, output_str)
108108
return outputs
109109

110+
def generate_greedy_logprobs(
111+
self,
112+
prompts: List[str],
113+
max_tokens: int,
114+
) -> List[List[torch.Tensor]]:
115+
all_logprobs = []
116+
for prompt in prompts:
117+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
118+
output = self.model.generate(
119+
input_ids.cuda(),
120+
use_cache=True,
121+
do_sample=False,
122+
max_new_tokens=max_tokens,
123+
output_hidden_states=True,
124+
return_dict_in_generate=True,
125+
)
126+
seq_logprobs = []
127+
for hidden_states in output.hidden_states:
128+
last_hidden_states = hidden_states[-1][0]
129+
logits = torch.matmul(
130+
last_hidden_states,
131+
self.model.get_output_embeddings().weight.t(),
132+
)
133+
if self.model.get_output_embeddings().bias is not None:
134+
logits += self.model.get_output_embeddings(
135+
).bias.unsqueeze(0)
136+
logprobs = torch.nn.functional.log_softmax(logits,
137+
dim=-1,
138+
dtype=torch.float32)
139+
seq_logprobs.append(logprobs)
140+
all_logprobs.append(seq_logprobs)
141+
return all_logprobs
142+
110143

111144
@pytest.fixture
112145
def hf_runner():

tests/samplers/test_logprobs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import torch
3+
4+
from vllm import SamplingParams
5+
6+
MODELS = ["facebook/opt-125m"]
7+
8+
9+
@pytest.mark.parametrize("model", MODELS)
10+
@pytest.mark.parametrize("dtype", ["half"])
11+
def test_get_prompt_logprobs(
12+
hf_runner,
13+
vllm_runner,
14+
model,
15+
dtype,
16+
example_prompts,
17+
):
18+
max_tokens = 5
19+
hf_model = hf_runner(model, dtype=dtype)
20+
hf_logprobs = hf_model.generate_greedy_logprobs(
21+
example_prompts,
22+
max_tokens=max_tokens,
23+
)
24+
del hf_model
25+
26+
vllm_model = vllm_runner(model, dtype=dtype)
27+
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
28+
logprobs=5,
29+
prompt_logprobs=5,
30+
temperature=0.0)
31+
vllm_results = vllm_model.model.generate(
32+
example_prompts, sampling_params=vllm_sampling_params)
33+
34+
# Test whether logprobs are included in the results.
35+
for result in vllm_results:
36+
assert result.prompt_logprobs is not None
37+
assert result.outputs[0].logprobs is not None
38+
39+
# Test whether prompt logprobs are consistent with HF
40+
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
41+
# Check prompt logprobs
42+
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
43+
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
44+
for token_id, logprob in vllm_prompt_logprob_dict.items():
45+
torch.testing.assert_close(logprob,
46+
hf_logprob[0][i][token_id].item(),
47+
atol=1e-2,
48+
rtol=1e-2)
49+
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
50+
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
51+
for token_id, logprob in vllm_sample_logprob_dict.items():
52+
torch.testing.assert_close(logprob,
53+
hf_logprob[i][-1][token_id].item(),
54+
atol=1e-2,
55+
rtol=1e-2)

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def get_head_size(self) -> int:
143143
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
144144
"""Returns the number of KV heads per GPU worker."""
145145
# For GPTBigCode & Falcon:
146-
# Note: for falcon, when new_decoder_architecture is True, the
146+
# NOTE: for falcon, when new_decoder_architecture is True, the
147147
# multi_query flag is ignored and we use n_head_kv for the number of
148148
# KV heads.
149149
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]

vllm/engine/llm_engine.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from vllm.outputs import RequestOutput
1313
from vllm.sampling_params import SamplingParams
1414
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
15-
SequenceGroupMetadata, SequenceOutputs,
16-
SequenceStatus)
15+
SequenceGroupMetadata, SequenceGroupOutputs,
16+
SequenceOutputs, SequenceStatus)
1717
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
1818
get_tokenizer)
1919
from vllm.utils import Counter
@@ -350,9 +350,15 @@ def _check_beam_search_early_stopping(
350350
eos_token_id=self.tokenizer.eos_token_id))
351351
return current_worst_score >= highest_attainable_score
352352

353-
def _process_sequence_group_samples(
354-
self, seq_group: SequenceGroup,
355-
samples: List[SequenceOutputs]) -> None:
353+
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
354+
outputs: SequenceGroupOutputs) -> None:
355+
# Process prompt logprobs
356+
prompt_logprobs = outputs.prompt_logprobs
357+
if prompt_logprobs is not None:
358+
seq_group.prompt_logprobs = prompt_logprobs
359+
360+
# Process samples
361+
samples = outputs.samples
356362
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
357363
existing_finished_seqs = seq_group.get_finished_seqs()
358364
parent_child_dict = {
@@ -520,8 +526,8 @@ def _process_model_outputs(
520526
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
521527
# Update the scheduled sequence groups with the model outputs.
522528
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
523-
for seq_group, samples in zip(scheduled_seq_groups, output):
524-
self._process_sequence_group_samples(seq_group, samples)
529+
for seq_group, outputs in zip(scheduled_seq_groups, output):
530+
self._process_sequence_group_outputs(seq_group, outputs)
525531

526532
# Free the finished sequence groups.
527533
self.scheduler.free_finished_seq_groups()

vllm/model_executor/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def set_attn_bias(self, input_metadata: InputMetadata,
420420
# Generates ALiBi mask for each prompt.
421421
for prompt_len in input_metadata.prompt_lens:
422422
bias = torch.arange(prompt_len, dtype=dtype)
423-
# Note(zhuohan): HF uses
423+
# NOTE(zhuohan): HF uses
424424
# `bias = bias[None, :].repeat(prompt_len, 1)`
425425
# here. We find that both biases give the same results, but
426426
# the bias below more accurately follows the original ALiBi

0 commit comments

Comments
 (0)