Skip to content

Commit 5db9e58

Browse files
committed
improve algorithm
1 parent a97cf1f commit 5db9e58

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

monai/losses/focal_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def softmax_focal_loss(
212212
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
213213
s_j is the unnormalized score for class j.
214214
"""
215-
pt = input.softmax(1)
216-
loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target
215+
input_ls = input.log_softmax(1)
216+
loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target
217217

218218
if alpha is not None:
219219
# (1-alpha) for the background class and alpha for the other classes

0 commit comments

Comments
 (0)