Skip to content

Commit 3a91777

Browse files
authored
adds epoch-boundary checkpoint saving (#160)
Currently, we save checkpoints: 1. Whenever we pass enough samples, and 2. sometimes at the end of training. This adds saving per-epoch, so one could set save_samples really high and ONLY save at the boundary of epochs. Signed-off-by: James Kunstle <[email protected]>
1 parent 3170300 commit 3a91777

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/instructlab/training/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TrainingArgs(BaseModel):
139139
warmup_steps: int
140140
is_padding_free: bool
141141
random_seed: int = 42
142+
checkpoint_at_epoch: bool = False
142143

143144
mock_data: Optional[bool] = False
144145
mock_data_len: int = 0

src/instructlab/training/main_ds.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,15 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
463463
if local_rank == 0:
464464
inner_pb.update(1)
465465
torch.cuda.empty_cache()
466+
467+
if args.checkpoint_at_epoch:
468+
save_hf_format_ds(
469+
args,
470+
model,
471+
tokenizer,
472+
global_step * args.samples_per_gpu * world_size,
473+
is_lora=bool(args.lora_r),
474+
)
466475
if args.save_last:
467476
save_hf_format_ds(
468477
args,
@@ -615,6 +624,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
615624
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
616625
]
617626

627+
if train_args.checkpoint_at_epoch:
628+
command.append("--checkpoint_at_epoch")
629+
618630
if train_args.mock_data:
619631
command.append("--mock_data")
620632
if train_args.mock_len:
@@ -734,6 +746,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
734746
parser.add_argument(
735747
"--save_last", action="store_true", help="save after finishing training"
736748
)
749+
parser.add_argument(
750+
"--checkpoint_at_epoch",
751+
action="store_true",
752+
help="Save a model checkpoint after finishing an epoch.",
753+
)
737754
parser.add_argument("--log_level", type=str, default="INFO")
738755
parser.add_argument("--seed", type=int, default=42)
739756
parser.add_argument("--mock_data", action="store_true")

0 commit comments

Comments
 (0)