Skip to content

Conversation

gholmes829
Copy link

@gholmes829 gholmes829 commented Sep 18, 2025

Purpose

I'd like to add a compelling option to enable proportional KV cache memory distribution. I use pipeline parallelism with 2 GPUs and asymmetrically load layers between them. When using --kv-cache-memory-bytes, VLLM will correctly allocate only as much VRAM as is needed to load the model layers on the devices.

...HOWEVER for the KV cache, it will uniformly allocate memory. This means if I split the layers say 48/16, the first rank ends up with ~1x max concurrency per request while the second rank ends up with ~3x max concurrency per request. This is extremely wasteful, as the second rank is mostly bottlenecked by the prior stage with lower max concurrency.

In short, this behavior causes wildly inefficient memory usage characterized by:

  • GPUs with fewer layers waste memory
  • GPUs with more layers are memory-starved
  • Overall system throughput is bottlenecked by the most constrained GPU

Solution:

Add opt-in flag --enable-pp-prop-kv-cache which, when enabled, modifies the behavior of --kv-cache-memory-bytes to get distributed per device proportional to the number of layers distributed to that device (as per VLLM_PP_LAYER_PARTITION) instead of being a uniform allocation per device.

Test Plan

In addition to unit tests, I manually tested bunch of cases comparing state before and after my changes. I am only including relevant arguments and reference "vllm serve" for following, but I will list more details of my env below and am happy to share detailed engine arguments if anyone would like.

Baseline run (prior to any of my changes):

pytest tests/distributed/test_pp_prop_kv_cache.py -v

(A) For manual testing with asymmetric PP:
VLLM_PP_LAYER_PARTITION=56,8 vllm serve Qwen/Qwen3-32B \
    --pipeline-parallel-size 2 \
    --kv-cache-memory-bytes 5G

(B) New --enable-pp-prop-kv-cache enabled (same thing but with the new flag):

VLLM_PP_LAYER_PARTITION=56,8 vllm serve Qwen/Qwen3-32B \
    --pipeline-parallel-size 2 \
    --kv-cache-memory-bytes 5G \
    --enable-pp-prop-kv-cache

Unit tests run with:

pytest tests/distributed/test_pp_prop_kv_cache.py -v

Environment details:

  • Docker WSL2
  • Online serving VLLM V1
  • From vllm/vllm-openai:v0.10.2
  • Cuda 2x3090

Test Result

(A) Baseline (pre-implementation):

vllm  | (Worker_PP1 pid=119) INFO 09-17 22:46:33 [gpu_worker.py:284] Initial free memory 22.72 GiB, reserved 5.00 GiB memory for KV Cache as specified by kv_cache_memory_bytes config and skipped memory profiling. This does does not respect the gpu_memory_utilization config. Only use kv_cache_memory_bytes config when you want manual control of KV cache memory size. If OOM'ed, check the difference of initial free memory between the current run and the previous run where kv_cache_memory_bytes is suggested and update it correspondingly.
vllm  | (Worker_PP0 pid=118) INFO 09-17 22:48:14 [gpu_worker.py:284] Initial free memory 22.72 GiB, reserved 5.00 GiB memory for KV Cache as specified by kv_cache_memory_bytes config and skipped memory profiling. This does does not respect the gpu_memory_utilization config. Only use kv_cache_memory_bytes config when you want manual control of KV cache memory size. If OOM'ed, check the difference of initial free memory between the current run and the previous run where kv_cache_memory_bytes is suggested and update it correspondingly.
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:48:14 [kv_cache_utils.py:864] GPU KV cache size: 46,800 tokens
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:48:14 [kv_cache_utils.py:868] Maximum concurrency for 40,960 tokens per request: 1.14x
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:48:14 [kv_cache_utils.py:864] GPU KV cache size: 327,680 tokens
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:48:14 [kv_cache_utils.py:868] Maximum concurrency for 40,960 tokens per request: 8.00x

(B) With --enable-pp-prop-kv-cache enabled:

vllm  | (Worker_PP1 pid=131) INFO 09-17 22:30:38 [gpu_worker.py:270] Initial free memory 22.72 GiB, reserved 0.62 GiB for KV cache (proportionally split from 5.00 GiB based on 8/64 layers on rank 1). Skipping memory profiling.
vllm  | (Worker_PP0 pid=130) INFO 09-17 22:32:11 [gpu_worker.py:270] Initial free memory 22.72 GiB, reserved 4.38 GiB for KV cache (proportionally split from 5.00 GiB based on 56/64 layers on rank 0). Skipping memory profiling.
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:32:11 [kv_cache_utils.py:864] GPU KV cache size: 40,960 tokens
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:32:11 [kv_cache_utils.py:868] Maximum concurrency for 40,960 tokens per request: 1.00x
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:32:11 [kv_cache_utils.py:864] GPU KV cache size: 40,960 tokens
vllm  | (EngineCore_DP0 pid=88) INFO 09-17 22:32:11 [kv_cache_utils.py:868] Maximum concurrency for 40,960 tokens per request: 1.00x

Note how for Scenario A it allocates:

  • GPU0 -- 46,800 tokens @1.14x concurrency = 5GB
  • GPU1 -- 327,680 tokens @8.00x concurrency = 5GB

And for Scenario B it allocates:

  • GPU0 -- 40,960 tokens @1.00x concurrency = 4.38GB
  • GPU1 -- 40,960 tokens @1.00x concurrency = 0.62GB

Note that it doesn't normalize it to 1.00 or anything, I just lined up the math perfectly on this one :)

In terms of numbers, in my scenario this has basically decreased my KV cache memory usage by 50% for the same performance! Other setups could benefit more or less than this depending on the partitioning and concurrency multiplier.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization for distributing KV cache memory proportionally in asymmetric pipeline parallelism setups. The implementation is well-structured and includes comprehensive unit tests. My review focuses on simplifying the core calculation logic for better efficiency and maintainability. By removing an unnecessary loop, the code becomes cleaner and more direct. I've provided suggestions to refactor this logic in both vllm/worker/worker.py and vllm/v1/worker/gpu_worker.py.

Copy link

mergify bot commented Sep 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gholmes829.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant