Skip to content

Commit 691d1b5

Browse files
authored
Fix/best model checkpoint fix (#35885)
* Set best_model_checkpoint only when ckpt exists. Rather than set it explicitly without checking if the checkpoint directory even exists as before, now we moved the setting logic inside of _save_checkpoint and are only setting it if it exists. * Added best_global_step to TrainerState. * Added tests for best_model_checkpoint. * Fixed hard-coded values in test to prevent fail. * Added helper func and removed hard-coded best_step. * Added side effect patch generator for _eval. * Added evaluate side effect func. * Removed erroneous patching. * Fixed minor bug. * Applied Ruff. * Fixed Ruff problem in make style. * Used Trainer.set_initial_training_values.
1 parent 3bd1a0d commit 691d1b5

File tree

4 files changed

+231
-6
lines changed

4 files changed

+231
-6
lines changed

src/transformers/testing_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from functools import wraps
3939
from io import StringIO
4040
from pathlib import Path
41-
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
41+
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
4242
from unittest import mock
4343
from unittest.mock import patch
4444

@@ -48,6 +48,7 @@
4848
from huggingface_hub import delete_repo
4949
from packaging import version
5050

51+
from transformers import Trainer
5152
from transformers import logging as transformers_logging
5253

5354
from .integrations import (
@@ -1440,6 +1441,34 @@ def get_tests_dir(append_path=None):
14401441
return tests_dir
14411442

14421443

1444+
def get_steps_per_epoch(trainer: Trainer) -> int:
1445+
training_args = trainer.args
1446+
train_dataloader = trainer.get_train_dataloader()
1447+
1448+
initial_training_values = trainer.set_initial_training_values(
1449+
args=training_args,
1450+
dataloader=train_dataloader,
1451+
total_train_batch_size=training_args.per_device_train_batch_size,
1452+
)
1453+
steps_per_epoch = initial_training_values[1]
1454+
1455+
return steps_per_epoch
1456+
1457+
1458+
def evaluate_side_effect_factory(
1459+
side_effect_values: List[Dict[str, float]],
1460+
) -> Generator[Dict[str, float], None, None]:
1461+
"""
1462+
Function that returns side effects for the _evaluate method.
1463+
Used when we're unsure of exactly how many times _evaluate will be called.
1464+
"""
1465+
for side_effect_value in side_effect_values:
1466+
yield side_effect_value
1467+
1468+
while True:
1469+
yield side_effect_values[-1]
1470+
1471+
14431472
#
14441473
# Helper functions for dealing with testing text outputs
14451474
# The original code came from:

src/transformers/trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3178,12 +3178,10 @@ def _determine_best_metric(self, metrics, trial):
31783178
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
31793179

31803180
if operator(metric_value, self.state.best_metric):
3181-
run_dir = self._get_output_dir(trial=trial)
3182-
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
3183-
output_dir = os.path.join(run_dir, checkpoint_folder)
3184-
31853181
self.state.best_metric = metric_value
3186-
self.state.best_model_checkpoint = output_dir
3182+
3183+
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
3184+
self.state.best_global_step = self.state.global_step
31873185

31883186
is_new_best_metric = True
31893187

@@ -3204,6 +3202,13 @@ def _save_checkpoint(self, model, trial):
32043202
output_dir = os.path.join(run_dir, checkpoint_folder)
32053203
self.save_model(output_dir, _internal_call=True)
32063204

3205+
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
3206+
best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
3207+
best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
3208+
3209+
if os.path.exists(best_checkpoint_dir):
3210+
self.state.best_model_checkpoint = best_checkpoint_dir
3211+
32073212
if not self.args.save_only_model:
32083213
# Save optimizer and scheduler
32093214
self._save_optimizer_and_scheduler(output_dir)

src/transformers/trainer_callback.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class TrainerState:
7474
The list of logs done since the beginning of training.
7575
best_metric (`float`, *optional*):
7676
When tracking the best model, the value of the best metric encountered so far.
77+
best_global_step (`int`, *optional*):
78+
When tracking the best model, the step at which the best metric was encountered.
79+
Used for setting `best_model_checkpoint`.
7780
best_model_checkpoint (`str`, *optional*):
7881
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
7982
far.
@@ -103,6 +106,7 @@ class TrainerState:
103106
total_flos: float = 0
104107
log_history: List[Dict[str, float]] = None
105108
best_metric: Optional[float] = None
109+
best_global_step: Optional[int] = None
106110
best_model_checkpoint: Optional[str] = None
107111
is_local_process_zero: bool = True
108112
is_world_process_zero: bool = True

tests/trainer/test_trainer.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@
6262
TemporaryHubRepo,
6363
TestCasePlus,
6464
backend_device_count,
65+
evaluate_side_effect_factory,
6566
execute_subprocess_async,
6667
get_gpu_count,
68+
get_steps_per_epoch,
6769
get_tests_dir,
6870
is_staging_test,
6971
require_accelerate,
@@ -4710,6 +4712,191 @@ def test_metric_for_best_model_behavior(self):
47104712
)
47114713
self.assertTrue(trainer.args.metric_for_best_model == "loss")
47124714

4715+
def test_best_model_checkpoint_behavior(self):
4716+
# Case 1. Never evaluated, save_total_limit > 1 and save_steps == 1.
4717+
# Both best_metric and best_model_checkpoint should be None.
4718+
with tempfile.TemporaryDirectory() as tmpdir:
4719+
trainer = get_regression_trainer(
4720+
output_dir=tmpdir,
4721+
eval_strategy="steps",
4722+
save_strategy="steps",
4723+
save_steps=1,
4724+
metric_for_best_model="accuracy",
4725+
greater_is_better=True,
4726+
)
4727+
trainer.train()
4728+
4729+
assert trainer.state.best_metric is None
4730+
assert trainer.state.best_model_checkpoint is None
4731+
assert len(os.listdir(tmpdir)) == trainer.state.global_step
4732+
4733+
# Case 2. Never evaluated and save_total_limit == 1.
4734+
# Both best_metric and best_model_checkpoint should be None.
4735+
# Only the last checkpoint should remain.
4736+
with tempfile.TemporaryDirectory() as tmpdir:
4737+
trainer = get_regression_trainer(
4738+
output_dir=tmpdir,
4739+
eval_strategy="steps",
4740+
save_strategy="steps",
4741+
save_steps=1,
4742+
metric_for_best_model="accuracy",
4743+
greater_is_better=True,
4744+
save_total_limit=1,
4745+
)
4746+
trainer.train()
4747+
4748+
num_steps = trainer.state.global_step
4749+
4750+
assert trainer.state.best_metric is None
4751+
assert trainer.state.best_model_checkpoint is None
4752+
assert len(os.listdir(tmpdir)) == 1
4753+
4754+
ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{num_steps}")
4755+
assert os.path.isdir(ckpt)
4756+
assert os.listdir(tmpdir)[0] == f"{PREFIX_CHECKPOINT_DIR}-{num_steps}"
4757+
4758+
# Case 3. eval_strategy == save_strategy.
4759+
# best_model_checkpoint should be at epoch 1.
4760+
with tempfile.TemporaryDirectory() as tmpdir:
4761+
trainer = get_regression_trainer(
4762+
output_dir=tmpdir,
4763+
eval_strategy="epoch",
4764+
save_strategy="epoch",
4765+
metric_for_best_model="accuracy",
4766+
compute_metrics=AlmostAccuracy(),
4767+
greater_is_better=True,
4768+
load_best_model_at_end=False,
4769+
)
4770+
4771+
with patch.object(
4772+
trainer,
4773+
"_evaluate",
4774+
side_effect=evaluate_side_effect_factory(
4775+
[
4776+
{"eval_accuracy": 0.59},
4777+
{"eval_accuracy": 0.57},
4778+
{"eval_accuracy": 0.55},
4779+
]
4780+
),
4781+
):
4782+
trainer.train()
4783+
4784+
steps_per_epoch = get_steps_per_epoch(trainer)
4785+
4786+
assert trainer.state.best_metric == 0.59
4787+
assert trainer.state.best_global_step == steps_per_epoch
4788+
4789+
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
4790+
assert trainer.state.best_model_checkpoint == best_ckpt
4791+
4792+
assert len(os.listdir(tmpdir)) == trainer.state.num_train_epochs
4793+
4794+
# Case 4. eval_strategy != save_strategy.
4795+
with tempfile.TemporaryDirectory() as tmpdir:
4796+
trainer = get_regression_trainer(
4797+
output_dir=tmpdir,
4798+
eval_strategy="epoch",
4799+
save_strategy="steps",
4800+
save_steps=1,
4801+
metric_for_best_model="accuracy",
4802+
compute_metrics=AlmostAccuracy(),
4803+
greater_is_better=True,
4804+
load_best_model_at_end=False,
4805+
)
4806+
4807+
with patch.object(
4808+
trainer,
4809+
"_evaluate",
4810+
side_effect=evaluate_side_effect_factory(
4811+
[
4812+
{"eval_accuracy": 0.59},
4813+
{"eval_accuracy": 0.57},
4814+
{"eval_accuracy": 0.55},
4815+
]
4816+
),
4817+
):
4818+
trainer.train()
4819+
4820+
steps_per_epoch = get_steps_per_epoch(trainer)
4821+
4822+
assert trainer.state.best_metric == 0.59
4823+
assert trainer.state.best_global_step == steps_per_epoch
4824+
4825+
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
4826+
assert trainer.state.best_model_checkpoint == best_ckpt
4827+
4828+
assert len(os.listdir(tmpdir)) == trainer.state.global_step
4829+
4830+
# Case 5. Multiple checkpoints, save_total_limit == 1.
4831+
# Best metric is found at step 1 and that checkpoint should be saved.
4832+
with tempfile.TemporaryDirectory() as tmpdir:
4833+
trainer = get_regression_trainer(
4834+
output_dir=tmpdir,
4835+
eval_strategy="steps",
4836+
eval_steps=1,
4837+
save_strategy="steps",
4838+
save_steps=1,
4839+
metric_for_best_model="accuracy",
4840+
compute_metrics=AlmostAccuracy(),
4841+
greater_is_better=True,
4842+
save_total_limit=1,
4843+
)
4844+
4845+
with patch.object(
4846+
trainer,
4847+
"_evaluate",
4848+
side_effect=evaluate_side_effect_factory(
4849+
[
4850+
{"eval_accuracy": 0.90},
4851+
{"eval_accuracy": 0.80},
4852+
{"eval_accuracy": 0.70},
4853+
]
4854+
),
4855+
):
4856+
trainer.train()
4857+
4858+
assert trainer.state.best_metric == 0.90
4859+
assert trainer.state.best_global_step == 1
4860+
4861+
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
4862+
assert trainer.state.best_model_checkpoint == best_ckpt
4863+
4864+
assert len(os.listdir(tmpdir)) == 1
4865+
4866+
# Case 6. Saving happens more often and eval/save mismatch.
4867+
# `best_model_checkpoint` should be None due to a step mismatch.
4868+
with tempfile.TemporaryDirectory() as tmpdir:
4869+
trainer = get_regression_trainer(
4870+
output_dir=tmpdir,
4871+
eval_strategy="steps",
4872+
eval_steps=3,
4873+
save_strategy="steps",
4874+
save_steps=2,
4875+
metric_for_best_model="accuracy",
4876+
compute_metrics=AlmostAccuracy(),
4877+
greater_is_better=True,
4878+
)
4879+
4880+
with patch.object(
4881+
trainer,
4882+
"_evaluate",
4883+
side_effect=evaluate_side_effect_factory(
4884+
[
4885+
{"eval_accuracy": 0.90},
4886+
{"eval_accuracy": 0.80},
4887+
{"eval_accuracy": 0.70},
4888+
]
4889+
),
4890+
):
4891+
trainer.train()
4892+
4893+
assert trainer.state.best_metric == 0.90
4894+
assert trainer.state.best_global_step == 3
4895+
4896+
assert trainer.state.best_model_checkpoint is None
4897+
4898+
assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2
4899+
47134900

47144901
@require_torch
47154902
@is_staging_test

0 commit comments

Comments
 (0)