Skip to content

Commit 88ad308

Browse files
committed
Addressing coderabbit comments
1 parent 4f9ebee commit 88ad308

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

examples/notebooks/memory_estimator_example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@
7373
},
7474
{
7575
"cell_type": "code",
76-
"execution_count": 24,
76+
"execution_count": null,
7777
"id": "70462895",
7878
"metadata": {},
7979
"outputs": [],
8080
"source": [
8181
"num_gpus = 2\n",
82-
"gpu_memory = 48 * (2**30) # 80 GB in bytes"
82+
"gpu_memory = 48 * (2**30) # 48 GB in bytes"
8383
]
8484
},
8585
{

src/training_hub/profiling/memory_estimator.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from click import FLOAT
2-
import torch
3-
from typing import Callable, Optional, override
1+
from typing import override
42
from transformers import AutoModel
5-
from transformers.models.perceiver.modeling_perceiver import PerceiverMultimodalPreprocessor
63

74
"""
85
Code assisted by Cursor/Claude4
@@ -14,8 +11,8 @@
1411
FLOAT8_BYTES_N: int = 1
1512
ADAMW_PARAMS_N: int = 2
1613

17-
# Helper lambda to do the rounding when printing
18-
ROUNDER = lambda value : str(round(value / 1073741824, 1))
14+
# Helper function to do the rounding when printing
15+
def ROUNDER(value: int) -> str: return str(round(value / 1073741824, 1))
1916

2017
class BasicEstimator:
2118
"""
@@ -43,11 +40,12 @@ def __init__(
4340
num_gpus: int = 8,
4441
gpu_memory: int = 85899345920,
4542
model_path: str = "ibm-granite/granite-3.3-8b-instruct",
46-
effective_batch_size: int = None,
47-
max_seq_len: int = None,
48-
max_tokens_per_gpu: int = None,
43+
effective_batch_size: int | None = None,
44+
max_seq_len: int | None = None,
45+
max_tokens_per_gpu: int | None = None,
4946
use_liger: bool = False,
5047
verbose: int = 1,
48+
trust_remote_code: bool = False,
5149
):
5250
self.num_gpus = num_gpus
5351
self.gpu_memory = gpu_memory
@@ -56,7 +54,7 @@ def __init__(
5654
self.verbose = verbose
5755

5856
# Load model directly
59-
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
57+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote_code)
6058

6159
# Determine parameters needed for calculations
6260
self.num_params: int = self.model.num_parameters(only_trainable=False)
@@ -278,9 +276,9 @@ def __init__(
278276
num_gpus: int = 8,
279277
gpu_memory: int = 85899345920,
280278
model_path: str = "ibm-granite/granite-3.3-8b-instruct",
281-
effective_batch_size: int = None,
282-
max_seq_len: int = None,
283-
max_tokens_per_gpu: int = None,
279+
effective_batch_size: int | None = None,
280+
max_seq_len: int | None = None,
281+
max_tokens_per_gpu: int | None = None,
284282
use_liger: bool = False,
285283
verbose: int = 1,
286284
):
@@ -348,11 +346,12 @@ def estimate(
348346
num_gpus: int = 8,
349347
gpu_memory: int = 85899345920,
350348
model_path: str = "ibm-granite/granite-3.3-8b-instruct",
351-
effective_batch_size: int = None,
352-
max_seq_len: int = None,
353-
max_tokens_per_gpu: int = None,
349+
effective_batch_size: int | None = None,
350+
max_seq_len: int | None = None,
351+
max_tokens_per_gpu: int | None = None,
354352
use_liger: bool = False,
355353
verbose: int = 1,
354+
trust_remote_code: bool = False
356355
):
357356
"""
358357
Convenience function for performing estimation
@@ -383,7 +382,25 @@ def estimate(
383382
"""
384383

385384
if training_method.lower() == "osft":
386-
estimator = OSFTEstimator(num_gpus, gpu_memory, model_path, effective_batch_size, max_seq_len, max_tokens_per_gpu, use_liger, verbose)
385+
estimator = OSFTEstimator(num_gpus,
386+
gpu_memory,
387+
model_path,
388+
effective_batch_size,
389+
max_seq_len,
390+
max_tokens_per_gpu,
391+
use_liger,
392+
verbose,
393+
trust_remote_code,
394+
)
387395
else:
388-
estimator = BasicEstimator(num_gpus, gpu_memory, model_path, effective_batch_size, max_seq_len, max_tokens_per_gpu, use_liger, verbose)
396+
estimator = BasicEstimator(num_gpus,
397+
gpu_memory,
398+
model_path,
399+
effective_batch_size,
400+
max_seq_len,
401+
max_tokens_per_gpu,
402+
use_liger,
403+
verbose,
404+
trust_remote_code
405+
)
389406
return estimator.estimate()

0 commit comments

Comments
 (0)