Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 1bb218f

Browse files
MultiPathfacebook-github-bot
authored andcommitted
Split from PR#968. add --keep-best-checkpoints (#990)
Summary: Fixes fairinternal/fairseq-py#968. Split --keep-best-checkpoints from the original request. Use scores as the names to save the checkpoints Pull Request resolved: fairinternal/fairseq-py#990 Differential Revision: D19411250 Pulled By: MultiPath fbshipit-source-id: 82b0db614208eee54c9c0e470ad7faa6481747d5
1 parent cf8676f commit 1bb218f

2 files changed

Lines changed: 23 additions & 3 deletions

File tree

fairseq/checkpoint_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def is_better(a, b):
5353
not hasattr(save_checkpoint, "best")
5454
or is_better(val_loss, save_checkpoint.best)
5555
)
56+
checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
57+
args.best_checkpoint_metric, val_loss)] = (
58+
val_loss is not None
59+
and args.keep_best_checkpoints > 0
60+
and (
61+
not hasattr(save_checkpoint, "best")
62+
or is_better(val_loss, save_checkpoint.best)
63+
)
64+
)
5665
checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints
5766

5867
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
@@ -69,8 +78,8 @@ def is_better(a, b):
6978

7079
write_timer.stop()
7180
print(
72-
"| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)".format(
73-
checkpoints[0], epoch, updates, write_timer.sum
81+
"| saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
82+
checkpoints[0], epoch, updates, val_loss, write_timer.sum
7483
)
7584
)
7685

@@ -90,6 +99,15 @@ def is_better(a, b):
9099
if os.path.lexists(old_chk):
91100
os.remove(old_chk)
92101

102+
if args.keep_best_checkpoints > 0:
103+
# only keep the best N checkpoints according to validation metric
104+
checkpoints = checkpoint_paths(
105+
args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(args.best_checkpoint_metric))
106+
if not args.maximize_best_checkpoint_metric:
107+
checkpoints = checkpoints[::-1]
108+
for old_chk in checkpoints[args.keep_best_checkpoints:]:
109+
if os.path.lexists(old_chk):
110+
os.remove(old_chk)
93111

94112
def load_checkpoint(args, trainer, **passthrough_args):
95113
"""
@@ -202,7 +220,7 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
202220
for i, f in enumerate(files):
203221
m = pt_regexp.fullmatch(f)
204222
if m is not None:
205-
idx = int(m.group(1)) if len(m.groups()) > 0 else i
223+
idx = float(m.group(1)) if len(m.groups()) > 0 else i
206224
entries.append((idx, m.group(0)))
207225
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
208226

fairseq/options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ def add_checkpoint_args(parser):
403403
help='keep the last N checkpoints saved with --save-interval-updates')
404404
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
405405
help='keep last N epoch checkpoints')
406+
group.add_argument('--keep-best-checkpoints', type=int, default=-1, metavar='N',
407+
help='keep best N checkpoints based on scores')
406408
group.add_argument('--no-save', action='store_true',
407409
help='don\'t save models or checkpoints')
408410
group.add_argument('--no-epoch-checkpoints', action='store_true',

0 commit comments

Comments
 (0)