Skip to content

Add best_k_metrics parameter to the ModelCheckpoint #20457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
23 changes: 18 additions & 5 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import (
WarningCache,
rank_zero_info,
rank_zero_warn,
)
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_fabric.utilities.cloud_io import (
_is_dir,
_is_local_file_protocol,
get_filesystem,
)
from lightning_fabric.utilities.types import _PATH

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -244,6 +252,7 @@ def __init__(
self.best_k_models: dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
self.best_model_metrics: Optional[dict[str, Tensor]] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
Expand Down Expand Up @@ -339,6 +348,7 @@ def state_dict(self) -> dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
"best_model_metrics": self.best_model_metrics,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
Expand All @@ -354,15 +364,16 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:

if self.dirpath == dirpath_from_ckpt:
self.best_model_score = state_dict["best_model_score"]
self.best_model_metrics = state_dict["best_model_metrics"]
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = state_dict.get("kth_value", self.kth_value)
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
else:
warnings.warn(
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`,"
" `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded."
)

self.best_model_path = state_dict["best_model_path"]
Expand Down Expand Up @@ -746,6 +757,8 @@ def _update_best_and_save(
_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]
if self.best_model_path == filepath:
self.best_model_metrics = monitor_candidates

if self.verbose:
epoch = monitor_candidates["epoch"]
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torch import optim
from torch.utils.data.dataloader import DataLoader

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -703,6 +702,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog):
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt")
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt")
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add tests that exercise the new code

assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""

Expand Down Expand Up @@ -809,6 +809,7 @@ def test_model_checkpoint_topk_zero(tmp_path):
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == ""
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""
# check that only the last ckpt was created
Expand Down Expand Up @@ -1074,7 +1075,7 @@ def assert_checkpoint_log_dir(idx):

# load from checkpoint
trainer_config["logger"] = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(**trainer_config)
trainer = Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
Expand Down
Loading