@@ -336,12 +336,15 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
336336 world_size = int (os .environ ["WORLD_SIZE" ])
337337
338338 batch_size = args .effective_batch_size // grad_accum
339- args .save_samples = (args .save_samples // batch_size ) * batch_size
340- (
341- print (f"\033 [93mNumber of samples per save: { args .save_samples } \033 [0m" )
342- if local_rank == 0
343- else None
344- )
339+
340+ if args .save_samples > 0 :
341+ args .save_samples = (args .save_samples // batch_size ) * batch_size
342+ (
343+ print (f"\033 [93mNumber of samples per save: { args .save_samples } \033 [0m" )
344+ if local_rank == 0
345+ else None
346+ )
347+
345348 if args .save_samples_ds is not None :
346349 args .save_samples_ds = (args .save_samples_ds // batch_size ) * batch_size
347350 (
@@ -439,7 +442,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
439442 }
440443 )
441444
442- if global_step * batch_size % args .save_samples == 0 :
445+ if args .save_samples > 0 and (
446+ global_step * batch_size % args .save_samples == 0
447+ ):
443448 save_hf_format_ds (
444449 args ,
445450 model ,
@@ -736,7 +741,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
736741 )
737742 parser .add_argument ("--num_warmup_steps" , type = int , default = 1000 )
738743 # parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
739- parser .add_argument ("--save_samples" , type = int )
744+ parser .add_argument (
745+ "--save_samples" ,
746+ type = int ,
747+ help = "The number of samples seen between each checkpoint save. If --save_samples<=0, this feature is disabled." ,
748+ )
740749 parser .add_argument (
741750 "--save_samples_ds" ,
742751 type = int ,
0 commit comments