|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import itertools |
7 | 8 | import os
|
8 | 9 | from dataclasses import dataclass
|
9 | 10 | from datetime import timedelta
|
@@ -54,7 +55,47 @@ def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> i
|
54 | 55 | num_params = sum(p.numel() for p in model.parameters())
|
55 | 56 | if exclude_embedding:
|
56 | 57 | num_params -= model.tok_embeddings.weight.numel()
|
57 |
| - return num_params |
| 58 | + readable_num_params = format_model_params(num_params) |
| 59 | + return readable_num_params |
| 60 | + |
| 61 | + |
| 62 | +def get_module_size(stage: torch.nn.Module) -> int: |
| 63 | + """get module (stage) size in bytes""" |
| 64 | + model_size = sum( |
| 65 | + [ |
| 66 | + p.numel() * p.dtype.itemsize |
| 67 | + for p in itertools.chain(stage.parameters(), stage.buffers()) |
| 68 | + ] |
| 69 | + ) |
| 70 | + return model_size |
| 71 | + |
| 72 | + |
| 73 | +def format_model_params(params): |
| 74 | + """turn the num_params into a readable formatted number""" |
| 75 | + if params >= 1_000_000_000: |
| 76 | + return f"{params / 1_000_000_000:.2f}B" |
| 77 | + elif params >= 1_000_000: |
| 78 | + return f"{params / 1_000_000:.2f}M" |
| 79 | + else: |
| 80 | + return f"{params:,}" |
| 81 | + |
| 82 | + |
| 83 | +def bytes_to_readable(bytes_value: int, round_to: int = 2) -> str: |
| 84 | + """formatting function to make reading model (stage) sizes easy""" |
| 85 | + GiB = 1024**3 # 1 GiB in bytes |
| 86 | + MiB = 1024**2 # 1 MiB in bytes |
| 87 | + |
| 88 | + if bytes_value >= GiB: |
| 89 | + value = bytes_value / GiB |
| 90 | + unit = "GiB" |
| 91 | + else: |
| 92 | + value = bytes_value / MiB |
| 93 | + unit = "MiB" |
| 94 | + |
| 95 | + # Round to 2 decimal places |
| 96 | + rounded_value = round(value, round_to) |
| 97 | + |
| 98 | + return f"{rounded_value} {unit}" |
58 | 99 |
|
59 | 100 |
|
60 | 101 | @dataclass(frozen=True)
|
|
0 commit comments