Skip to content

Commit c349450

Browse files
authored
fix: add score_function in saving best models (#145)
1 parent a0e9ce4 commit c349450

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

src/templates/template-vision-classification/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def setup_handlers(
169169
filename_prefix="best",
170170
n_saved=config.n_saved,
171171
global_step_transform=global_step_transform,
172+
score_name="eval_accuracy",
173+
score_function=Checkpoint.get_default_score_fn("eval_accuracy"),
172174
)
173175
evaluator.add_event_handler(
174176
Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval

src/templates/template-vision-dcgan/trainers.py

-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def eval_function(engine: Engine, batch: Any):
158158
metrics = {
159159
"epoch": engine.state.epoch,
160160
"errD": errD.item(),
161-
"eval_loss": errD.item(),
162161
"errG": errG.item(),
163162
"D_x": D_x,
164163
"D_G_z1": D_G_z1,

src/templates/template-vision-dcgan/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def setup_handlers(
169169
filename_prefix="best",
170170
n_saved=config.n_saved,
171171
global_step_transform=global_step_transform,
172+
score_name="model_d_error",
173+
score_function=Checkpoint.get_default_score_fn("errD", -1),
172174
)
173175
evaluator.add_event_handler(
174176
Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval
@@ -179,7 +181,7 @@ def setup_handlers(
179181
#::: if (it.patience) { :::#
180182
# early stopping
181183
def score_fn(engine: Engine):
182-
return -engine.state.metrics["eval_loss"]
184+
return -engine.state.metrics["errD"]
183185

184186
es = EarlyStopping(config.patience, score_fn, trainer)
185187
evaluator.add_event_handler(Events.EPOCH_COMPLETED, es)

0 commit comments

Comments
 (0)