Skip to content

[V1] Move usage stats to worker and start logging TPU hardware #16211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 58 commits into from
Apr 25, 2025
Merged
Changes from 9 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f9d82ea
Track TPU usages in vLLM's data dashboards
dyli-google Mar 27, 2025
731b68a
Merge branch 'vllm-project:main' into main
dyli-google Mar 27, 2025
d2d9b9e
Make the code more robust
dyli-google Mar 27, 2025
f168647
Merge branch 'main' of https://github.com/dyli-google/vllm
dyli-google Mar 27, 2025
ee00cf7
Merge branch 'vllm-project:main' into main
dyli-google Apr 7, 2025
39d610f
Your descriptive message about the changes you made
dyli-google Apr 7, 2025
558c60f
format
dyli-google Apr 7, 2025
639f77b
use new API
dyli-google Apr 7, 2025
d5e7533
Merge branch 'vllm-project:main' into main
dyli-google Apr 7, 2025
d9b9d61
Merge branch 'vllm-project:main' into main
dyli-google Apr 7, 2025
8f055c9
address Simon's comments
dyli-google Apr 7, 2025
63bea36
Silence ImportError
dyli-google Apr 7, 2025
25fa30b
Merge branch 'vllm-project:main' into main
dyli-google Apr 8, 2025
8124c99
Merge branch 'vllm-project:main' into main
dyli-google Apr 9, 2025
6a4eea4
Use torch_xla.tpu.get_tpu_type() to get TPU version
dyli-google Apr 9, 2025
ae2f5a6
Merge branch 'vllm-project:main' into main
dyli-google Apr 10, 2025
5d2f2b6
Merge branch 'vllm-project:main' into main
dyli-google Apr 11, 2025
9b3a67c
Merge branch 'vllm-project:main' into main
dyli-google Apr 14, 2025
35fb26b
Merge branch 'vllm-project:main' into main
dyli-google Apr 14, 2025
b0912f0
Merge branch 'vllm-project:main' into main
dyli-google Apr 20, 2025
88dd6c6
Merge branch 'vllm-project:main' into main
dyli-google Apr 22, 2025
727bed5
Add usage to more engines
dyli-google Apr 22, 2025
4f94631
Merge branch 'vllm-project:main' into main
dyli-google Apr 22, 2025
619e496
fix error
dyli-google Apr 22, 2025
a1ae7ff
format
dyli-google Apr 23, 2025
1667fab
Merge branch 'vllm-project:main' into main
dyli-google Apr 23, 2025
9f725f6
Revert "format"
dyli-google Apr 23, 2025
b17dbc9
format
dyli-google Apr 23, 2025
5286466
Merge branch 'vllm-project:main' into main
dyli-google Apr 23, 2025
3bd0c9b
Use import torch_xla
dyli-google Apr 23, 2025
625d21c
Merge branch 'main' of https://github.com/dyli-google/vllm
dyli-google Apr 23, 2025
718729a
format
dyli-google Apr 23, 2025
6e61fba
format
dyli-google Apr 23, 2025
737646d
format
dyli-google Apr 23, 2025
0e093cc
Merge branch 'vllm-project:main' into main
dyli-google Apr 23, 2025
9940dad
Merge branch 'vllm-project:main' into main
dyli-google Apr 23, 2025
f825349
Try Qiliang's idea
dyli-google Apr 23, 2025
7798bde
Merge branch 'vllm-project:main' into main
dyli-google Apr 23, 2025
bbd7f5a
Use Yarong's 2nd idea
dyli-google Apr 24, 2025
5bf9f34
Merge branch 'main' into main
dyli-google Apr 24, 2025
4e38e67
revert vllm/engine/async_llm_engine.py
dyli-google Apr 24, 2025
fc18a7a
simplify code
dyli-google Apr 24, 2025
cf7997a
simplify
dyli-google Apr 24, 2025
3bd5730
fix typo
dyli-google Apr 24, 2025
4374c3c
format
dyli-google Apr 24, 2025
6829371
simplify
dyli-google Apr 24, 2025
3c55fc7
silence error
dyli-google Apr 24, 2025
bbee546
Suppress all exceptions
dyli-google Apr 24, 2025
429b6aa
format
dyli-google Apr 24, 2025
8939235
remove comment
dyli-google Apr 24, 2025
bc284db
Merge branch 'vllm-project:main' into main
dyli-google Apr 24, 2025
bac067a
report usage of TPU and GPU during worker init time
dyli-google Apr 24, 2025
3ad33a2
remove useless import
dyli-google Apr 24, 2025
5b0ab6d
format
dyli-google Apr 24, 2025
1f592e4
Merge branch 'vllm-project:main' into main
dyli-google Apr 24, 2025
98e7ae0
Merge branch 'vllm-project:main' into main
dyli-google Apr 24, 2025
689d343
Merge branch 'vllm-project:main' into main
dyli-google Apr 25, 2025
4eea0a9
Merge branch 'vllm-project:main' into main
dyli-google Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/usage/usage_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def __init__(self) -> None:
self.gpu_type: Optional[str] = None
self.gpu_memory_per_device: Optional[int] = None
self.env_var_json: Optional[str] = None
self.tpu_count: Optional[int] = None
self.tpu_type: Optional[str] = None
self.tpu_memory_per_device: Optional[int] = None

# vLLM Information
self.model_architecture: Optional[str] = None
Expand Down Expand Up @@ -174,6 +177,19 @@ def _report_usage_once(self, model_architecture: str,
self.gpu_memory_per_device = device_property.total_memory
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu():
try:
import torch_xla.runtime as xr
from torch_xla.core import xla_model as xm
self.tpu_count = xr.world_size()
self.tpu_type = xm.xla_device_hw(xm.xla_device())
self.tpu_memory_per_device = xm.get_memory_info().bytes_limit
except ImportError:
logging.warning(
"torch_xla not found, skipping TPU usage statistics.")
self.tpu_count = None
self.tpu_type = None
self.tpu_memory_per_device = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just set gpu_count/gpu_type/gpu_memory_per_device. We can perform the disambiguation in backend processing. We can also silence the import error.

Please paste the output from ~/.config/vllm/usage_stats.json for verification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Where is ~/.config/vllm/usage_stats.json?

I cannot find it inside the docker:
root@t1v-n-a747908a-w-0:/workspace/vllm# cat ~/.config/vllm/usage_stats.json
cat: /root/.config/vllm/usage_stats.json: No such file or directory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, even without docker and building from source, still I cannot find ~/.config/vllm/usage_stats.json. Not sure how it is set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, after I started the server using python -m vllm.entrypoints.api_server, I got the json file (previously I used llm serve command).

It seems my code is not working: "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null

(myenv) dyli_google_com@t1v-n-b4c4da81-w-0:~/.config/vllm$ cat usage_stats.json
{"uuid": "d895fe5b-ff4c-42ca-a65e-deabc113a731", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": false, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.8.3", "context": "API_SERVER", "log_time": 1744157778107180032, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.98, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": false, "enforce_eager": false, "disable_custom_all_reduce": true}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xm.xla_device_hw(xm.xla_device()) is not null in my TPU VM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

  1. Did you start the server using python -m vllm.entrypoints.api_server or llm serve?
  2. Did you use Docker, or just build from source?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't test it in vLLM or docker. Just directly use torch_xla in a naive python environment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like my code doesn't work:

(vllm) dyli_google_com@t1v-n-b4c4da81-w-0:~/vllm$ cat ~/.config/vllm/usage_stats.json
{"uuid": "484b129d-3e9f-466d-86f0-e6f8088af5c1", "provider": "GCP", "num_cpu": 180, "cpu_type": "AMD EPYC 9B14", "cpu_family_model_stepping": "25,17,1", "total_memory": 1521841610752, "architecture": "x86_64", "platform": "Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35", "cuda_runtime": null, "gpu_count": null, "gpu_type": null, "gpu_memory_per_device": null, "env_var_json": "{"VLLM_USE_MODELSCOPE": false, "VLLM_USE_TRITON_FLASH_ATTN": true, "VLLM_ATTENTION_BACKEND": null, "VLLM_USE_FLASHINFER_SAMPLER": null, "VLLM_PP_LAYER_PARTITION": null, "VLLM_USE_TRITON_AWQ": false, "VLLM_USE_V1": false, "VLLM_ENABLE_V1_MULTIPROCESSING": true}", "model_architecture": "LlamaForCausalLM", "vllm_version": "0.6.6.dev1916+g6a4eea4ff", "context": "OPENAI_API_SERVER", "log_time": 1744244819493581056, "source": "production", "dtype": "torch.bfloat16", "tensor_parallel_size": 1, "block_size": 16, "gpu_memory_utilization": 0.95, "quantization": null, "kv_cache_dtype": "auto", "enable_lora": false, "enable_prompt_adapter": false, "enable_prefix_caching": null, "enforce_eager": false, "disable_custom_all_reduce": true}

self.provider = _detect_cloud_provider()
self.architecture = platform.machine()
self.platform = platform.platform()
Expand Down