Skip to content

Implement prompt logprobs & Batched topk for computing logprobs #1328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 16, 2023

Conversation

zhuohan123
Copy link
Member

@zhuohan123 zhuohan123 commented Oct 11, 2023

This PR:

  • Added prompt_logprobs to SamplingParams and RequestOutput. This makes vLLM to support returning the log probabilities of prompt tokens, which is required to support echo in OpenAI server.
  • Refactor the logprobs logic so that the query to topk logits is done in a batched fashion.

This PR will have merge conflicts with #1337. I think a good plan is to perform the optimization in #1337 along with the refactoring of InputMetadata after this PR is merged.

TODOs:

  • Test the correctness of this PR.
  • Test the performance with the new logprobs implementation.
  • Maybe in a future PR: refactor bloated reference to InputMetadata

@zhuohan123 zhuohan123 changed the title [WIP] Implement prompt logprobs Implement prompt logprobs & Batched topk for computing logprobs Oct 12, 2023
@WoosukKwon WoosukKwon mentioned this pull request Oct 13, 2023
3 tasks
@zhuohan123 zhuohan123 requested a review from WoosukKwon October 13, 2023 18:49
@zhuohan123
Copy link
Member Author

@WoosukKwon @Yard1 This PR is ready for review.

@zhuohan123
Copy link
Member Author

Profiling result:

# main without logprobs:
(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False, dtype='auto')
INFO 10-13 20:28:36 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:29:24 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
INFO 10-13 20:29:24 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:29:29 llm_engine.py:207] # GPU blocks: 7455, # CPU blocks: 512
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [03:51<00:00,  8.65it/s]
Throughput: 8.65 requests/s, 4186.71 tokens/s

# this branch without logprobs:
(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --nu
m-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', quantization=None, tensor_parallel_siz
e=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False, dtype='auto')
INFO 10-13 18:56:14 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 18:56:57 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
INFO 10-13 18:56:57 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 18:57:02 llm_engine.py:207] # GPU blocks: 7455, # CPU blocks: 512
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [03:51<00:00,  8.65it/s]
Throughput: 8.65 requests/s, 4185.26 tokens/s

# main with logprobs 5
(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False, dtype='auto')
INFO 10-13 20:42:18 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:43:02 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
INFO 10-13 20:43:02 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:43:07 llm_engine.py:207] # GPU blocks: 7455, # CPU blocks: 512
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [05:11<00:00,  6.42it/s]
Throughput: 6.41 requests/s, 3102.15 tokens/s

# this branch with logprobs 5
(vllm) zhuohan@zhuohan-1:~/vllm/vllm/benchmarks$ python benchmark_throughput.py --backend vllm --model huggyllama/llama-7b --dataset ../../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000
Namespace(backend='vllm', dataset='../../data/ShareGPT_V3_unfiltered_cleaned_split.json', model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=2000, seed=0, hf_max_batch_size=None, trust_remote_code=False, dtype='auto')
INFO 10-13 20:48:56 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:49:44 llm_engine.py:72] Initializing an LLM engine with config: model='huggyllama/llama-7b', tokenizer='huggyllama/llama-7b', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
INFO 10-13 20:49:44 tokenizer.py:31] For some LLaMA V1 models, initializing the fast tokenizer may take a long time. To reduce the initialization time, consider using 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
INFO 10-13 20:49:49 llm_engine.py:207] # GPU blocks: 7455, # CPU blocks: 512
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [04:02<00:00,  8.24it/s]
Throughput: 8.23 requests/s, 3980.15 tokens/s

@zhuohan123 zhuohan123 requested a review from Yard1 October 13, 2023 20:54
Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! I will update my PR once this is merged. We should definitely consider a broader refactor here to precompute as many things as possible in _prepare_inputs to avoid multiple loops and CPU-GPU syncing unless necessary.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123 Awesome! Thanks for the hard work. Please check my comments.

Comment on lines 571 to 572
prompt_logprobs: List[Optional[List[Optional[Dict[int, int]]]]],
sample_logprobs: List[List[Optional[Dict[int, int]]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_logprobs: List[Optional[List[Optional[Dict[int, int]]]]],
sample_logprobs: List[List[Optional[Dict[int, int]]]],
prompt_logprobs: List[Optional[List[Optional[Dict[int, float]]]]],
sample_logprobs: List[List[Optional[Dict[int, float]]]],

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: Why do we need Optional here? In which case is it used?

Copy link
Member Author

@zhuohan123 zhuohan123 Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sample_logprobs, there should be no Optional. I have fixed the code. For prompt_logprobs, there are two case:

  1. If a request does not query prompt logprobs, the prompt_logprobs for that request will be None.
  2. The first token of the prompt will not have a log proboability, so it will always be None. This is the same behavior as the OpenAI endpoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the explanation!

@WoosukKwon
Copy link
Collaborator

BTW, I got 2 logprobs per token when running examples/llm_engine_example.py where prompt_logprobs=1. Is this expected? It's a bit confusing because logprobs=1 returns 1 log prob per token.

..., prompt_logprobs=[None, {250: -3.594587802886963, 100: -1.414900302886963}, {9916: -8.579404830932617, 319: -3.582822799682617},  ...

@wanmok
Copy link
Contributor

wanmok commented Oct 16, 2023

BTW, I got 2 logprobs per token when running examples/llm_engine_example.py where prompt_logprobs=1. Is this expected? It's a bit confusing because logprobs=1 returns 1 log prob per token.

..., prompt_logprobs=[None, {250: -3.594587802886963, 100: -1.414900302886963}, {9916: -8.579404830932617, 319: -3.582822799682617},  ...

A similar question here. Does the design support logprobs=0? In this case, we would like to only know the selected log probs in the prompt rather than top-k. This is required to implement the OpenAI API's echo.

@zhuohan123
Copy link
Member Author

The logprobs behavior of vLLM follows the OpenAI API's specification:

Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response.

@zhuohan123 zhuohan123 requested a review from WoosukKwon October 16, 2023 07:42
@zhuohan123
Copy link
Member Author

@WoosukKwon This PR is ready for review again.

@RanchiZhao
Copy link

is this now available? i am eager to use ppl method to do the evals!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the hard work!

@zhuohan123 zhuohan123 merged commit 9d9072a into main Oct 16, 2023
@zhuohan123 zhuohan123 deleted the prompt-logprobs branch October 16, 2023 21:02
@WoosukKwon WoosukKwon mentioned this pull request Oct 17, 2023
@wheel-is
Copy link

wheel-is commented Dec 1, 2023

has this been implemented? doesnt seem to be returning prompt logits when i specify it

@ToSev7en
Copy link

why I come across some logprob=-inf ?

yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 29, 2025
When contiguous_pa is enabled, the decode graph is not warmed-up for the
max block_id.
See example below, when the total number of HPU blocks is 1974, the
decode graph should be warmed-up for (bs, 1974).

> INFO 05-28 03:29:33 executor_base.py:110] # HPU blocks: 1974, # CPU
blocks: 954

Need to work with
HabanaAI/vllm-hpu-extension#201

In habana_main, this code has been updated.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants