We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a97cf1f commit 5db9e58Copy full SHA for 5db9e58
1 file changed
monai/losses/focal_loss.py
@@ -212,8 +212,8 @@ def softmax_focal_loss(
212
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
213
s_j is the unnormalized score for class j.
214
"""
215
- pt = input.softmax(1)
216
- loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target
+ input_ls = input.log_softmax(1)
+ loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target
217
218
if alpha is not None:
219
# (1-alpha) for the background class and alpha for the other classes
0 commit comments