|
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | import torch |
| 8 | +from vllm_test_utils import monitor |
8 | 9 |
|
9 | 10 | from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, |
10 | 11 | get_open_port, memory_profiling, merge_async_iterators, |
@@ -289,16 +290,32 @@ def test_memory_profiling(): |
289 | 290 |
|
290 | 291 | weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB |
291 | 292 |
|
| 293 | + def measure_current_non_torch(): |
| 294 | + free, total = torch.cuda.mem_get_info() |
| 295 | + current_used = total - free |
| 296 | + current_torch = torch.cuda.memory_reserved() |
| 297 | + current_non_torch = current_used - current_torch |
| 298 | + return current_non_torch |
| 299 | + |
292 | 300 | with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, |
293 | | - weights_memory_in_bytes=weights_memory_in_bytes) as result: |
| 301 | + weights_memory_in_bytes=weights_memory_in_bytes) as result, \ |
| 302 | + monitor(measure_current_non_torch) as monitored_values: |
294 | 303 | # make a memory spike, 1 GiB |
295 | 304 | spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) |
296 | 305 | del spike |
297 | 306 |
|
298 | 307 | # Add some extra non-torch memory 256 MiB (simulate NCCL) |
299 | 308 | handle2 = lib.cudaMalloc(256 * 1024 * 1024) |
300 | 309 |
|
| 310 | + # this is an analytic value, it is exact, |
| 311 | + # we only have 256 MiB non-torch memory increase |
| 312 | + measured_diff = monitored_values.values[-1] - monitored_values.values[0] |
| 313 | + assert measured_diff == 256 * 1024 * 1024 |
| 314 | + |
301 | 315 | # Check that the memory usage is within 5% of the expected values |
| 316 | + # 5% tolerance is caused by PyTorch caching allocator, |
| 317 | + # we cannot control PyTorch's behavior of its internal buffers, |
| 318 | + # which causes a small error (<10 MiB in practice) |
302 | 319 | non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa |
303 | 320 | torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa |
304 | 321 | assert abs(non_torch_ratio - 1) <= 0.05 |
|
0 commit comments