Skip to content

Commit 8f6936d

Browse files
authored
Fix segmentation Dice + GeneralizedDice for 2d index tensors (#2832)
1 parent e2543c8 commit 8f6936d

File tree

6 files changed

+22
-9
lines changed

6 files changed

+22
-9
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5050

5151
- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))
5252

53+
5354
### Fixed
5455

56+
- Fixed segmentation `Dice` + `GeneralizedDice` for 2d index tensors ([#2832](https://github.com/Lightning-AI/torchmetrics/pull/2832))
57+
58+
5559
- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))
5660

5761

src/torchmetrics/functional/segmentation/dice.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def _dice_score_update(
4949
) -> tuple[Tensor, Tensor, Tensor]:
5050
"""Update the state with the current prediction and target."""
5151
_check_same_shape(preds, target)
52-
if preds.ndim < 3:
53-
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
5452

5553
if input_format == "index":
5654
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
5755
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
5856

57+
if preds.ndim < 3:
58+
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
59+
5960
if not include_background:
6061
preds, target = _ignore_background(preds, target)
6162

src/torchmetrics/functional/segmentation/generalized_dice.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Tuple
15+
1416
import torch
1517
from torch import Tensor
1618
from typing_extensions import Literal
@@ -49,16 +51,17 @@ def _generalized_dice_update(
4951
include_background: bool,
5052
weight_type: Literal["square", "simple", "linear"] = "square",
5153
input_format: Literal["one-hot", "index"] = "one-hot",
52-
) -> Tensor:
54+
) -> Tuple[Tensor, Tensor]:
5355
"""Update the state with the current prediction and target."""
5456
_check_same_shape(preds, target)
55-
if preds.ndim < 3:
56-
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
5757

5858
if input_format == "index":
5959
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
6060
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
6161

62+
if preds.ndim < 3:
63+
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
64+
6265
if not include_background:
6366
preds, target = _ignore_background(preds, target)
6467

@@ -67,7 +70,6 @@ def _generalized_dice_update(
6770
target_sum = torch.sum(target, dim=reduce_axis)
6871
pred_sum = torch.sum(preds, dim=reduce_axis)
6972
cardinality = target_sum + pred_sum
70-
7173
if weight_type == "simple":
7274
weights = 1.0 / target_sum
7375
elif weight_type == "linear":
@@ -89,7 +91,7 @@ def _generalized_dice_update(
8991

9092
numerator = 2.0 * intersection * weights
9193
denominator = cardinality * weights
92-
return numerator, denominator # type:ignore[return-value]
94+
return numerator, denominator
9395

9496

9597
def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor:

tests/unittests/segmentation/inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,7 @@
3434
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
3535
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
3636
)
37+
_input4 = _Input(
38+
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
39+
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
40+
)

tests/unittests/segmentation/test_dice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from unittests import NUM_CLASSES
2323
from unittests._helpers import seed_all
2424
from unittests._helpers.testers import MetricTester
25-
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
25+
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3
2626

2727
seed_all(42)
2828

@@ -55,6 +55,7 @@ def _reference_dice_score(
5555
(_inputs1.preds, _inputs1.target, "one-hot"),
5656
(_inputs2.preds, _inputs2.target, "one-hot"),
5757
(_inputs3.preds, _inputs3.target, "index"),
58+
(_input4.preds, _input4.target, "index"),
5859
],
5960
)
6061
@pytest.mark.parametrize("include_background", [True, False])

tests/unittests/segmentation/test_generalized_dice_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from unittests import NUM_CLASSES
2525
from unittests._helpers import seed_all
2626
from unittests._helpers.testers import MetricTester
27-
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
27+
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3
2828

2929
seed_all(42)
3030

@@ -53,6 +53,7 @@ def _reference_generalized_dice(
5353
(_inputs1.preds, _inputs1.target, "one-hot"),
5454
(_inputs2.preds, _inputs2.target, "one-hot"),
5555
(_inputs3.preds, _inputs3.target, "index"),
56+
(_input4.preds, _input4.target, "index"),
5657
],
5758
)
5859
@pytest.mark.parametrize("include_background", [True, False])

0 commit comments

Comments
 (0)