1- from click import FLOAT
2- import torch
3- from typing import Callable , Optional , override
1+ from typing import override
42from transformers import AutoModel
5- from transformers .models .perceiver .modeling_perceiver import PerceiverMultimodalPreprocessor
63
74"""
85Code assisted by Cursor/Claude4
1411FLOAT8_BYTES_N : int = 1
1512ADAMW_PARAMS_N : int = 2
1613
17- # Helper lambda to do the rounding when printing
18- ROUNDER = lambda value : str (round (value / 1073741824 , 1 ))
14+ # Helper function to do the rounding when printing
15+ def ROUNDER ( value : int ) -> str : return str (round (value / 1073741824 , 1 ))
1916
2017class BasicEstimator :
2118 """
@@ -43,11 +40,12 @@ def __init__(
4340 num_gpus : int = 8 ,
4441 gpu_memory : int = 85899345920 ,
4542 model_path : str = "ibm-granite/granite-3.3-8b-instruct" ,
46- effective_batch_size : int = None ,
47- max_seq_len : int = None ,
48- max_tokens_per_gpu : int = None ,
43+ effective_batch_size : int | None = None ,
44+ max_seq_len : int | None = None ,
45+ max_tokens_per_gpu : int | None = None ,
4946 use_liger : bool = False ,
5047 verbose : int = 1 ,
48+ trust_remote_code : bool = False ,
5149 ):
5250 self .num_gpus = num_gpus
5351 self .gpu_memory = gpu_memory
@@ -56,7 +54,7 @@ def __init__(
5654 self .verbose = verbose
5755
5856 # Load model directly
59- self .model = AutoModel .from_pretrained (model_path , trust_remote_code = True )
57+ self .model = AutoModel .from_pretrained (model_path , trust_remote_code = trust_remote_code )
6058
6159 # Determine parameters needed for calculations
6260 self .num_params : int = self .model .num_parameters (only_trainable = False )
@@ -278,9 +276,9 @@ def __init__(
278276 num_gpus : int = 8 ,
279277 gpu_memory : int = 85899345920 ,
280278 model_path : str = "ibm-granite/granite-3.3-8b-instruct" ,
281- effective_batch_size : int = None ,
282- max_seq_len : int = None ,
283- max_tokens_per_gpu : int = None ,
279+ effective_batch_size : int | None = None ,
280+ max_seq_len : int | None = None ,
281+ max_tokens_per_gpu : int | None = None ,
284282 use_liger : bool = False ,
285283 verbose : int = 1 ,
286284 ):
@@ -348,11 +346,12 @@ def estimate(
348346 num_gpus : int = 8 ,
349347 gpu_memory : int = 85899345920 ,
350348 model_path : str = "ibm-granite/granite-3.3-8b-instruct" ,
351- effective_batch_size : int = None ,
352- max_seq_len : int = None ,
353- max_tokens_per_gpu : int = None ,
349+ effective_batch_size : int | None = None ,
350+ max_seq_len : int | None = None ,
351+ max_tokens_per_gpu : int | None = None ,
354352 use_liger : bool = False ,
355353 verbose : int = 1 ,
354+ trust_remote_code : bool = False
356355 ):
357356 """
358357 Convenience function for performing estimation
@@ -383,7 +382,25 @@ def estimate(
383382 """
384383
385384 if training_method .lower () == "osft" :
386- estimator = OSFTEstimator (num_gpus , gpu_memory , model_path , effective_batch_size , max_seq_len , max_tokens_per_gpu , use_liger , verbose )
385+ estimator = OSFTEstimator (num_gpus ,
386+ gpu_memory ,
387+ model_path ,
388+ effective_batch_size ,
389+ max_seq_len ,
390+ max_tokens_per_gpu ,
391+ use_liger ,
392+ verbose ,
393+ trust_remote_code ,
394+ )
387395 else :
388- estimator = BasicEstimator (num_gpus , gpu_memory , model_path , effective_batch_size , max_seq_len , max_tokens_per_gpu , use_liger , verbose )
396+ estimator = BasicEstimator (num_gpus ,
397+ gpu_memory ,
398+ model_path ,
399+ effective_batch_size ,
400+ max_seq_len ,
401+ max_tokens_per_gpu ,
402+ use_liger ,
403+ verbose ,
404+ trust_remote_code
405+ )
389406 return estimator .estimate ()
0 commit comments