Skip to content

Commit b1a54c1

Browse files
makes saving every save_samples an optional feature (#165)
* makes saving every save_samples an optional feature Signed-off-by: James Kunstle <[email protected]> * Update help text Signed-off-by: Mustafa Eyceoz <[email protected]> --------- Signed-off-by: James Kunstle <[email protected]> Signed-off-by: Mustafa Eyceoz <[email protected]> Co-authored-by: Mustafa Eyceoz <[email protected]>
1 parent 3a91777 commit b1a54c1

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/instructlab/training/main_ds.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,15 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
336336
world_size = int(os.environ["WORLD_SIZE"])
337337

338338
batch_size = args.effective_batch_size // grad_accum
339-
args.save_samples = (args.save_samples // batch_size) * batch_size
340-
(
341-
print(f"\033[93mNumber of samples per save: {args.save_samples}\033[0m")
342-
if local_rank == 0
343-
else None
344-
)
339+
340+
if args.save_samples > 0:
341+
args.save_samples = (args.save_samples // batch_size) * batch_size
342+
(
343+
print(f"\033[93mNumber of samples per save: {args.save_samples}\033[0m")
344+
if local_rank == 0
345+
else None
346+
)
347+
345348
if args.save_samples_ds is not None:
346349
args.save_samples_ds = (args.save_samples_ds // batch_size) * batch_size
347350
(
@@ -439,7 +442,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
439442
}
440443
)
441444

442-
if global_step * batch_size % args.save_samples == 0:
445+
if args.save_samples > 0 and (
446+
global_step * batch_size % args.save_samples == 0
447+
):
443448
save_hf_format_ds(
444449
args,
445450
model,
@@ -736,7 +741,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
736741
)
737742
parser.add_argument("--num_warmup_steps", type=int, default=1000)
738743
# parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
739-
parser.add_argument("--save_samples", type=int)
744+
parser.add_argument(
745+
"--save_samples",
746+
type=int,
747+
help="The number of samples seen between each checkpoint save. If --save_samples<=0, this feature is disabled.",
748+
)
740749
parser.add_argument(
741750
"--save_samples_ds",
742751
type=int,

0 commit comments

Comments
 (0)