Skip to content

Commit 77acabd

Browse files
committed
Add rounding error check to _maybe_log_save_evaluate
1 parent aa798b7 commit 77acabd

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3077,7 +3077,11 @@ def _maybe_log_save_evaluate(
30773077
# reset tr_loss to zero
30783078
tr_loss -= tr_loss
30793079

3080-
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
3080+
log_eval = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
3081+
if log_eval > 1e-4:
3082+
logs["loss"] = str(round(log_eval))
3083+
else:
3084+
logs["loss"] = "{:e}".format(log_eval)
30813085
if grad_norm is not None:
30823086
logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
30833087
if learning_rate is not None:

0 commit comments

Comments
 (0)