Skip to content

Commit 81af139

Browse files
committed
minor fix: 1. Add stacklevel=2 to warning. 2. Pass include_background to sub-losses. 3. let sub-losses handle include_background.
Signed-off-by: ytl0623 <[email protected]>
1 parent 836ce42 commit 81af139

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,11 @@ def __init__(
207207
self.gamma = gamma
208208
self.delta = delta
209209
self.weight: float = weight
210-
self.asy_focal_loss = AsymmetricFocalLoss(to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta)
210+
self.asy_focal_loss = AsymmetricFocalLoss(
211+
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background
212+
)
211213
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
212-
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta
214+
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background
213215
)
214216
self.include_background = include_background
215217
self.use_softmax = use_softmax
@@ -251,14 +253,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
251253
else:
252254
y_true = one_hot(y_true, num_classes=n_pred_ch)
253255

254-
if not self.include_background:
255-
if n_pred_ch == 1:
256-
warnings.warn("single channel prediction, `include_background=False` ignored.")
257-
else:
258-
# if skipping background, removing first channel
259-
y_pred = y_pred[:, 1:]
260-
y_true = y_true[:, 1:]
261-
262256
if self.use_softmax:
263257
y_pred = torch.softmax(y_pred.float(), dim=1)
264258
else:

0 commit comments

Comments
 (0)