diff --git a/src/anomalib/metrics/per_image/__init__.py b/src/anomalib/metrics/per_image/__init__.py index 51d27f1d49..b98ea9fae6 100644 --- a/src/anomalib/metrics/per_image/__init__.py +++ b/src/anomalib/metrics/per_image/__init__.py @@ -7,7 +7,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .binclf_curve import per_image_binclf_curve, per_image_fpr, per_image_tpr from .binclf_curve_numpy import BinclfThreshsChoice from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult, aupimo_scores, pimo_curves from .utils import ( @@ -27,9 +26,6 @@ "PIMOResult", "AUPIMOResult", # functional interfaces - "per_image_binclf_curve", - "per_image_fpr", - "per_image_tpr", "pimo_curves", "aupimo_scores", # torchmetrics interfaces diff --git a/src/anomalib/metrics/per_image/binclf_curve.py b/src/anomalib/metrics/per_image/binclf_curve.py deleted file mode 100644 index 4635a641c9..0000000000 --- a/src/anomalib/metrics/per_image/binclf_curve.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Binary classification curve (torch interface). - -This module implements torch interfaces to access the numpy code in `binclf_curve_numpy.py`. - -Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. - -Tensors are build with `torch.from_numpy` and so the returned tensors will share the same memory as the numpy arrays. - -Validations will preferably happen in ndarray so the numpy code can be reused without torch, -so often times the Tensor arguments will be converted to ndarray and then validated. -""" - -# Original Code -# https://github.com/jpcbertoldo/aupimo -# -# Modified -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import torch -from torch import Tensor - -from . import _validate, binclf_curve_numpy -from .binclf_curve_numpy import BinclfThreshsChoice - -# =========================================== ARGS VALIDATION =========================================== - - -def _validate_is_threshs(threshs: Tensor) -> None: - _validate.is_tensor(threshs, argname="threshs") - _validate.is_threshs(threshs.numpy()) - - -def _validate_is_binclf_curves(binclf_curves: Tensor, valid_threshs: Tensor | None = None) -> None: - _validate.is_tensor(binclf_curves, argname="binclf_curves") - if valid_threshs is not None: - _validate_is_threshs(valid_threshs) - _validate.is_binclf_curves( - binclf_curves.detach().cpu().numpy(), - valid_threshs=valid_threshs.numpy() if valid_threshs is not None else None, - ) - - -# =========================================== FUNCTIONAL =========================================== - - -def per_image_binclf_curve( - anomaly_maps: Tensor, - masks: Tensor, - threshs_choice: BinclfThreshsChoice | str = BinclfThreshsChoice.MINMAX_LINSPACE, - threshs_given: Tensor | None = None, - num_threshs: int | None = None, -) -> tuple[Tensor, Tensor]: - """Compute the binary classification matrix of each image in the batch for multiple thresholds (shared). - - Note: tensors are converted to numpy arrays and then converted back to tensors (same device as `anomaly_maps`). - - Args: - anomaly_maps (Tensor): Anomaly score maps of shape (N, H, W [, D, ...]) - masks (Tensor): Binary ground truth masks of shape (N, H, W [, D, ...]) - threshs_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. - return_result_object (bool, optional): Whether to return a `PerImageBinClfCurveResult` object. Defaults to True. - - *** `threshs_choice`-dependent arguments *** - - THRESH_SEQUENCE_GIVEN - threshs_given (Tensor, optional): Sequence of thresholds to use. - - THRESH_SEQUENCE_MINMAX_LINSPACE - num_threshs (int, optional): Number of thresholds between the min and max of the anomaly maps. - - Returns: - tuple[Tensor, Tensor]: - [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. - - [1] Binary classification matrices of shape (N, K, 2, 2) - - N: number of images/instances - K: number of thresholds - - The last two dimensions are the confusion matrix (ground truth, predictions) - So for each thresh it gives: - - `tp`: `[... , 1, 1]` - - `fp`: `[... , 0, 1]` - - `fn`: `[... , 1, 0]` - - `tn`: `[... , 0, 0]` - - `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: - - `tp` stands for `true positive` - - `fp` stands for `false positive` - - `fn` stands for `false negative` - - `tn` stands for `true negative` - - The numbers in each confusion matrix are the counts of pixels in the image (not the ratios). - - Thresholds are shared across all images, so all confusion matrices, for instance, - at position [:, 0, :, :] are relative to the 1st threshold in `threshs`. - - Thresholds are sorted in ascending order. - """ - _validate.is_tensor(anomaly_maps, argname="anomaly_maps") - anomaly_maps_array = anomaly_maps.detach().cpu().numpy() - - _validate.is_tensor(masks, argname="masks") - masks_array = masks.detach().cpu().numpy() - - if threshs_given is not None: - _validate.is_tensor(threshs_given, argname="threshs_given") - threshs_given_array = threshs_given.detach().cpu().numpy() - else: - threshs_given_array = None - - threshs_array, binclf_curves_array = binclf_curve_numpy.per_image_binclf_curve( - anomaly_maps=anomaly_maps_array, - masks=masks_array, - threshs_choice=threshs_choice, - threshs_given=threshs_given_array, - num_threshs=num_threshs, - ) - threshs = torch.from_numpy(threshs_array).to(anomaly_maps.device) - binclf_curves = torch.from_numpy(binclf_curves_array).to(anomaly_maps.device).long() - - return threshs, binclf_curves - - -# =========================================== RATE METRICS =========================================== - - -def per_image_tpr(binclf_curves: Tensor) -> Tensor: - """Compute the true positive rates (TPR) for each image in the batch. - - Args: - binclf_curves (Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. - - Returns: - Tensor: True positive rates (TPR) of shape (N, K) - - N: number of images/instances - K: number of thresholds - - The last dimension is the TPR for each threshold. - - Thresholds are sorted in ascending order, so TPR is in descending order. - """ - _validate_is_binclf_curves(binclf_curves) - binclf_curves_array = binclf_curves.detach().cpu().numpy() - tprs_array = binclf_curve_numpy.per_image_tpr(binclf_curves_array) - return torch.from_numpy(tprs_array).to(binclf_curves.device) - - -def per_image_fpr(binclf_curves: Tensor) -> Tensor: - """Compute the false positive rates (FPR) for each image in the batch. - - Args: - binclf_curves (Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. - - Returns: - Tensor: False positive rates (FPR) of shape (N, K) - - N: number of images/instances - K: number of thresholds - - The last dimension is the FPR for each threshold. - - Thresholds are sorted in ascending order, so FPR is in descending order. - """ - _validate_is_binclf_curves(binclf_curves) - binclf_curves_array = binclf_curves.detach().cpu().numpy() - fprs_array = binclf_curve_numpy.per_image_fpr(binclf_curves_array) - return torch.from_numpy(fprs_array).to(binclf_curves.device) diff --git a/src/anomalib/metrics/per_image/pimo.py b/src/anomalib/metrics/per_image/pimo.py index 6c329f211a..bda8d800b6 100644 --- a/src/anomalib/metrics/per_image/pimo.py +++ b/src/anomalib/metrics/per_image/pimo.py @@ -39,26 +39,20 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import json import logging from collections.abc import Sequence from dataclasses import dataclass, field -from pathlib import Path import torch from torch import Tensor from torchmetrics import Metric -from anomalib.data.utils.image import duplicate_filename from anomalib.data.utils.path import validate_path -from . import _validate, pimo_numpy, utils -from .utils import StatsOutliersPolicy, StatsRepeatedPolicy +from . import _validate, pimo_numpy logger = logging.getLogger(__name__) -# =========================================== AUX =========================================== - def _images_classes_from_masks(masks: Tensor) -> Tensor: masks = torch.concat(masks, dim=0) @@ -256,60 +250,6 @@ def thresh_at(self, fpr_level: float) -> tuple[int, float, float]: fpr_level, ) - def to_dict(self) -> dict[str, Tensor | str]: - """Return a dictionary with the result object's attributes.""" - dic = { - "threshs": self.threshs, - "shared_fpr": self.shared_fpr, - "per_image_tprs": self.per_image_tprs, - } - if self.paths is not None: - dic["paths"] = self.paths - return dic - - @classmethod - def from_dict(cls: type["PIMOResult"], dic: dict[str, Tensor | str | list[str]]) -> "PIMOResult": - """Return a result object from a dictionary.""" - try: - return cls(**dic) # type: ignore[arg-type] - - except TypeError as ex: - msg = f"Invalid input dictionary for {cls.__name__} object. Cause: {ex}." - raise TypeError(msg) from ex - - def save(self, file_path: str | Path) -> None: - """Save to a `.pt` file. - - Args: - file_path: path to the `.pt` file where to save the PIMO result. - If the file already exists, a numerical suffix is added to the filename. - """ - validate_path(file_path, should_exist=False, accepted_extensions=(".pt",)) - file_path = duplicate_filename(file_path) - payload = self.to_dict() - torch.save(payload, file_path) - - @classmethod - def load(cls: type["PIMOResult"], file_path: str | Path) -> "PIMOResult": - """Load from a `.pt` file. - - Args: - file_path: path to the `.pt` file where to load the PIMO result. - """ - validate_path(file_path, accepted_extensions=(".pt",)) - payload = torch.load(file_path) - if not isinstance(payload, dict): - msg = f"Invalid content in file {file_path}. Must be a dictionary." - raise TypeError(msg) - # for compatibility with the original code - if "shared_fpr_metric" in payload: - del payload["shared_fpr_metric"] - try: - return cls.from_dict(payload) - except TypeError as ex: - msg = f"Invalid content in file {file_path}. Cause: {ex}." - raise TypeError(msg) from ex - @dataclass class AUPIMOResult: @@ -448,97 +388,8 @@ def from_pimoresult( paths=paths, ) - def to_dict(self) -> dict[str, Tensor | str | float | int]: - """Return a dictionary with the result object's attributes.""" - dic = { - "fpr_lower_bound": self.fpr_lower_bound, - "fpr_upper_bound": self.fpr_upper_bound, - "num_threshs": self.num_threshs, - "thresh_lower_bound": self.thresh_lower_bound, - "thresh_upper_bound": self.thresh_upper_bound, - "aupimos": self.aupimos, - } - if self.paths is not None: - dic["paths"] = self.paths - return dic - - @classmethod - def from_dict(cls: type["AUPIMOResult"], dic: dict[str, Tensor | str | float | int | list[str]]) -> "AUPIMOResult": - """Return a result object from a dictionary.""" - try: - return cls(**dic) # type: ignore[arg-type] - - except TypeError as ex: - msg = f"Invalid input dictionary for {cls.__name__} object. Cause: {ex}." - raise TypeError(msg) from ex - - def save(self, file_path: str | Path) -> None: - """Save to a `.json` file. - - Args: - file_path: path to the `.json` file where to save the AUPIMO result. - If the file already exists, a numerical suffix is added to the filename. - """ - validate_path(file_path, should_exist=False, accepted_extensions=(".json",)) - file_path = duplicate_filename(file_path) - file_path = Path(file_path) - payload = self.to_dict() - aupimos: Tensor = payload["aupimos"] - payload["aupimos"] = aupimos.numpy().tolist() - with file_path.open("w") as f: - json.dump(payload, f, indent=4) - - @classmethod - def load(cls: type["AUPIMOResult"], file_path: str | Path) -> "AUPIMOResult": - """Load from a `.json` file. - - Args: - file_path: path to the `.json` file where to load the AUPIMO result. - """ - validate_path(file_path, accepted_extensions=(".json",)) - file_path = Path(file_path) - with file_path.open("r") as f: - payload = json.load(f) - if not isinstance(payload, dict): - file_path = str(file_path) - msg = f"Invalid payload in file {file_path}. Must be a dictionary." - raise TypeError(msg) - payload["aupimos"] = torch.tensor(payload["aupimos"], dtype=torch.float64) - # for compatibility with the original code - if "shared_fpr_metric" in payload: - del payload["shared_fpr_metric"] - try: - return cls.from_dict(payload) - except (TypeError, ValueError) as ex: - msg = f"Invalid payload in file {file_path}. Cause: {ex}." - raise TypeError(msg) from ex - - def stats( - self, - outliers_policy: str | StatsOutliersPolicy = StatsOutliersPolicy.NONE.value, - repeated_policy: str | StatsRepeatedPolicy = StatsRepeatedPolicy.AVOID.value, - repeated_replacement_atol: float = 1e-2, - ) -> list[dict[str, str | int | float]]: - """Return the AUPIMO statistics. - - See `anomalib.metrics.per_image.utils.per_image_scores_stats` for details. - - Returns: - list[dict[str, str | int | float]]: AUPIMO statistics - """ - return utils.per_image_scores_stats( - self.aupimos, - self.image_classes, - only_class=1, - outliers_policy=outliers_policy, - repeated_policy=repeated_policy, - repeated_replacement_atol=repeated_replacement_atol, - ) - # =========================================== FUNCTIONAL =========================================== - - def pimo_curves( anomaly_maps: Tensor, masks: Tensor, @@ -857,23 +708,6 @@ def normalizing_factor(fpr_bounds: tuple[float, float]) -> float: """ return pimo_numpy.aupimo_normalizing_factor(fpr_bounds) - @staticmethod - def random_model_score(fpr_bounds: tuple[float, float]) -> float: - """AUPIMO of a theoretical random model. - - "Random model" means that there is no discrimination between normal and anomalous pixels/patches/images. - It corresponds to assuming the functions T = F. - - For the FPR bounds (1e-5, 1e-4), the random model AUPIMO is ~4e-5. - - Args: - fpr_bounds: lower and upper bounds of the FPR integration range. - - Returns: - float: the AUPIMO score. - """ - return pimo_numpy.aupimo_random_model_score(fpr_bounds) - def __repr__(self) -> str: """Show the metric name and its integration bounds.""" lower, upper = self.fpr_bounds diff --git a/src/anomalib/metrics/per_image/pimo_numpy.py b/src/anomalib/metrics/per_image/pimo_numpy.py index d979cf1a53..2bd5c0cb89 100644 --- a/src/anomalib/metrics/per_image/pimo_numpy.py +++ b/src/anomalib/metrics/per_image/pimo_numpy.py @@ -11,7 +11,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from enum import Enum import numpy as np from numpy import ndarray @@ -21,14 +20,6 @@ logger = logging.getLogger(__name__) -# =========================================== CONSTANTS =========================================== - - -class PIMOSharedFPRMetric(Enum): - """Shared FPR metric (x-axis of the PIMO curve).""" - - MEAN_PERIMAGE_FPR: str = "mean-per-image-fpr" - # =========================================== AUX =========================================== @@ -381,23 +372,3 @@ def aupimo_normalizing_factor(fpr_bounds: tuple[float, float]) -> float: fpr_lower_bound, fpr_upper_bound = fpr_bounds # the log's base must be the same as the one used in the integration! return float(np.log(fpr_upper_bound / fpr_lower_bound)) - - -def aupimo_random_model_score(fpr_bounds: tuple[float, float]) -> float: - """AUPIMO of a theoretical random model. - - "Random model" means that there is no discrimination between normal and anomalous pixels/patches/images. - It corresponds to assuming the functions T = F. - - For the FPR bounds (1e-5, 1e-4), the random model AUPIMO is ~4e-5. - - Args: - fpr_bounds: lower and upper bounds of the FPR integration range. - - Returns: - float: the AUPIMO score. - """ - _validate.is_rate_range(fpr_bounds) - fpr_lower_bound, fpr_upper_bound = fpr_bounds - integral_value = fpr_upper_bound - fpr_lower_bound - return float(integral_value / aupimo_normalizing_factor(fpr_bounds)) diff --git a/src/anomalib/metrics/per_image/utils.py b/src/anomalib/metrics/per_image/utils.py index 1d47674e2c..927ae4989f 100644 --- a/src/anomalib/metrics/per_image/utils.py +++ b/src/anomalib/metrics/per_image/utils.py @@ -245,8 +245,8 @@ def per_image_scores_stats( Outliers are handled according to `outliers_policy`: - None | "none": do not include outliers. - - "hi": only include high outliers. - - "lo": only include low outliers. + - "high": only include high outliers. + - "low": only include low outliers. - "both": include both high and low outliers. ** IMAGE INDEX ** diff --git a/src/anomalib/metrics/per_image/utils_numpy.py b/src/anomalib/metrics/per_image/utils_numpy.py index 7eb5413346..619e7c1677 100644 --- a/src/anomalib/metrics/per_image/utils_numpy.py +++ b/src/anomalib/metrics/per_image/utils_numpy.py @@ -29,14 +29,14 @@ class StatsOutliersPolicy(Enum): from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. None | "none": do not include outliers. - "hi": only include high outliers. - "lo": only include low outliers. + "high": only include high outliers. + "low": only include low outliers. "both": include both high and low outliers. """ NONE: str = "none" - HI: str = "hi" - LO: str = "lo" + HIGH: str = "high" + LOW: str = "low" BOTH: str = "both" @@ -158,8 +158,8 @@ def per_image_scores_stats( Outliers are handled according to `outliers_policy`: - None | "none": do not include outliers. - - "hi": only include high outliers. - - "lo": only include low outliers. + - "high": only include high outliers. + - "low": only include low outliers. - "both": include both high and low outliers. ** IMAGE INDEX ** @@ -244,13 +244,13 @@ def per_image_scores_stats( outliers_lo = outliers[outliers < boxplot_stats["med"]] outliers_hi = outliers[outliers > boxplot_stats["med"]] - if outliers_policy in {StatsOutliersPolicy.HI, StatsOutliersPolicy.BOTH}: + if outliers_policy in {StatsOutliersPolicy.HIGH, StatsOutliersPolicy.BOTH}: boxplot_stats = { **boxplot_stats, **{f"outhi_{idx:06}": value for idx, value in enumerate(outliers_hi)}, } - if outliers_policy in {StatsOutliersPolicy.LO, StatsOutliersPolicy.BOTH}: + if outliers_policy in {StatsOutliersPolicy.LOW, StatsOutliersPolicy.BOTH}: boxplot_stats = { **boxplot_stats, **{f"outlo_{idx:06}": value for idx, value in enumerate(outliers_lo)}, diff --git a/tests/unit/metrics/per_image/test_binclf_curve.py b/tests/unit/metrics/per_image/test_binclf_curve.py index 62112bc257..cd7c0cdd98 100644 --- a/tests/unit/metrics/per_image/test_binclf_curve.py +++ b/tests/unit/metrics/per_image/test_binclf_curve.py @@ -11,11 +11,9 @@ import numpy as np import pytest -import torch from numpy import ndarray -from torch import Tensor -from anomalib.metrics.per_image import binclf_curve, binclf_curve_numpy +from anomalib.metrics.per_image import binclf_curve_numpy def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: @@ -233,24 +231,6 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: argvalues=per_image_binclf_curves_numpy_argvalues, ) - # the test with the torch interface are the same we just convert ndarray to Tensor - if metafunc.function is test_per_image_binclf_curve_torch: - metafunc.parametrize( - argnames=( - "anomaly_maps", - "masks", - "threshs_choice", - "threshs_given", - "num_threshs", - "expected_threshs", - "expected_binclf_curves", - ), - argvalues=[ - tuple(torch.from_numpy(v) if isinstance(v, np.ndarray) else v for v in argvals) - for argvals in per_image_binclf_curves_numpy_argvalues - ], - ) - if metafunc.function is test_per_image_binclf_curve_numpy_validations: metafunc.parametrize( argnames=("args", "exception"), @@ -305,14 +285,6 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ], ) - if metafunc.function is test_rate_metrics_torch: - metafunc.parametrize( - argnames=("binclf_curves", "expected_fprs", "expected_tprs"), - argvalues=[ - (torch.from_numpy(binclf_curves), torch.from_numpy(expected_fprs), torch.from_numpy(expected_tprs)), - ], - ) - # ================================================================================================== # LOW-LEVEL FUNCTIONS (PYTHON) @@ -429,48 +401,3 @@ def test_rate_metrics_numpy(binclf_curves: ndarray, expected_fprs: ndarray, expe assert np.allclose(tprs, expected_tprs, equal_nan=True) assert np.allclose(fprs, expected_fprs, equal_nan=True) - - -# ================================================================================================== -# API FUNCTIONS (TORCH) - - -def test_per_image_binclf_curve_torch( - anomaly_maps: Tensor, - masks: Tensor, - threshs_choice: str, - threshs_given: Tensor | None, - num_threshs: int | None, - expected_threshs: Tensor, - expected_binclf_curves: Tensor, -) -> None: - """Test if `per_image_binclf_curve()` returns the expected values.""" - computed_threshs, computed_binclf_curves = binclf_curve.per_image_binclf_curve( - anomaly_maps, - masks, - threshs_choice=threshs_choice, - threshs_given=threshs_given, - num_threshs=num_threshs, - ) - - # threshs - assert computed_threshs.shape == expected_threshs.shape - assert computed_threshs.dtype == computed_threshs.dtype - assert (computed_threshs == expected_threshs).all() - - # binclf_curves - assert computed_binclf_curves.shape == expected_binclf_curves.shape - assert computed_binclf_curves.dtype == expected_binclf_curves.dtype - assert (computed_binclf_curves == expected_binclf_curves).all() - - -def test_rate_metrics_torch(binclf_curves: Tensor, expected_fprs: Tensor, expected_tprs: Tensor) -> None: - """Test if rate metrics are computed correctly.""" - tprs = binclf_curve.per_image_tpr(binclf_curves) - fprs = binclf_curve.per_image_fpr(binclf_curves) - - assert tprs.shape == expected_tprs.shape - assert fprs.shape == expected_fprs.shape - - assert torch.allclose(tprs, expected_tprs, equal_nan=True) - assert torch.allclose(fprs, expected_fprs, equal_nan=True) diff --git a/tests/unit/metrics/per_image/test_pimo.py b/tests/unit/metrics/per_image/test_pimo.py index 9e818346fb..d0092fb616 100644 --- a/tests/unit/metrics/per_image/test_pimo.py +++ b/tests/unit/metrics/per_image/test_pimo.py @@ -7,9 +7,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import tempfile -from pathlib import Path - import numpy as np import pytest import torch @@ -19,8 +16,6 @@ from anomalib.metrics.per_image import pimo, pimo_numpy from anomalib.metrics.per_image.pimo import AUPIMOResult, PIMOResult -from .test_utils import assert_statsdict_stuff - def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: """Generate tests for all functions in this module. @@ -186,12 +181,6 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: ], ) - if metafunc.function is test_pimoresult_object or metafunc.function is test_aupimoresult_object: - anomaly_maps = torch.from_numpy(anomaly_maps) - masks = torch.from_numpy(masks) - metafunc.parametrize(argnames=("anomaly_maps", "masks"), argvalues=[(anomaly_maps, masks)]) - metafunc.parametrize(argnames=("paths",), argvalues=[(None,), (["/path/to/a", "/path/to/b", "/path/to/c"],)]) - def _do_test_pimo_outputs( threshs: ndarray | Tensor, @@ -488,104 +477,3 @@ def test_aupimo_edge( force=False, **fpr_bounds, ) - - -def test_pimoresult_object( - anomaly_maps: Tensor, - masks: Tensor, - paths: list[str] | None, -) -> None: - """Test if `PIMOResult` can be converted to other formats and back.""" - optional_kwargs = {} - if paths is not None: - optional_kwargs["paths"] = paths - - pimoresult = pimo.pimo_curves( - anomaly_maps, - masks, - num_threshs=7, - **optional_kwargs, - ) - - _ = pimoresult.num_threshs - _ = pimoresult.num_images - _ = pimoresult.image_classes - - # object -> dict -> object - dic = pimoresult.to_dict() - assert isinstance(dic, dict) - pimoresult_from_dict = PIMOResult.from_dict(dic) - assert isinstance(pimoresult_from_dict, PIMOResult) - # values should be the same - assert torch.allclose(pimoresult_from_dict.threshs, pimoresult.threshs) - assert torch.allclose(pimoresult_from_dict.shared_fpr, pimoresult.shared_fpr) - assert torch.allclose(pimoresult_from_dict.per_image_tprs, pimoresult.per_image_tprs, equal_nan=True) - - # object -> file -> object - with tempfile.TemporaryDirectory() as tmpdir: - file_path = Path(tmpdir) / "pimo.pt" - pimoresult.save(str(file_path)) - assert file_path.exists() - pimoresult_from_load = PIMOResult.load(str(file_path)) - assert isinstance(pimoresult_from_load, PIMOResult) - # values should be the same - assert torch.allclose(pimoresult_from_load.threshs, pimoresult.threshs) - assert torch.allclose(pimoresult_from_load.shared_fpr, pimoresult.shared_fpr) - assert torch.allclose(pimoresult_from_load.per_image_tprs, pimoresult.per_image_tprs, equal_nan=True) - - -def test_aupimoresult_object( - anomaly_maps: Tensor, - masks: Tensor, - paths: list[str] | None, -) -> None: - """Test if `AUPIMOResult` can be converted to other formats and back.""" - optional_kwargs = {} - if paths is not None: - optional_kwargs["paths"] = paths - - _, aupimoresult = pimo.aupimo_scores( - anomaly_maps, - masks, - num_threshs=7, - fpr_bounds=(1e-5, 1e-4), - force=True, - **optional_kwargs, - ) - - # call properties - _ = aupimoresult.num_images - _ = aupimoresult.image_classes - _ = aupimoresult.fpr_bounds - _ = aupimoresult.thresh_bounds - - # object -> dict -> object - dic = aupimoresult.to_dict() - assert isinstance(dic, dict) - aupimoresult_from_dict = AUPIMOResult.from_dict(dic) - assert isinstance(aupimoresult_from_dict, AUPIMOResult) - # values should be the same - assert aupimoresult_from_dict.fpr_bounds == aupimoresult.fpr_bounds - assert aupimoresult_from_dict.num_threshs == aupimoresult.num_threshs - assert aupimoresult_from_dict.thresh_bounds == aupimoresult.thresh_bounds - assert torch.allclose(aupimoresult_from_dict.aupimos, aupimoresult.aupimos, equal_nan=True) - - # object -> file -> object - with tempfile.TemporaryDirectory() as tmpdir: - file_path = Path(tmpdir) / "aupimo.json" - aupimoresult.save(str(file_path)) - assert file_path.exists() - aupimoresult_from_load = AUPIMOResult.load(str(file_path)) - assert isinstance(aupimoresult_from_load, AUPIMOResult) - # values should be the same - assert aupimoresult_from_load.fpr_bounds == aupimoresult.fpr_bounds - assert aupimoresult_from_load.num_threshs == aupimoresult.num_threshs - assert aupimoresult_from_load.thresh_bounds == aupimoresult.thresh_bounds - assert torch.allclose(aupimoresult_from_load.aupimos, aupimoresult.aupimos, equal_nan=True) - - # statistics - stats = aupimoresult.stats() - assert len(stats) == 6 - - for statdic in stats: - assert_statsdict_stuff(statdic, 2) diff --git a/tests/unit/metrics/per_image/test_utils.py b/tests/unit/metrics/per_image/test_utils.py index 0c894bf56f..c8d5943921 100644 --- a/tests/unit/metrics/per_image/test_utils.py +++ b/tests/unit/metrics/per_image/test_utils.py @@ -118,9 +118,9 @@ def test_per_image_scores_stats() -> None: stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.BOTH) assert len(stats) == 6 - stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.LO) + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.LOW) assert len(stats) == 6 - stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.HI) + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.HIGH) assert len(stats) == 6 stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.NONE) assert len(stats) == 6