Skip to content

Commit 5986ed2

Browse files
authored
[distributed] add stage metrics - total params per stage, total size and present it in a nicely formatted manner (#1120)
* add stage metrics - total params per stage, total size * PR feedback * PR feedback, typing
1 parent 8b6aa07 commit 5986ed2

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

dist_run.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@
2525
load_safetensor_weights,
2626
)
2727

28-
from distributed.utils import Color as color, TrackTime, CUDATrackTime, GPUMemoryMonitor
28+
from distributed.utils import (
29+
Color as color,
30+
GPUMemoryMonitor,
31+
get_module_size,
32+
get_num_params,
33+
bytes_to_readable,
34+
TrackTime,
35+
CUDATrackTime,
36+
)
2937

3038
from distributed.verification_utils import find_cpu_tensors
3139
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
@@ -193,11 +201,17 @@ def main():
193201
logger.info(f"Loading weights for {pp_rank=} on {device=}")
194202
with TrackTime("cuda") as timer:
195203
_load_model_weights(model, hf_model_name, device=device, model_config=config)
196-
197204
logger.info(
198205
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
199206
)
200-
207+
208+
# info on stage size and params
209+
stage_size = get_module_size(model)
210+
stage_size_formatted = bytes_to_readable(stage_size)
211+
stage_num_params = get_num_params(model)
212+
logger.info(
213+
f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}\n"
214+
)
201215

202216
# Setup input position
203217
# input_pos for prefill: a list of increasing integers from 0 to seqlen

distributed/utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import os
89
from dataclasses import dataclass
910
from datetime import timedelta
@@ -54,7 +55,47 @@ def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> i
5455
num_params = sum(p.numel() for p in model.parameters())
5556
if exclude_embedding:
5657
num_params -= model.tok_embeddings.weight.numel()
57-
return num_params
58+
readable_num_params = format_model_params(num_params)
59+
return readable_num_params
60+
61+
62+
def get_module_size(stage: torch.nn.Module) -> int:
63+
"""get module (stage) size in bytes"""
64+
model_size = sum(
65+
[
66+
p.numel() * p.dtype.itemsize
67+
for p in itertools.chain(stage.parameters(), stage.buffers())
68+
]
69+
)
70+
return model_size
71+
72+
73+
def format_model_params(params):
74+
"""turn the num_params into a readable formatted number"""
75+
if params >= 1_000_000_000:
76+
return f"{params / 1_000_000_000:.2f}B"
77+
elif params >= 1_000_000:
78+
return f"{params / 1_000_000:.2f}M"
79+
else:
80+
return f"{params:,}"
81+
82+
83+
def bytes_to_readable(bytes_value: int, round_to: int = 2) -> str:
84+
"""formatting function to make reading model (stage) sizes easy"""
85+
GiB = 1024**3 # 1 GiB in bytes
86+
MiB = 1024**2 # 1 MiB in bytes
87+
88+
if bytes_value >= GiB:
89+
value = bytes_value / GiB
90+
unit = "GiB"
91+
else:
92+
value = bytes_value / MiB
93+
unit = "MiB"
94+
95+
# Round to 2 decimal places
96+
rounded_value = round(value, round_to)
97+
98+
return f"{rounded_value} {unit}"
5899

59100

60101
@dataclass(frozen=True)

0 commit comments

Comments
 (0)