diff --git a/pyproject.toml b/pyproject.toml index 9709c1a112..c6c79f7b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,8 @@ test = [ "coverage[toml]", "tox", ] -full = ["anomalib[core,openvino,loggers,notebooks]"] +extra = ["numba>=0.58.1"] +full = ["anomalib[core,openvino,loggers,notebooks,extra]"] dev = ["anomalib[full,docs,test]"] [project.scripts] diff --git a/src/anomalib/data/utils/path.py b/src/anomalib/data/utils/path.py index 9c3f56273b..80c73a0f68 100644 --- a/src/anomalib/data/utils/path.py +++ b/src/anomalib/data/utils/path.py @@ -142,13 +142,20 @@ def contains_non_printable_characters(path: str | Path) -> bool: return not printable_pattern.match(str(path)) -def validate_path(path: str | Path, base_dir: str | Path | None = None, should_exist: bool = True) -> Path: +def validate_path( + path: str | Path, + base_dir: str | Path | None = None, + should_exist: bool = True, + accepted_extensions: tuple[str, ...] | None = None, +) -> Path: """Validate the path. Args: path (str | Path): Path to validate. base_dir (str | Path): Base directory to restrict file access. should_exist (bool): If True, do not raise an exception if the path does not exist. + accepted_extensions (tuple[str, ...] | None): Accepted extensions for the path. An exception is raised if the + path does not have one of the accepted extensions. If None, no check is performed. Defaults to None. Returns: Path: Validated path. @@ -213,6 +220,11 @@ def validate_path(path: str | Path, base_dir: str | Path | None = None, should_e msg = f"Read or execute permissions denied for the path: {path}" raise PermissionError(msg) + # Check if the path has one of the accepted extensions + if accepted_extensions is not None and path.suffix not in accepted_extensions: + msg = f"Path extension is not accepted. Accepted extensions: {accepted_extensions}. Path: {path}" + raise ValueError(msg) + return path diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 4c3eafa811..a47680f676 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -11,6 +11,7 @@ import torchmetrics from omegaconf import DictConfig, ListConfig +from . import per_image from .anomaly_score_distribution import AnomalyScoreDistribution from .aupr import AUPR from .aupro import AUPRO @@ -19,6 +20,7 @@ from .f1_max import F1Max from .f1_score import F1Score from .min_max import MinMax +from .per_image import AUPIMO, PIMO, aupimo_scores, pimo_curves from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO from .threshold import F1AdaptiveThreshold, ManualThreshold @@ -35,6 +37,11 @@ "ManualThreshold", "MinMax", "PRO", + "per_image", + "pimo_curves", + "aupimo_scores", + "PIMO", + "AUPIMO", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/metrics/per_image/__init__.py b/src/anomalib/metrics/per_image/__init__.py new file mode 100644 index 0000000000..2e34372ff7 --- /dev/null +++ b/src/anomalib/metrics/per_image/__init__.py @@ -0,0 +1,44 @@ +"""Per-Image Metrics.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .binclf_curve import per_image_binclf_curve, per_image_fpr, per_image_tpr +from .binclf_curve_numpy import BinclfAlgorithm, BinclfThreshsChoice +from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult, aupimo_scores, pimo_curves +from .utils import ( + compare_models_pairwise_ttest_rel, + compare_models_pairwise_wilcoxon, + format_pairwise_tests_results, + per_image_scores_stats, +) +from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy + +__all__ = [ + # constants + "BinclfAlgorithm", + "BinclfThreshsChoice", + "StatsOutliersPolicy", + "StatsRepeatedPolicy", + # result classes + "PIMOResult", + "AUPIMOResult", + # functional interfaces + "per_image_binclf_curve", + "per_image_fpr", + "per_image_tpr", + "pimo_curves", + "aupimo_scores", + # torchmetrics interfaces + "PIMO", + "AUPIMO", + # utils + "compare_models_pairwise_ttest_rel", + "compare_models_pairwise_wilcoxon", + "format_pairwise_tests_results", + "per_image_scores_stats", +] diff --git a/src/anomalib/metrics/per_image/_binclf_curve_numba.py b/src/anomalib/metrics/per_image/_binclf_curve_numba.py new file mode 100644 index 0000000000..3151a2faba --- /dev/null +++ b/src/anomalib/metrics/per_image/_binclf_curve_numba.py @@ -0,0 +1,115 @@ +"""Binary classification matrix curve (NUMBA implementation of low level functions). + +Details: `.binclf_curve`. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numba +import numpy as np +from numpy import ndarray + + +@numba.jit(nopython=True) +def binclf_one_curve_numba(scores: ndarray, gts: ndarray, threshs: ndarray) -> ndarray: + """One binary classification matrix at each threshold (NUMBA implementation). + + This does the same as `_binclf_one_curve_python` but with numba using just-in-time compilation. + + Note: VALIDATION IS NOT DONE HERE! Make sure to validate the arguments before calling this function. + + Args: + scores (ndarray): Anomaly scores (D,). + gts (ndarray): Binary (bool) ground truth of shape (D,). + threshs (ndarray): Sequence of thresholds in ascending order (K,). + + Returns: + ndarray: Binary classification matrix curve (K, 2, 2) + + Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + """ + num_th = len(threshs) + + # POSITIVES + scores_pos = scores[gts] + # the sorting is very important for the algorithm to work and the speedup + scores_pos = np.sort(scores_pos) + # start counting with lowest th, so everything is predicted as positive (this variable is updated in the loop) + num_pos = current_count_tp = len(scores_pos) + + tps = np.empty((num_th,), dtype=np.int64) + + # NEGATIVES + # same thing but for the negative samples + scores_neg = scores[~gts] + scores_neg = np.sort(scores_neg) + num_neg = current_count_fp = len(scores_neg) + + fps = np.empty((num_th,), dtype=np.int64) + + # it will progressively drop the scores that are below the current th + for thidx, th in enumerate(threshs): + num_drop = 0 + num_scores = len(scores_pos) + while num_drop < num_scores and scores_pos[num_drop] < th: # ! scores_pos ! + num_drop += 1 + # --- + scores_pos = scores_pos[num_drop:] + current_count_tp -= num_drop + tps[thidx] = current_count_tp + + # same with the negatives + num_drop = 0 + num_scores = len(scores_neg) + while num_drop < num_scores and scores_neg[num_drop] < th: # ! scores_neg ! + num_drop += 1 + # --- + scores_neg = scores_neg[num_drop:] + current_count_fp -= num_drop + fps[thidx] = current_count_fp + + fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps + tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps + + # sequence of dimensions is (threshs, true class, predicted class) (see docstring) + return np.stack( + ( + np.stack((tns, fps), axis=-1), + np.stack((fns, tps), axis=-1), + ), + axis=-1, + ).transpose(0, 2, 1) + + +@numba.jit(nopython=True, parallel=True) +def binclf_multiple_curves_numba(scores_batch: ndarray, gts_batch: ndarray, threshs: ndarray) -> ndarray: + """Multiple binary classification matrix at each threshold (NUMBA implementation). + + This does the same as `_binclf_multiple_curves_python` but with numba, + using parallelization and just-in-time compilation. + + Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function. + + Args: + scores_batch (ndarray): Anomaly scores (N, D,). + gts_batch (ndarray): Binary (bool) ground truth of shape (N, D,). + threshs (ndarray): Sequence of thresholds in ascending order (K,). + + Returns: + ndarray: Binary classification matrix curves (N, K, 2, 2) + + Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + """ + num_imgs = scores_batch.shape[0] + num_th = len(threshs) + ret = np.empty((num_imgs, num_th, 2, 2), dtype=np.int64) + for imgidx in numba.prange(num_imgs): + scoremap = scores_batch[imgidx] + mask = gts_batch[imgidx] + ret[imgidx] = binclf_one_curve_numba(scoremap, mask, threshs) + return ret diff --git a/src/anomalib/metrics/per_image/_validate.py b/src/anomalib/metrics/per_image/_validate.py new file mode 100644 index 0000000000..72f107e21e --- /dev/null +++ b/src/anomalib/metrics/per_image/_validate.py @@ -0,0 +1,362 @@ +"""Utils for validating arguments and results. + +`torch` is imported in the functions that use it, so this module can be used in numpy-standalone mode. + +TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase. +https://github.com/openvinotoolkit/anomalib/issues/2093 +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import numpy as np +from numpy import ndarray + + +def is_tensor(tensor: Any, argname: str | None = None) -> None: # noqa: ANN401 + """Validate that `tensor` is a `torch.Tensor`.""" + from torch import Tensor + + argname = f"'{argname}'" if argname is not None else "argument" + + if not isinstance(tensor, Tensor): + msg = f"Expected {argname} to be a tensor, but got {type(tensor)}" + raise TypeError(msg) + + +def is_num_threshs_gte2(num_threshs: int) -> None: + """Validate the number of thresholds is a positive integer >= 2.""" + if not isinstance(num_threshs, int): + msg = f"Expected the number of thresholds to be an integer, but got {type(num_threshs)}" + raise TypeError(msg) + + if num_threshs < 2: + msg = f"Expected the number of thresholds to be larger than 1, but got {num_threshs}" + raise ValueError(msg) + + +def is_same_shape(*args) -> None: + """Works both for tensors and ndarrays.""" + assert len(args) > 0 + shapes = sorted({tuple(arg.shape) for arg in args}) + if len(shapes) > 1: + msg = f"Expected arguments to have the same shape, but got {shapes}" + raise ValueError(msg) + + +def is_rate(rate: float | int, zero_ok: bool, one_ok: bool) -> None: + """Validates a rate parameter. + + Args: + rate (float | int): The rate to be validated. + zero_ok (bool): Flag indicating if rate can be 0. + one_ok (bool): Flag indicating if rate can be 1. + """ + if not isinstance(rate, float | int): + msg = f"Expected rate to be a float or int, but got {type(rate)}." + raise TypeError(msg) + + if rate < 0.0 or rate > 1.0: + msg = f"Expected rate to be in [0, 1], but got {rate}." + raise ValueError(msg) + + if not zero_ok and rate == 0.0: + msg = "Rate cannot be 0." + raise ValueError(msg) + + if not one_ok and rate == 1.0: + msg = "Rate cannot be 1." + raise ValueError(msg) + + +def is_rate_range(bounds: tuple[float, float]) -> None: + """Validates the range of rates within the bounds. + + Args: + bounds (tuple[float, float]): The lower and upper bounds of the rates. + """ + if not isinstance(bounds, tuple): + msg = f"Expected the bounds to be a tuple, but got {type(bounds)}" + raise TypeError(msg) + + if len(bounds) != 2: + msg = f"Expected the bounds to be a tuple of length 2, but got {len(bounds)}" + raise ValueError(msg) + + lower, upper = bounds + is_rate(lower, zero_ok=False, one_ok=False) + is_rate(upper, zero_ok=False, one_ok=True) + + if lower >= upper: + msg = f"Expected the upper bound to be larger than the lower bound, but got {upper=} <= {lower=}" + raise ValueError(msg) + + +def is_threshs(threshs: ndarray) -> None: + """Validate that the thresholds are valid and monotonically increasing.""" + if not isinstance(threshs, ndarray): + msg = f"Expected thresholds to be an ndarray, but got {type(threshs)}" + raise TypeError(msg) + + if threshs.ndim != 1: + msg = f"Expected thresholds to be 1D, but got {threshs.ndim}" + raise ValueError(msg) + + if threshs.dtype.kind != "f": + msg = f"Expected thresholds to be of float type, but got ndarray with dtype {threshs.dtype}" + raise TypeError(msg) + + # make sure they are strictly increasing + if not np.all(np.diff(threshs) > 0): + msg = "Expected thresholds to be strictly increasing, but it is not." + raise ValueError(msg) + + +def is_thresh_bounds(thresh_bounds: tuple[float, float]) -> None: + if not isinstance(thresh_bounds, tuple): + msg = f"Expected threshold bounds to be a tuple, but got {type(thresh_bounds)}." + raise TypeError(msg) + + if len(thresh_bounds) != 2: + msg = f"Expected threshold bounds to be a tuple of length 2, but got {len(thresh_bounds)}." + raise ValueError(msg) + + lower, upper = thresh_bounds + + if not isinstance(lower, float): + msg = f"Expected lower threshold bound to be a float, but got {type(lower)}." + raise TypeError(msg) + + if not isinstance(upper, float): + msg = f"Expected upper threshold bound to be a float, but got {type(upper)}." + raise TypeError(msg) + + if upper <= lower: + msg = f"Expected the upper bound to be greater than the lower bound, but got {upper} <= {lower}." + raise ValueError(msg) + + +def is_anomaly_maps(anomaly_maps: ndarray) -> None: + if not isinstance(anomaly_maps, ndarray): + msg = f"Expected anomaly maps to be an ndarray, but got {type(anomaly_maps)}" + raise TypeError(msg) + + if anomaly_maps.ndim != 3: + msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got {anomaly_maps.ndim} dimensions" + raise ValueError(msg) + + if anomaly_maps.dtype.kind != "f": + msg = ( + "Expected anomaly maps to be an floating ndarray with anomaly scores," + f" but got ndarray with dtype {anomaly_maps.dtype}" + ) + raise TypeError(msg) + + +def is_masks(masks: ndarray) -> None: + if not isinstance(masks, ndarray): + msg = f"Expected masks to be an ndarray, but got {type(masks)}" + raise TypeError(msg) + + if masks.ndim != 3: + msg = f"Expected masks have 3 dimensions (N, H, W), but got {masks.ndim} dimensions" + raise ValueError(msg) + + if masks.dtype.kind == "b": + pass + + elif masks.dtype.kind in ("i", "u"): + masks_unique_vals = np.unique(masks) + if np.any((masks_unique_vals != 0) & (masks_unique_vals != 1)): + msg = ( + "Expected masks to be a *binary* ndarray with ground truth labels, " + f"but got ndarray with unique values {sorted(masks_unique_vals)}" + ) + raise ValueError(msg) + + else: + msg = ( + "Expected masks to be an integer or boolean ndarray with ground truth labels, " + f"but got ndarray with dtype {masks.dtype}" + ) + raise TypeError(msg) + + +def is_binclf_curves(binclf_curves: ndarray, valid_threshs: ndarray | None) -> None: + if not isinstance(binclf_curves, ndarray): + msg = f"Expected binclf curves to be an ndarray, but got {type(binclf_curves)}" + raise TypeError(msg) + + if binclf_curves.ndim != 4: + msg = f"Expected binclf curves to be 4D, but got {binclf_curves.ndim}D" + raise ValueError(msg) + + if binclf_curves.shape[-2:] != (2, 2): + msg = f"Expected binclf curves to have shape (..., 2, 2), but got {binclf_curves.shape}" + raise ValueError(msg) + + if binclf_curves.dtype != np.int64: + msg = f"Expected binclf curves to have dtype int64, but got {binclf_curves.dtype}." + raise TypeError(msg) + + if (binclf_curves < 0).any(): + msg = "Expected binclf curves to have non-negative values, but got negative values." + raise ValueError(msg) + + neg = binclf_curves[:, :, 0, :].sum(axis=-1) # (num_images, num_threshs) + + if (neg != neg[:, :1]).any(): + msg = "Expected binclf curves to have the same number of negatives per image for every thresh." + raise ValueError(msg) + + pos = binclf_curves[:, :, 1, :].sum(axis=-1) # (num_images, num_threshs) + + if (pos != pos[:, :1]).any(): + msg = "Expected binclf curves to have the same number of positives per image for every thresh." + raise ValueError(msg) + + if valid_threshs is None: + return + + if binclf_curves.shape[1] != valid_threshs.shape[0]: + msg = ( + "Expected the binclf curves to have as many confusion matrices as the thresholds sequence, " + f"but got {binclf_curves.shape[1]} and {valid_threshs.shape[0]}" + ) + raise RuntimeError(msg) + + +def is_images_classes(images_classes: ndarray) -> None: + if not isinstance(images_classes, ndarray): + msg = f"Expected image classes to be an ndarray, but got {type(images_classes)}." + raise TypeError(msg) + + if images_classes.ndim != 1: + msg = f"Expected image classes to be 1D, but got {images_classes.ndim}D." + raise ValueError(msg) + + if images_classes.dtype.kind == "b": + pass + elif images_classes.dtype.kind in ("i", "u"): + unique_vals = np.unique(images_classes) + if np.any((unique_vals != 0) & (unique_vals != 1)): + msg = ( + "Expected image classes to be a *binary* ndarray with ground truth labels, " + f"but got ndarray with unique values {sorted(unique_vals)}" + ) + raise ValueError(msg) + else: + msg = ( + "Expected image classes to be an integer or boolean ndarray with ground truth labels, " + f"but got ndarray with dtype {images_classes.dtype}" + ) + raise TypeError(msg) + + +def is_rates(rates: ndarray, nan_allowed: bool) -> None: + if not isinstance(rates, ndarray): + msg = f"Expected rates to be an ndarray, but got {type(rates)}." + raise TypeError(msg) + + if rates.ndim != 1: + msg = f"Expected rates to be 1D, but got {rates.ndim}D." + raise ValueError(msg) + + if rates.dtype.kind != "f": + msg = f"Expected rates to have dtype of float type, but got {rates.dtype}." + raise ValueError(msg) + + isnan_mask = np.isnan(rates) + if nan_allowed: + # if they are all nan, then there is nothing to validate + if isnan_mask.all(): + return + valid_values = rates[~isnan_mask] + elif isnan_mask.any(): + msg = "Expected rates to not contain NaN values, but got NaN values." + raise ValueError(msg) + else: + valid_values = rates + + if (valid_values < 0).any(): + msg = "Expected rates to have values in the interval [0, 1], but got values < 0." + raise ValueError(msg) + + if (valid_values > 1).any(): + msg = "Expected rates to have values in the interval [0, 1], but got values > 1." + raise ValueError(msg) + + +def is_rate_curve(rate_curve: ndarray, nan_allowed: bool, decreasing: bool) -> None: + is_rates(rate_curve, nan_allowed=nan_allowed) + + diffs = np.diff(rate_curve) + diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs + + if decreasing and (diffs_valid > 0).any(): + msg = "Expected rate curve to be monotonically decreasing, but got non-monotonically decreasing values." + raise ValueError(msg) + + if not decreasing and (diffs_valid < 0).any(): + msg = "Expected rate curve to be monotonically increasing, but got non-monotonically increasing values." + raise ValueError(msg) + + +def is_per_image_rate_curves(rate_curves: ndarray, nan_allowed: bool, decreasing: bool | None) -> None: + if not isinstance(rate_curves, ndarray): + msg = f"Expected per-image rate curves to be an ndarray, but got {type(rate_curves)}." + raise TypeError(msg) + + if rate_curves.ndim != 2: + msg = f"Expected per-image rate curves to be 2D, but got {rate_curves.ndim}D." + raise ValueError(msg) + + if rate_curves.dtype.kind != "f": + msg = f"Expected per-image rate curves to have dtype of float type, but got {rate_curves.dtype}." + raise ValueError(msg) + + isnan_mask = np.isnan(rate_curves) + if nan_allowed: + # if they are all nan, then there is nothing to validate + if isnan_mask.all(): + return + valid_values = rate_curves[~isnan_mask] + elif isnan_mask.any(): + msg = "Expected per-image rate curves to not contain NaN values, but got NaN values." + raise ValueError(msg) + else: + valid_values = rate_curves + + if (valid_values < 0).any(): + msg = "Expected per-image rate curves to have values in the interval [0, 1], but got values < 0." + raise ValueError(msg) + + if (valid_values > 1).any(): + msg = "Expected per-image rate curves to have values in the interval [0, 1], but got values > 1." + raise ValueError(msg) + + if decreasing is None: + return + + diffs = np.diff(rate_curves, axis=1) + diffs_valid = diffs[~np.isnan(diffs)] if nan_allowed else diffs + + if decreasing and (diffs_valid > 0).any(): + msg = ( + "Expected per-image rate curves to be monotonically decreasing, " + "but got non-monotonically decreasing values." + ) + raise ValueError(msg) + + if not decreasing and (diffs_valid < 0).any(): + msg = ( + "Expected per-image rate curves to be monotonically increasing, " + "but got non-monotonically increasing values." + ) + raise ValueError(msg) diff --git a/src/anomalib/metrics/per_image/binclf_curve.py b/src/anomalib/metrics/per_image/binclf_curve.py new file mode 100644 index 0000000000..1a1b614a68 --- /dev/null +++ b/src/anomalib/metrics/per_image/binclf_curve.py @@ -0,0 +1,173 @@ +"""Binary classification curve (torch interface). + +This module implements torch interfaces to access the numpy code in `binclf_curve_numpy.py`. + +Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + +Tensors are build with `torch.from_numpy` and so the returned tensors will share the same memory as the numpy arrays. + +Validations will preferably happen in ndarray so the numpy code can be reused without torch, +so often times the Tensor arguments will be converted to ndarray and then validated. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import Tensor + +from . import _validate, binclf_curve_numpy +from .binclf_curve_numpy import BinclfAlgorithm, BinclfThreshsChoice + +# =========================================== ARGS VALIDATION =========================================== + + +def _validate_is_threshs(threshs: Tensor) -> None: + _validate.is_tensor(threshs, argname="threshs") + _validate.is_threshs(threshs.numpy()) + + +def _validate_is_binclf_curves(binclf_curves: Tensor, valid_threshs: Tensor | None = None) -> None: + _validate.is_tensor(binclf_curves, argname="binclf_curves") + if valid_threshs is not None: + _validate_is_threshs(valid_threshs) + _validate.is_binclf_curves( + binclf_curves.detach().cpu().numpy(), + valid_threshs=valid_threshs.numpy() if valid_threshs is not None else None, + ) + + +# =========================================== FUNCTIONAL =========================================== + + +def per_image_binclf_curve( + anomaly_maps: Tensor, + masks: Tensor, + algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA, + threshs_choice: BinclfThreshsChoice | str = BinclfThreshsChoice.MINMAX_LINSPACE, + threshs_given: Tensor | None = None, + num_threshs: int | None = None, +) -> tuple[Tensor, Tensor]: + """Compute the binary classification matrix of each image in the batch for multiple thresholds (shared). + + Note: tensors are converted to numpy arrays and then converted back to tensors (same device as `anomaly_maps`). + + Args: + anomaly_maps (Tensor): Anomaly score maps of shape (N, H, W [, D, ...]) + masks (Tensor): Binary ground truth masks of shape (N, H, W [, D, ...]) + algorithm (str, optional): Algorithm to use. Defaults to ALGORITHM_NUMBA. + threshs_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. + return_result_object (bool, optional): Whether to return a `PerImageBinClfCurveResult` object. Defaults to True. + + *** `threshs_choice`-dependent arguments *** + + THRESH_SEQUENCE_GIVEN + threshs_given (Tensor, optional): Sequence of thresholds to use. + + THRESH_SEQUENCE_MINMAX_LINSPACE + num_threshs (int, optional): Number of thresholds between the min and max of the anomaly maps. + + Returns: + tuple[Tensor, Tensor]: + [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. + + [1] Binary classification matrices of shape (N, K, 2, 2) + + N: number of images/instances + K: number of thresholds + + The last two dimensions are the confusion matrix (ground truth, predictions) + So for each thresh it gives: + - `tp`: `[... , 1, 1]` + - `fp`: `[... , 0, 1]` + - `fn`: `[... , 1, 0]` + - `tn`: `[... , 0, 0]` + + `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: + - `tp` stands for `true positive` + - `fp` stands for `false positive` + - `fn` stands for `false negative` + - `tn` stands for `true negative` + + The numbers in each confusion matrix are the counts of pixels in the image (not the ratios). + + Thresholds are shared across all images, so all confusion matrices, for instance, + at position [:, 0, :, :] are relative to the 1st threshold in `threshs`. + + Thresholds are sorted in ascending order. + """ + _validate.is_tensor(anomaly_maps, argname="anomaly_maps") + anomaly_maps_array = anomaly_maps.detach().cpu().numpy() + + _validate.is_tensor(masks, argname="masks") + masks_array = masks.detach().cpu().numpy() + + if threshs_given is not None: + _validate.is_tensor(threshs_given, argname="threshs_given") + threshs_given_array = threshs_given.detach().cpu().numpy() + else: + threshs_given_array = None + + threshs_array, binclf_curves_array = binclf_curve_numpy.per_image_binclf_curve( + anomaly_maps=anomaly_maps_array, + masks=masks_array, + algorithm=algorithm, + threshs_choice=threshs_choice, + threshs_given=threshs_given_array, + num_threshs=num_threshs, + ) + threshs = torch.from_numpy(threshs_array).to(anomaly_maps.device) + binclf_curves = torch.from_numpy(binclf_curves_array).to(anomaly_maps.device).long() + + return threshs, binclf_curves + + +# =========================================== RATE METRICS =========================================== + + +def per_image_tpr(binclf_curves: Tensor) -> Tensor: + """Compute the true positive rates (TPR) for each image in the batch. + + Args: + binclf_curves (Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + Tensor: True positive rates (TPR) of shape (N, K) + + N: number of images/instances + K: number of thresholds + + The last dimension is the TPR for each threshold. + + Thresholds are sorted in ascending order, so TPR is in descending order. + """ + _validate_is_binclf_curves(binclf_curves) + binclf_curves_array = binclf_curves.detach().cpu().numpy() + tprs_array = binclf_curve_numpy.per_image_tpr(binclf_curves_array) + return torch.from_numpy(tprs_array).to(binclf_curves.device) + + +def per_image_fpr(binclf_curves: Tensor) -> Tensor: + """Compute the false positive rates (FPR) for each image in the batch. + + Args: + binclf_curves (Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + Tensor: False positive rates (FPR) of shape (N, K) + + N: number of images/instances + K: number of thresholds + + The last dimension is the FPR for each threshold. + + Thresholds are sorted in ascending order, so FPR is in descending order. + """ + _validate_is_binclf_curves(binclf_curves) + binclf_curves_array = binclf_curves.detach().cpu().numpy() + fprs_array = binclf_curve_numpy.per_image_fpr(binclf_curves_array) + return torch.from_numpy(fprs_array).to(binclf_curves.device) diff --git a/src/anomalib/metrics/per_image/binclf_curve_numpy.py b/src/anomalib/metrics/per_image/binclf_curve_numpy.py new file mode 100644 index 0000000000..621932baeb --- /dev/null +++ b/src/anomalib/metrics/per_image/binclf_curve_numpy.py @@ -0,0 +1,430 @@ +"""Binary classification curve (numpy-only implementation). + +A binary classification (binclf) matrix (TP, FP, FN, TN) is evaluated at multiple thresholds. + +The thresholds are shared by all instances/images, but their binclf are computed independently for each instance/image. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import logging +from enum import Enum +from functools import partial + +import numpy as np +from numpy import ndarray + +try: + import numba # noqa: F401 +except ImportError: + HAS_NUMBA = False +else: + HAS_NUMBA = True + + +if HAS_NUMBA: + from . import _binclf_curve_numba + +from . import _validate + +logger = logging.getLogger(__name__) + +# =========================================== CONSTANTS =========================================== + + +class BinclfAlgorithm(Enum): + """Algorithm to use (relates to the low-level implementation).""" + + PYTHON: str = "python" + NUMBA: str = "numba" + + +class BinclfThreshsChoice(Enum): + """Sequence of thresholds to use.""" + + GIVEN: str = "given" + MINMAX_LINSPACE: str = "minmax-linspace" + MEAN_FPR_OPTIMIZED: str = "mean-fpr-optimized" + + +# =========================================== ARGS VALIDATION =========================================== + + +def _validate_is_scores_batch(scores_batch: ndarray) -> None: + """scores_batch (ndarray): floating (N, D).""" + if not isinstance(scores_batch, ndarray): + msg = f"Expected `scores_batch` to be an ndarray, but got {type(scores_batch)}" + raise TypeError(msg) + + if scores_batch.dtype.kind != "f": + msg = ( + "Expected `scores_batch` to be an floating ndarray with anomaly scores_batch," + f" but got ndarray with dtype {scores_batch.dtype}" + ) + raise TypeError(msg) + + if scores_batch.ndim != 2: + msg = f"Expected `scores_batch` to be 2D, but got {scores_batch.ndim}" + raise ValueError(msg) + + +def _validate_is_gts_batch(gts_batch: ndarray) -> None: + """gts_batch (ndarray): boolean (N, D).""" + if not isinstance(gts_batch, ndarray): + msg = f"Expected `gts_batch` to be an ndarray, but got {type(gts_batch)}" + raise TypeError(msg) + + if gts_batch.dtype.kind != "b": + msg = ( + "Expected `gts_batch` to be an boolean ndarray with anomaly scores_batch," + f" but got ndarray with dtype {gts_batch.dtype}" + ) + raise TypeError(msg) + + if gts_batch.ndim != 2: + msg = f"Expected `gts_batch` to be 2D, but got {gts_batch.ndim}" + raise ValueError(msg) + + +# =========================================== PYTHON VERSION =========================================== + + +def _binclf_one_curve_python(scores: ndarray, gts: ndarray, threshs: ndarray) -> ndarray: + """One binary classification matrix at each threshold (PYTHON implementation). + + In the case where the thresholds are given (i.e. not considering all possible thresholds based on the scores), + this weird-looking function is faster than the two options in `torchmetrics` on the CPU: + - `_binary_precision_recall_curve_update_vectorized` + - `_binary_precision_recall_curve_update_loop` + + (both in module `torchmetrics.functional.classification.precision_recall_curve` in `torchmetrics==1.1.0`). + + Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function. + + Args: + scores (ndarray): Anomaly scores (D,). + gts (ndarray): Binary (bool) ground truth of shape (D,). + threshs (ndarray): Sequence of thresholds in ascending order (K,). + + Returns: + ndarray: Binary classification matrix curve (K, 2, 2) + + Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + """ + num_th = len(threshs) + + # POSITIVES + scores_positives = scores[gts] + # the sorting is very important for the algorithm to work and the speedup + scores_positives = np.sort(scores_positives) + # variable updated in the loop; start counting with lowest thresh ==> everything is predicted as positive + num_pos = current_count_tp = scores_positives.size + tps = np.empty((num_th,), dtype=np.int64) + + # NEGATIVES + # same thing but for the negative samples + scores_negatives = scores[~gts] + scores_negatives = np.sort(scores_negatives) + num_neg = current_count_fp = scores_negatives.size + fps = np.empty((num_th,), dtype=np.int64) + + def score_less_than_thresh(score: float, thresh: float) -> bool: + return score < thresh + + # it will progressively drop the scores that are below the current thresh + for thresh_idx, thresh in enumerate(threshs): + # UPDATE POSITIVES + # < becasue it is the same as ~(>=) + num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_positives)) + scores_positives = scores_positives[num_drop:] + current_count_tp -= num_drop + tps[thresh_idx] = current_count_tp + + # UPDATE NEGATIVES + # same with the negatives + num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_negatives)) + scores_negatives = scores_negatives[num_drop:] + current_count_fp -= num_drop + fps[thresh_idx] = current_count_fp + + # deduce the rest of the matrix counts + fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps + tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps + + # sequence of dimensions is (threshs, true class, predicted class) (see docstring) + return np.stack( + [ + np.stack([tns, fps], axis=-1), + np.stack([fns, tps], axis=-1), + ], + axis=-1, + ).transpose(0, 2, 1) + + +_binclf_multiple_curves_python = np.vectorize(_binclf_one_curve_python, signature="(n),(n),(k)->(k,2,2)") +_binclf_multiple_curves_python.__doc__ = """ +Multiple binary classification matrix at each threshold (PYTHON implementation). +vectorized version of `_binclf_one_curve_python` (see above) +""" + +# =========================================== INTERFACE =========================================== + + +def binclf_multiple_curves( + scores_batch: ndarray, + gts_batch: ndarray, + threshs: ndarray, + algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, +) -> ndarray: + """Multiple binary classification matrix (per-instance scope) at each threshold (shared). + + This is a wrapper around `_binclf_multiple_curves_python` and `_binclf_multiple_curves_numba`. + Validation of the arguments is done here (not in the actual implementation functions). + + Note: predicted as positive condition is `score >= thresh`. + + Args: + scores_batch (ndarray): Anomaly scores (N, D,). + gts_batch (ndarray): Binary (bool) ground truth of shape (N, D,). + threshs (ndarray): Sequence of thresholds in ascending order (K,). + algorithm (str, optional): Algorithm to use. Defaults to ALGORITHM_NUMBA. + + Returns: + ndarray: Binary classification matrix curves (N, K, 2, 2) + + The last two dimensions are the confusion matrix (ground truth, predictions) + So for each thresh it gives: + - `tp`: `[... , 1, 1]` + - `fp`: `[... , 0, 1]` + - `fn`: `[... , 1, 0]` + - `tn`: `[... , 0, 0]` + + `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: + - `tp` stands for `true positive` + - `fp` stands for `false positive` + - `fn` stands for `false negative` + - `tn` stands for `true negative` + + The numbers in each confusion matrix are the counts (not the ratios). + + Counts are relative to each instance (i.e. from 0 to D, e.g. the total is the number of pixels in the image). + + Thresholds are shared across all instances, so all confusion matrices, for instance, + at position [:, 0, :, :] are relative to the 1st threshold in `threshs`. + + Thresholds are sorted in ascending order. + """ + algorithm = BinclfAlgorithm(algorithm) + _validate_is_scores_batch(scores_batch) + _validate_is_gts_batch(gts_batch) + _validate.is_same_shape(scores_batch, gts_batch) + _validate.is_threshs(threshs) + + if algorithm == BinclfAlgorithm.NUMBA: + if HAS_NUMBA: + return _binclf_curve_numba.binclf_multiple_curves_numba(scores_batch, gts_batch, threshs) + + logger.warning( + f"Algorithm '{BinclfAlgorithm.NUMBA.value}' was selected, but Numba is not installed. " + f"Falling back to '{BinclfAlgorithm.PYTHON.value}' implementation.", + "Notice that the performance will be slower. Consider installing Numba for faster computation.", + ) + + return _binclf_multiple_curves_python(scores_batch, gts_batch, threshs) + + +# ========================================= PER-IMAGE BINCLF CURVE ========================================= + + +def _get_threshs_minmax_linspace(anomaly_maps: ndarray, num_threshs: int) -> ndarray: + """Get thresholds linearly spaced between the min and max of the anomaly maps.""" + _validate.is_num_threshs_gte2(num_threshs) + # this operation can be a bit expensive + thresh_low, thresh_high = thresh_bounds = (anomaly_maps.min().item(), anomaly_maps.max().item()) + try: + _validate.is_thresh_bounds(thresh_bounds) + except ValueError as ex: + msg = f"Invalid threshold bounds computed from the given anomaly maps. Cause: {ex}" + raise ValueError(msg) from ex + return np.linspace(thresh_low, thresh_high, num_threshs, dtype=anomaly_maps.dtype) + + +def per_image_binclf_curve( + anomaly_maps: ndarray, + masks: ndarray, + algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, + threshs_choice: BinclfThreshsChoice | str = BinclfThreshsChoice.MINMAX_LINSPACE.value, + threshs_given: ndarray | None = None, + num_threshs: int | None = None, +) -> tuple[ndarray, ndarray]: + """Compute the binary classification matrix of each image in the batch for multiple thresholds (shared). + + Args: + anomaly_maps (ndarray): Anomaly score maps of shape (N, H, W) + masks (ndarray): Binary ground truth masks of shape (N, H, W) + algorithm (str, optional): Algorithm to use. Defaults to ALGORITHM_NUMBA. + threshs_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. + # + # `threshs_choice`-dependent arguments + # + # THRESH_SEQUENCE_GIVEN + threshs_given (ndarray, optional): Sequence of thresholds to use. + # + # THRESH_SEQUENCE_MINMAX_LINSPACE + num_threshs (int, optional): Number of thresholds between the min and max of the anomaly maps. + + Returns: + tuple[ndarray, ndarray]: + [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. + + [1] Binary classification matrices of shape (N, K, 2, 2) + + N: number of images/instances + K: number of thresholds + + The last two dimensions are the confusion matrix (ground truth, predictions) + So for each thresh it gives: + - `tp`: `[... , 1, 1]` + - `fp`: `[... , 0, 1]` + - `fn`: `[... , 1, 0]` + - `tn`: `[... , 0, 0]` + + `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: + - `tp` stands for `true positive` + - `fp` stands for `false positive` + - `fn` stands for `false negative` + - `tn` stands for `true negative` + + The numbers in each confusion matrix are the counts of pixels in the image (not the ratios). + + Thresholds are shared across all images, so all confusion matrices, for instance, + at position [:, 0, :, :] are relative to the 1st threshold in `threshs`. + + Thresholds are sorted in ascending order. + """ + BinclfAlgorithm(algorithm) + threshs_choice = BinclfThreshsChoice(threshs_choice) + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + + threshs: ndarray + + if threshs_choice == BinclfThreshsChoice.GIVEN: + assert threshs_given is not None + _validate.is_threshs(threshs_given) + if num_threshs is not None: + logger.warning( + "Argument `num_threshs` was given, " + f"but it is ignored because `threshs_choice` is '{threshs_choice.value}'.", + ) + threshs = threshs_given.astype(anomaly_maps.dtype) + + elif threshs_choice == BinclfThreshsChoice.MINMAX_LINSPACE: + assert num_threshs is not None + if threshs_given is not None: + logger.warning( + "Argument `threshs_given` was given, " + f"but it is ignored because `threshs_choice` is '{threshs_choice.value}'.", + ) + # `num_threshs` is validated in the function below + threshs = _get_threshs_minmax_linspace(anomaly_maps, num_threshs) + + elif threshs_choice == BinclfThreshsChoice.MEAN_FPR_OPTIMIZED: + raise NotImplementedError(f"TODO implement {threshs_choice.value}") # noqa: EM102 + + else: + msg = ( + f"Expected `threshs_choice` to be from {list(BinclfThreshsChoice.__members__)}," + f" but got '{threshs_choice.value}'" + ) + raise NotImplementedError(msg) + + # keep the batch dimension and flatten the rest + scores_batch = anomaly_maps.reshape(anomaly_maps.shape[0], -1) + gts_batch = masks.reshape(masks.shape[0], -1).astype(bool) # make sure it is boolean + + binclf_curves = binclf_multiple_curves(scores_batch, gts_batch, threshs, algorithm=algorithm) + + num_images = anomaly_maps.shape[0] + + try: + _validate.is_binclf_curves(binclf_curves, valid_threshs=threshs) + + # these two validations cannot be done in `_validate.binclf_curves` because it does not have access to the + # original shapes of `anomaly_maps` + if binclf_curves.shape[0] != num_images: + msg = ( + "Expected `binclf_curves` to have the same number of images as `anomaly_maps`, " + f"but got {binclf_curves.shape[0]} and {anomaly_maps.shape[0]}" + ) + raise RuntimeError(msg) + + except (TypeError, ValueError) as ex: + msg = f"Invalid `binclf_curves` was computed. Cause: {ex}" + raise RuntimeError(msg) from ex + + return threshs, binclf_curves + + +# =========================================== RATE METRICS =========================================== + + +def per_image_tpr(binclf_curves: ndarray) -> ndarray: + """True positive rates (TPR) for image for each thresh. + + TPR = TP / P = TP / (TP + FN) + + TP: true positives + FM: false negatives + P: positives (TP + FN) + + Args: + binclf_curves (ndarray): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + ndarray: shape (N, K), dtype float64 + N: number of images + K: number of thresholds + + Thresholds are sorted in ascending order, so TPR is in descending order. + """ + # shape: (num images, num threshs) + tps = binclf_curves[..., 1, 1] + pos = binclf_curves[..., 1, :].sum(axis=2) # 2 was the 3 originally + + # tprs will be nan if pos == 0 (normal image), which is expected + return tps.astype(np.float64) / pos.astype(np.float64) + + +def per_image_fpr(binclf_curves: ndarray) -> ndarray: + """False positive rates (TPR) for image for each thresh. + + FPR = FP / N = FP / (FP + TN) + + FP: false positives + TN: true negatives + N: negatives (FP + TN) + + Args: + binclf_curves (ndarray): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + ndarray: shape (N, K), dtype float64 + N: number of images + K: number of thresholds + + Thresholds are sorted in ascending order, so FPR is in descending order. + """ + # shape: (num images, num threshs) + fps = binclf_curves[..., 0, 1] + neg = binclf_curves[..., 0, :].sum(axis=2) # 2 was the 3 originally + + # it can be `nan` if an anomalous image is fully covered by the mask + return fps.astype(np.float64) / neg.astype(np.float64) diff --git a/src/anomalib/metrics/per_image/pimo.py b/src/anomalib/metrics/per_image/pimo.py new file mode 100644 index 0000000000..f29056e973 --- /dev/null +++ b/src/anomalib/metrics/per_image/pimo.py @@ -0,0 +1,956 @@ +"""Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). + +# PIMO + +PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. +The anomaly score thresholds are indexed by a (shared) valued of False Positive Rate (FPR) measure on the normal images. + +Each *anomalous* image has its own curve such that the X-axis is shared by all of them. + +At a given threshold: + X-axis: Shared FPR (may vary) + 1. Log of the Average of per-image FPR on normal images. + SEE NOTE BELOW. + Y-axis: per-image TP Rate (TPR), or "Overlap" between the ground truth and the predicted masks. + +*** Note about other shared FPR alternatives *** +The shared FPR metric can be made harder by using the cross-image max (or high-percentile) FPRs instead of the mean. +Rationale: this will further punish models that have exceptional FPs in normal images. +So far there is only one shared FPR metric implemented but others will be added in the future. + +# AUPIMO + +`AUPIMO` is the area under each `PIMO` curve with bounded integration range in terms of shared FPR. + +# Disclaimer + +This module implements torch interfaces to access the numpy code in `pimo_numpy.py`. +Tensors are converted to numpy arrays and then passed and validated in the numpy code. +The results are converted back to tensors and eventually wrapped in an dataclass object. + +Validations will preferably happen in ndarray so the numpy code can be reused without torch, +so often times the Tensor arguments will be converted to ndarray and then validated. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path + +import torch +from torch import Tensor +from torchmetrics import Metric + +from anomalib.data.utils.image import duplicate_filename +from anomalib.data.utils.path import validate_path + +from . import _validate, pimo_numpy, utils +from .binclf_curve_numpy import BinclfAlgorithm +from .utils import StatsOutliersPolicy, StatsRepeatedPolicy + +logger = logging.getLogger(__name__) + +# =========================================== AUX =========================================== + + +def _images_classes_from_masks(masks: Tensor) -> Tensor: + masks = torch.concat(masks, dim=0) + device = masks.device + image_classes = pimo_numpy._images_classes_from_masks(masks.numpy()) # noqa: SLF001 + return torch.from_numpy(image_classes, device=device) + + +# =========================================== ARGS VALIDATION =========================================== + + +def _validate_is_anomaly_maps(anomaly_maps: Tensor) -> None: + _validate.is_tensor(anomaly_maps, argname="anomaly_maps") + _validate.is_anomaly_maps(anomaly_maps.numpy()) + + +def _validate_is_masks(masks: Tensor) -> None: + _validate.is_tensor(masks, argname="masks") + _validate.is_masks(masks.numpy()) + + +def _validate_is_threshs(threshs: Tensor) -> None: + _validate.is_tensor(threshs, argname="threshs") + _validate.is_threshs(threshs.numpy()) + + +def _validate_is_shared_fpr(shared_fpr: Tensor, nan_allowed: bool = False, decreasing: bool = True) -> None: + _validate.is_tensor(shared_fpr, argname="shared_fpr") + _validate.is_rate_curve(shared_fpr.numpy(), nan_allowed=nan_allowed, decreasing=decreasing) + + +def _validate_is_image_classes(image_classes: Tensor) -> None: + _validate.is_tensor(image_classes, argname="image_classes") + _validate.is_images_classes(image_classes.numpy()) + + +def _validate_is_per_image_tprs(per_image_tprs: Tensor, image_classes: Tensor) -> None: + _validate_is_image_classes(image_classes) + _validate.is_tensor(per_image_tprs, argname="per_image_tprs") + + # general validations + _validate.is_per_image_rate_curves( + per_image_tprs.numpy(), + nan_allowed=True, # normal images have NaN TPRs + decreasing=None, # not checked here + ) + + # specific to anomalous images + _validate.is_per_image_rate_curves( + per_image_tprs[image_classes == 1].numpy(), + nan_allowed=False, + decreasing=True, + ) + + # specific to normal images + normal_images_tprs = per_image_tprs[image_classes == 0] + if not normal_images_tprs.isnan().all(): + msg = "Expected all normal images to have NaN TPRs, but some have non-NaN values." + raise ValueError(msg) + + +def _validate_is_aupimos(aupimos: Tensor) -> None: + _validate.is_tensor(aupimos, argname="aupimos") + _validate.is_rates(aupimos.numpy(), nan_allowed=True) + + +def _validate_is_source_images_paths(paths: Sequence[str], expected_num_paths: int | None) -> None: + if not isinstance(paths, list): + msg = f"Expected paths to be a list, but got {type(paths)}." + raise TypeError(msg) + + for idx, path in enumerate(paths): + try: + msg = f"Invalid path at index {idx}: {path}" + validate_path( + path, + # not necessary to exist because the metric can be computed + # directly from the anomaly maps and masks, without the images + should_exist=False, + ) + + except TypeError as ex: + raise TypeError(msg) from ex + + except ValueError as ex: + raise ValueError(msg) from ex + + if not isinstance(path, str): + # this will eventually be serialized to a file, so we don't want pathlib objects keep it simple + msg = f"Expected path to be a string, but got {type(path)}." + raise TypeError(msg) + + if expected_num_paths is None: + return + + if len(paths) != expected_num_paths: + msg = f"Invalid `paths` argument. Expected {expected_num_paths} paths, but got {len(paths)} instead." + raise ValueError(msg) + + +# =========================================== RESULT OBJECT =========================================== + + +@dataclass +class PIMOResult: + """Per-Image Overlap (PIMO, pronounced pee-mo) curve. + + This interface gathers the PIMO curve data and metadata and provides several utility methods. + + Notation: + - N: number of images + - K: number of thresholds + - FPR: False Positive Rate + - TPR: True Positive Rate + + Attributes: + threshs (Tensor): sequence of K (monotonically increasing) thresholds used to compute the PIMO curve + shared_fpr (Tensor): K values of the shared FPR metric at the corresponding thresholds + per_image_tprs (Tensor): for each of the N images, the K values of in-image TPR at the corresponding thresholds + paths (list[str]) (optional): [metadata] paths to the source images to which the PIMO curves correspond + """ + + # data + threshs: Tensor = field(repr=False) # shape => (K,) + shared_fpr: Tensor = field(repr=False) # shape => (K,) + per_image_tprs: Tensor = field(repr=False) # shape => (N, K) + + # optional metadata + paths: list[str] | None = field(repr=False, default=None) + + @property + def num_threshs(self) -> int: + """Number of thresholds.""" + return self.threshs.shape[0] + + @property + def num_images(self) -> int: + """Number of images.""" + return self.per_image_tprs.shape[0] + + @property + def image_classes(self) -> Tensor: + """Image classes (0: normal, 1: anomalous). + + Deduced from the per-image TPRs. + If any TPR value is not NaN, the image is considered anomalous. + """ + return (~torch.isnan(self.per_image_tprs)).any(dim=1).to(torch.int32) + + def __post_init__(self) -> None: + """Validate the inputs for the result object are consistent.""" + try: + _validate_is_threshs(self.threshs) + _validate_is_shared_fpr(self.shared_fpr, nan_allowed=False) + _validate_is_per_image_tprs(self.per_image_tprs, self.image_classes) + + if self.paths is not None: + _validate_is_source_images_paths(self.paths, expected_num_paths=self.per_image_tprs.shape[0]) + + except (TypeError, ValueError) as ex: + msg = f"Invalid inputs for {self.__class__.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + if self.threshs.shape != self.shared_fpr.shape: + msg = ( + f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"{self.threshs.shape=} != {self.shared_fpr.shape=}." + ) + raise TypeError(msg) + + if self.threshs.shape[0] != self.per_image_tprs.shape[1]: + msg = ( + f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"{self.threshs.shape[0]=} != {self.per_image_tprs.shape[1]=}." + ) + raise TypeError(msg) + + def thresh_at(self, fpr_level: float) -> tuple[int, float, float]: + """Return the threshold at the given shared FPR. + + See `anomalib.metrics.per_image.pimo_numpy.thresh_at_shared_fpr_level` for details. + + Args: + fpr_level (float): shared FPR level + + Returns: + tuple[int, float, float]: + [0] index of the threshold + [1] threshold + [2] the actual shared FPR value at the returned threshold + """ + return pimo_numpy.thresh_at_shared_fpr_level( + self.threshs.numpy(), + self.shared_fpr.numpy(), + fpr_level, + ) + + def to_dict(self) -> dict[str, Tensor | str]: + """Return a dictionary with the result object's attributes.""" + dic = { + "threshs": self.threshs, + "shared_fpr": self.shared_fpr, + "per_image_tprs": self.per_image_tprs, + } + if self.paths is not None: + dic["paths"] = self.paths + return dic + + @classmethod + def from_dict(cls: type["PIMOResult"], dic: dict[str, Tensor | str | list[str]]) -> "PIMOResult": + """Return a result object from a dictionary.""" + try: + return cls(**dic) # type: ignore[arg-type] + + except TypeError as ex: + msg = f"Invalid input dictionary for {cls.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + def save(self, file_path: str | Path) -> None: + """Save to a `.pt` file. + + Args: + file_path: path to the `.pt` file where to save the PIMO result. + If the file already exists, a numerical suffix is added to the filename. + """ + validate_path(file_path, should_exist=False, accepted_extensions=(".pt",)) + file_path = duplicate_filename(file_path) + payload = self.to_dict() + torch.save(payload, file_path) + + @classmethod + def load(cls: type["PIMOResult"], file_path: str | Path) -> "PIMOResult": + """Load from a `.pt` file. + + Args: + file_path: path to the `.pt` file where to load the PIMO result. + """ + validate_path(file_path, accepted_extensions=(".pt",)) + payload = torch.load(file_path) + if not isinstance(payload, dict): + msg = f"Invalid content in file {file_path}. Must be a dictionary." + raise TypeError(msg) + # for compatibility with the original code + if "shared_fpr_metric" in payload: + del payload["shared_fpr_metric"] + try: + return cls.from_dict(payload) + except TypeError as ex: + msg = f"Invalid content in file {file_path}. Cause: {ex}." + raise TypeError(msg) from ex + + +@dataclass +class AUPIMOResult: + """Area Under the Per-Image Overlap (AUPIMO, pronounced a-u-pee-mo) curve. + + This interface gathers the AUPIMO data and metadata and provides several utility methods. + + Attributes: + fpr_lower_bound (float): [metadata] LOWER bound of the FPR integration range + fpr_upper_bound (float): [metadata] UPPER bound of the FPR integration range + num_threshs (int): [metadata] number of thresholds used to effectively compute AUPIMO; + should not be confused with the number of thresholds used to compute the PIMO curve + thresh_lower_bound (float): LOWER threshold bound --> corresponds to the UPPER FPR bound + thresh_upper_bound (float): UPPER threshold bound --> corresponds to the LOWER FPR bound + aupimos (Tensor): values of AUPIMO scores (1 per image) + """ + + # metadata + fpr_lower_bound: float + fpr_upper_bound: float + num_threshs: int + + # data + thresh_lower_bound: float = field(repr=False) + thresh_upper_bound: float = field(repr=False) + aupimos: Tensor = field(repr=False) # shape => (N,) + + # optional metadata + paths: list[str] | None = field(repr=False, default=None) + + @property + def num_images(self) -> int: + """Number of images.""" + return self.aupimos.shape[0] + + @property + def num_normal_images(self) -> int: + """Number of normal images.""" + return int((self.image_classes == 0).sum()) + + @property + def num_anomalous_images(self) -> int: + """Number of anomalous images.""" + return int((self.image_classes == 1).sum()) + + @property + def image_classes(self) -> Tensor: + """Image classes (0: normal, 1: anomalous).""" + # if an instance has `nan` aupimo it's because it's a normal image + return self.aupimos.isnan().to(torch.int32) + + @property + def fpr_bounds(self) -> tuple[float, float]: + """Lower and upper bounds of the FPR integration range.""" + return self.fpr_lower_bound, self.fpr_upper_bound + + @property + def thresh_bounds(self) -> tuple[float, float]: + """Lower and upper bounds of the threshold integration range. + + Recall: they correspond to the FPR bounds in reverse order. + I.e.: + fpr_lower_bound --> thresh_upper_bound + fpr_upper_bound --> thresh_lower_bound + """ + return self.thresh_lower_bound, self.thresh_upper_bound + + def __post_init__(self) -> None: + """Validate the inputs for the result object are consistent.""" + try: + _validate.is_rate_range((self.fpr_lower_bound, self.fpr_upper_bound)) + # TODO(jpcbertoldo): warn when it's too low (use parameters from the numpy code) # noqa: TD003 + _validate.is_num_threshs_gte2(self.num_threshs) + _validate_is_aupimos(self.aupimos) + _validate.is_thresh_bounds((self.thresh_lower_bound, self.thresh_upper_bound)) + + if self.paths is not None: + _validate_is_source_images_paths(self.paths, expected_num_paths=self.aupimos.shape[0]) + + except (TypeError, ValueError) as ex: + msg = f"Invalid inputs for {self.__class__.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + @classmethod + def from_pimoresult( + cls: type["AUPIMOResult"], + pimoresult: PIMOResult, + fpr_bounds: tuple[float, float], + num_threshs_auc: int, + aupimos: Tensor, + paths: list[str] | None = None, + ) -> "AUPIMOResult": + """Return an AUPIMO result object from a PIMO result object. + + Args: + pimoresult: PIMO result object + fpr_bounds: lower and upper bounds of the FPR integration range + num_threshs_auc: number of thresholds used to effectively compute AUPIMO; + NOT the number of thresholds used to compute the PIMO curve! + aupimos: AUPIMO scores + paths: paths to the source images to which the AUPIMO scores correspond. + """ + if pimoresult.per_image_tprs.shape[0] != aupimos.shape[0]: + msg = ( + f"Invalid {cls.__name__} object. Attributes have inconsistent shapes: " + f"there are {pimoresult.per_image_tprs.shape[0]} PIMO curves but {aupimos.shape[0]} AUPIMO scores." + ) + raise TypeError(msg) + + if not torch.isnan(aupimos[pimoresult.image_classes == 0]).all(): + msg = "Expected all normal images to have NaN AUPIMOs, but some have non-NaN values." + raise TypeError(msg) + + if torch.isnan(aupimos[pimoresult.image_classes == 1]).any(): + msg = "Expected all anomalous images to have valid AUPIMOs (not nan), but some have NaN values." + raise TypeError(msg) + + if pimoresult.paths is not None: + paths = pimoresult.paths + + elif paths is not None: + _validate_is_source_images_paths(paths, expected_num_paths=pimoresult.num_images) + + fpr_lower_bound, fpr_upper_bound = fpr_bounds + # recall: fpr upper/lower bounds are the same as the thresh lower/upper bounds + _, thresh_lower_bound, __ = pimoresult.thresh_at(fpr_upper_bound) + _, thresh_upper_bound, __ = pimoresult.thresh_at(fpr_lower_bound) + # `_` is the threshold's index, `__` is the actual fpr value + return cls( + fpr_lower_bound=fpr_lower_bound, + fpr_upper_bound=fpr_upper_bound, + num_threshs=num_threshs_auc, + thresh_lower_bound=float(thresh_lower_bound), + thresh_upper_bound=float(thresh_upper_bound), + aupimos=aupimos, + paths=paths, + ) + + def to_dict(self) -> dict[str, Tensor | str | float | int]: + """Return a dictionary with the result object's attributes.""" + dic = { + "fpr_lower_bound": self.fpr_lower_bound, + "fpr_upper_bound": self.fpr_upper_bound, + "num_threshs": self.num_threshs, + "thresh_lower_bound": self.thresh_lower_bound, + "thresh_upper_bound": self.thresh_upper_bound, + "aupimos": self.aupimos, + } + if self.paths is not None: + dic["paths"] = self.paths + return dic + + @classmethod + def from_dict(cls: type["AUPIMOResult"], dic: dict[str, Tensor | str | float | int | list[str]]) -> "AUPIMOResult": + """Return a result object from a dictionary.""" + try: + return cls(**dic) # type: ignore[arg-type] + + except TypeError as ex: + msg = f"Invalid input dictionary for {cls.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + def save(self, file_path: str | Path) -> None: + """Save to a `.json` file. + + Args: + file_path: path to the `.json` file where to save the AUPIMO result. + If the file already exists, a numerical suffix is added to the filename. + """ + validate_path(file_path, should_exist=False, accepted_extensions=(".json",)) + file_path = duplicate_filename(file_path) + file_path = Path(file_path) + payload = self.to_dict() + aupimos: Tensor = payload["aupimos"] + payload["aupimos"] = aupimos.numpy().tolist() + with file_path.open("w") as f: + json.dump(payload, f, indent=4) + + @classmethod + def load(cls: type["AUPIMOResult"], file_path: str | Path) -> "AUPIMOResult": + """Load from a `.json` file. + + Args: + file_path: path to the `.json` file where to load the AUPIMO result. + """ + validate_path(file_path, accepted_extensions=(".json",)) + file_path = Path(file_path) + with file_path.open("r") as f: + payload = json.load(f) + if not isinstance(payload, dict): + file_path = str(file_path) + msg = f"Invalid payload in file {file_path}. Must be a dictionary." + raise TypeError(msg) + payload["aupimos"] = torch.tensor(payload["aupimos"], dtype=torch.float64) + # for compatibility with the original code + if "shared_fpr_metric" in payload: + del payload["shared_fpr_metric"] + try: + return cls.from_dict(payload) + except (TypeError, ValueError) as ex: + msg = f"Invalid payload in file {file_path}. Cause: {ex}." + raise TypeError(msg) from ex + + def stats( + self, + outliers_policy: str | StatsOutliersPolicy = StatsOutliersPolicy.NONE.value, + repeated_policy: str | StatsRepeatedPolicy = StatsRepeatedPolicy.AVOID.value, + repeated_replacement_atol: float = 1e-2, + ) -> list[dict[str, str | int | float]]: + """Return the AUPIMO statistics. + + See `anomalib.metrics.per_image.utils.per_image_scores_stats` for details. + + Returns: + list[dict[str, str | int | float]]: AUPIMO statistics + """ + return utils.per_image_scores_stats( + self.aupimos, + self.image_classes, + only_class=1, + outliers_policy=outliers_policy, + repeated_policy=repeated_policy, + repeated_replacement_atol=repeated_replacement_atol, + ) + + +# =========================================== FUNCTIONAL =========================================== + + +def pimo_curves( + anomaly_maps: Tensor, + masks: Tensor, + num_threshs: int, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, + paths: list[str] | None = None, +) -> PIMOResult: + """Compute the Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + + This torch interface is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. + The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on + the normal images. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + paths: paths to the source images to which the PIMO curves correspond. Default: None. + + Returns: + PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + """ + _validate_is_anomaly_maps(anomaly_maps) + anomaly_maps_array = anomaly_maps.detach().cpu().numpy() + + _validate_is_masks(masks) + masks_array = masks.detach().cpu().numpy() + + if paths is not None: + _validate_is_source_images_paths(paths, expected_num_paths=anomaly_maps.shape[0]) + + # other validations are done in the numpy code + threshs_array, shared_fpr_array, per_image_tprs_array, _ = pimo_numpy.pimo_curves( + anomaly_maps_array, + masks_array, + num_threshs, + binclf_algorithm=binclf_algorithm, + ) + # _ is `image_classes` -- not needed here because it's a property in the result object + + # tensors are build with `torch.from_numpy` and so the returned tensors + # will share the same memory as the numpy arrays + device = anomaly_maps.device + # N: number of images, K: number of thresholds + # shape => (K,) + threshs = torch.from_numpy(threshs_array).to(device) + # shape => (K,) + shared_fpr = torch.from_numpy(shared_fpr_array).to(device) + # shape => (N, K) + per_image_tprs = torch.from_numpy(per_image_tprs_array).to(device) + + return PIMOResult( + threshs=threshs, + shared_fpr=shared_fpr, + per_image_tprs=per_image_tprs, + paths=paths, + ) + + +def aupimo_scores( + anomaly_maps: Tensor, + masks: Tensor, + num_threshs: int = 300_000, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, + fpr_bounds: tuple[float, float] = (1e-5, 1e-4), + force: bool = False, + paths: list[str] | None = None, +) -> tuple[PIMOResult, AUPIMOResult]: + """Compute the PIMO curves and their Area Under the Curve (i.e. AUPIMO) scores. + + This torch interface is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. + It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + fpr_bounds: lower and upper bounds of the FPR integration range + force: whether to force the computation despite bad conditions + paths: paths to the source images to which the AUPIMO scores correspond. + + Returns: + tuple[PIMOResult, AUPIMOResult]: PIMO and AUPIMO results dataclass objects. See `PIMOResult` and `AUPIMOResult`. + """ + anomaly_maps_array = anomaly_maps.detach().cpu().numpy() + masks_array = masks.detach().cpu().numpy() + + if paths is not None: + _validate_is_source_images_paths(paths, expected_num_paths=anomaly_maps.shape[0]) + + # other validations are done in the numpy code + + threshs_array, shared_fpr_array, per_image_tprs_array, _, aupimos_array, num_threshs_auc = pimo_numpy.aupimo_scores( + anomaly_maps_array, + masks_array, + num_threshs, + binclf_algorithm=binclf_algorithm, + fpr_bounds=fpr_bounds, + force=force, + ) + + # tensors are build with `torch.from_numpy` and so the returned tensors + # will share the same memory as the numpy arrays + device = anomaly_maps.device + # N: number of images, K: number of thresholds + # shape => (K,) + threshs = torch.from_numpy(threshs_array).to(device) + # shape => (K,) + shared_fpr = torch.from_numpy(shared_fpr_array).to(device) + # shape => (N, K) + per_image_tprs = torch.from_numpy(per_image_tprs_array).to(device) + # shape => (N,) + aupimos = torch.from_numpy(aupimos_array).to(device) + + pimoresult = PIMOResult( + threshs=threshs, + shared_fpr=shared_fpr, + per_image_tprs=per_image_tprs, + paths=paths, + ) + aupimoresult = AUPIMOResult.from_pimoresult( + pimoresult, + fpr_bounds=fpr_bounds, + # not `num_threshs`! + # `num_threshs` is the number of thresholds used to compute the PIMO curve + # this is the number of thresholds used to compute the AUPIMO integral + num_threshs_auc=num_threshs_auc, + aupimos=aupimos, + ) + return pimoresult, aupimoresult + + +# =========================================== TORCHMETRICS =========================================== + + +class PIMO(Metric): + """Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + + This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. + The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on + the normal images. + + Details: `anomalib.metrics.per_image.pimo`. + + Notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Attributes: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + + Args: + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + + Returns: + PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + """ + + is_differentiable: bool = False + higher_is_better: bool | None = None + full_state_update: bool = False + + num_threshs: int + binclf_algorithm: str + + anomaly_maps: list[Tensor] + masks: list[Tensor] + + @property + def _is_empty(self) -> bool: + """Return True if the metric has not been updated yet.""" + return len(self.anomaly_maps) == 0 + + @property + def num_images(self) -> int: + """Number of images.""" + return sum([am.shape[0] for am in self.anomaly_maps]) + + @property + def image_classes(self) -> Tensor: + """Image classes (0: normal, 1: anomalous).""" + return _images_classes_from_masks(self.masks) + + def __init__( + self, + num_threshs: int, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, + ) -> None: + """Per-Image Overlap (PIMO) curve. + + Args: + num_threshs: number of thresholds used to compute the PIMO curve (K) + binclf_algorithm: algorithm to compute the binary classification curve (see `binclf_curve_numpy.Algorithm`) + """ + super().__init__() + + logger.warning( + f"Metric `{self.__class__.__name__}` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint.", + ) + + # the options below are, redundantly, validated here to avoid reaching + # an error later in the execution + + _validate.is_num_threshs_gte2(num_threshs) + self.num_threshs = num_threshs + + # validate binclf_algorithm and get string + self.binclf_algorithm = BinclfAlgorithm(binclf_algorithm).value + + self.add_state("anomaly_maps", default=[], dist_reduce_fx="cat") + self.add_state("masks", default=[], dist_reduce_fx="cat") + + def update(self, anomaly_maps: Tensor, masks: Tensor) -> None: + """Update lists of anomaly maps and masks. + + Args: + anomaly_maps (Tensor): predictions of the model (ndim == 2, float) + masks (Tensor): ground truth masks (ndim == 2, binary) + """ + _validate_is_anomaly_maps(anomaly_maps) + _validate_is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + self.anomaly_maps.append(anomaly_maps) + self.masks.append(masks) + + def compute(self) -> PIMOResult: + """Compute the PIMO curves. + + Call the functional interface `pimo_curves()`, which is a wrapper around the numpy code. + + Returns: + PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + """ + if self._is_empty: + msg = "No anomaly maps and masks have been added yet. Please call `update()` first." + raise RuntimeError(msg) + anomaly_maps = torch.concat(self.anomaly_maps, dim=0) + masks = torch.concat(self.masks, dim=0) + return pimo_curves( + anomaly_maps, + masks, + self.num_threshs, + binclf_algorithm=self.binclf_algorithm, + ) + + +class AUPIMO(PIMO): + """Area Under the Per-Image Overlap (PIMO) curve. + + This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. + It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + + Details: `anomalib.metrics.per_image.pimo`. + + Notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Attributes: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + + Args: + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + fpr_bounds: lower and upper bounds of the FPR integration range + force: whether to force the computation despite bad conditions + + Returns: + tuple[PIMOResult, AUPIMOResult]: PIMO and AUPIMO results dataclass objects. See `PIMOResult` and `AUPIMOResult`. + """ + + fpr_bounds: tuple[float, float] + return_average: bool + force: bool + + @staticmethod + def normalizing_factor(fpr_bounds: tuple[float, float]) -> float: + """Constant that normalizes the AUPIMO integral to 0-1 range. + + It is the maximum possible value from the integral in AUPIMO's definition. + It corresponds to assuming a constant function T_i: thresh --> 1. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the normalization factor (>0). + """ + return pimo_numpy.aupimo_normalizing_factor(fpr_bounds) + + @staticmethod + def random_model_score(fpr_bounds: tuple[float, float]) -> float: + """AUPIMO of a theoretical random model. + + "Random model" means that there is no discrimination between normal and anomalous pixels/patches/images. + It corresponds to assuming the functions T = F. + + For the FPR bounds (1e-5, 1e-4), the random model AUPIMO is ~4e-5. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the AUPIMO score. + """ + return pimo_numpy.aupimo_random_model_score(fpr_bounds) + + def __repr__(self) -> str: + """Show the metric name and its integration bounds.""" + lower, upper = self.fpr_bounds + return f"{self.__class__.__name__}([{lower:.2g}, {upper:.2g}])" + + def __init__( + self, + num_threshs: int = 300_000, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, + fpr_bounds: tuple[float, float] = (1e-5, 1e-4), + return_average: bool = True, + force: bool = False, + ) -> None: + """Area Under the Per-Image Overlap (PIMO) curve. + + Args: + num_threshs: [passed to parent `PIMO`] number of thresholds used to compute the PIMO curve + binclf_algorithm: [passed to parent `PIMO`] algorithm to compute the binary classification curve + fpr_bounds: lower and upper bounds of the FPR integration range + return_average: if True, return the average AUPIMO score; if False, return all the individual AUPIMO scores + force: if True, force the computation of the AUPIMO scores even in bad conditions (e.g. few points) + """ + super().__init__( + num_threshs=num_threshs, + binclf_algorithm=binclf_algorithm, + ) + + # other validations are done in PIMO.__init__() + + _validate.is_rate_range(fpr_bounds) + self.fpr_bounds = fpr_bounds + self.return_average = return_average + self.force = force + + def compute(self, force: bool | None = None) -> tuple[PIMOResult, AUPIMOResult]: # type: ignore[override] + """Compute the PIMO curves and their Area Under the curve (AUPIMO) scores. + + Call the functional interface `aupimo_scores()`, which is a wrapper around the numpy code. + + Args: + force: if given (not None), override the `force` attribute. + + Returns: + tuple[PIMOResult, AUPIMOResult]: PIMO curves and AUPIMO scores dataclass objects. + See `PIMOResult` and `AUPIMOResult` for details. + """ + if self._is_empty: + msg = "No anomaly maps and masks have been added yet. Please call `update()` first." + raise RuntimeError(msg) + anomaly_maps = torch.concat(self.anomaly_maps, dim=0) + masks = torch.concat(self.masks, dim=0) + force = force if force is not None else self.force + pimoresult, aupimoresult = aupimo_scores( + anomaly_maps, + masks, + self.num_threshs, + binclf_algorithm=self.binclf_algorithm, + fpr_bounds=self.fpr_bounds, + force=force, + ) + if self.return_average: + # normal images have NaN AUPIMO scores + is_nan = torch.isnan(aupimoresult.aupimos) + return aupimoresult.aupimos[~is_nan].mean() + return pimoresult, aupimoresult diff --git a/src/anomalib/metrics/per_image/pimo_numpy.py b/src/anomalib/metrics/per_image/pimo_numpy.py new file mode 100644 index 0000000000..8b1f56f7ff --- /dev/null +++ b/src/anomalib/metrics/per_image/pimo_numpy.py @@ -0,0 +1,410 @@ +"""Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). + +Details: `anomalib.metrics.per_image.pimo`. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from enum import Enum + +import numpy as np +from numpy import ndarray + +from . import _validate, binclf_curve_numpy +from .binclf_curve_numpy import BinclfAlgorithm, BinclfThreshsChoice + +logger = logging.getLogger(__name__) + +# =========================================== CONSTANTS =========================================== + + +class PIMOSharedFPRMetric(Enum): + """Shared FPR metric (x-axis of the PIMO curve).""" + + MEAN_PERIMAGE_FPR: str = "mean-per-image-fpr" + + +# =========================================== AUX =========================================== + + +def _images_classes_from_masks(masks: ndarray) -> ndarray: + """Deduce the image classes from the masks.""" + _validate.is_masks(masks) + return (masks == 1).any(axis=(1, 2)).astype(np.int32) + + +# =========================================== ARGS VALIDATION =========================================== + + +def _validate_has_at_least_one_anomalous_image(masks: ndarray) -> None: + image_classes = _images_classes_from_masks(masks) + if (image_classes == 1).sum() == 0: + msg = "Expected at least one ANOMALOUS image, but found none." + raise ValueError(msg) + + +def _validate_has_at_least_one_normal_image(masks: ndarray) -> None: + image_classes = _images_classes_from_masks(masks) + if (image_classes == 0).sum() == 0: + msg = "Expected at least one NORMAL image, but found none." + raise ValueError(msg) + + +def _joint_validate_threshs_shared_fpr(threshs: ndarray, shared_fpr: ndarray) -> None: + if threshs.shape[0] != shared_fpr.shape[0]: + msg = ( + "Expected `threshs` and `shared_fpr` to have the same number of elements, " + f"but got {threshs.shape[0]} != {shared_fpr.shape[0]}" + ) + raise ValueError(msg) + + +# =========================================== PIMO =========================================== + + +def pimo_curves( + anomaly_maps: ndarray, + masks: ndarray, + num_threshs: int, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA.value, +) -> tuple[ndarray, ndarray, ndarray, ndarray]: + """Compute the Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + + PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. + The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on + the normal images. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + + Returns: + tuple[ndarray, ndarray, ndarray, ndarray]: + [0] thresholds of shape (K,) in ascending order + [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) + [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) + [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) + """ + # validate the strings are valid + BinclfAlgorithm(binclf_algorithm) + _validate.is_num_threshs_gte2(num_threshs) + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + _validate_has_at_least_one_anomalous_image(masks) + _validate_has_at_least_one_normal_image(masks) + + image_classes = _images_classes_from_masks(masks) + + # the thresholds are computed here so that they can be restrained to the normal images + # therefore getting a better resolution in terms of FPR quantization + # otherwise the function `binclf_curve_numpy.per_image_binclf_curve` would have the range of thresholds + # computed from all the images (normal + anomalous) + threshs = binclf_curve_numpy._get_threshs_minmax_linspace( # noqa: SLF001 + anomaly_maps[image_classes == 0], + num_threshs, + ) + + # N: number of images, K: number of thresholds + # shapes are (K,) and (N, K, 2, 2) + threshs, binclf_curves = binclf_curve_numpy.per_image_binclf_curve( + anomaly_maps=anomaly_maps, + masks=masks, + algorithm=binclf_algorithm, + threshs_choice=BinclfThreshsChoice.GIVEN.value, + threshs_given=threshs, + num_threshs=None, + ) + + shared_fpr: ndarray + # mean-per-image-fpr on normal images + # shape -> (N, K) + per_image_fprs_normals = binclf_curve_numpy.per_image_fpr(binclf_curves[image_classes == 0]) + try: + _validate.is_per_image_rate_curves(per_image_fprs_normals, nan_allowed=False, decreasing=True) + except ValueError as ex: + msg = f"Cannot compute PIMO because the per-image FPR curves from normal images are invalid. Cause: {ex}" + raise RuntimeError(msg) from ex + + # shape -> (K,) + # this is the only shared FPR metric implemented so far, + # see note about shared FPR in Details: `anomalib.metrics.per_image.pimo`. + shared_fpr = per_image_fprs_normals.mean(axis=0) + + # shape -> (N, K) + per_image_tprs = binclf_curve_numpy.per_image_tpr(binclf_curves) + + return threshs, shared_fpr, per_image_tprs, image_classes + + +# =========================================== AUPIMO =========================================== + + +def aupimo_scores( + anomaly_maps: ndarray, + masks: ndarray, + num_threshs: int = 300_000, + binclf_algorithm: BinclfAlgorithm | str = BinclfAlgorithm.NUMBA, + fpr_bounds: tuple[float, float] = (1e-5, 1e-4), + force: bool = False, +) -> tuple[ndarray, ndarray, ndarray, ndarray, ndarray, int]: + """Compute the PIMO curves and their Area Under the Curve (i.e. AUPIMO) scores. + + Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. + It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_threshs: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + fpr_bounds: lower and upper bounds of the FPR integration range + force: whether to force the computation despite bad conditions + + Returns: + tuple[ndarray, ndarray, ndarray, ndarray, ndarray]: + [0] thresholds of shape (K,) in ascending order + [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) + [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) + [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) + [4] AUPIMO scores of shape (N,) in [0, 1] + [5] number of points used in the AUC integration + """ + _validate.is_rate_range(fpr_bounds) + + # other validations are done in the `pimo` function + threshs, shared_fpr, per_image_tprs, image_classes = pimo_curves( + anomaly_maps=anomaly_maps, + masks=masks, + num_threshs=num_threshs, + binclf_algorithm=binclf_algorithm, + ) + try: + _validate.is_threshs(threshs) + _validate.is_rate_curve(shared_fpr, nan_allowed=False, decreasing=True) + _validate.is_images_classes(image_classes) + _validate.is_per_image_rate_curves(per_image_tprs[image_classes == 1], nan_allowed=False, decreasing=True) + + except ValueError as ex: + msg = f"Cannot compute AUPIMO because the PIMO curves are invalid. Cause: {ex}" + raise RuntimeError(msg) from ex + + fpr_lower_bound, fpr_upper_bound = fpr_bounds + + # get the threshold indices where the fpr bounds are achieved + fpr_lower_bound_thresh_idx, _, fpr_lower_bound_defacto = thresh_at_shared_fpr_level( + threshs, + shared_fpr, + fpr_lower_bound, + ) + fpr_upper_bound_thresh_idx, _, fpr_upper_bound_defacto = thresh_at_shared_fpr_level( + threshs, + shared_fpr, + fpr_upper_bound, + ) + + if not np.isclose(fpr_lower_bound_defacto, fpr_lower_bound, rtol=(rtol := 1e-2)): + logger.warning( + "The lower bound of the shared FPR integration range is not exactly achieved. " + f"Expected {fpr_lower_bound} but got {fpr_lower_bound_defacto}, which is not within {rtol=}.", + ) + + if not np.isclose(fpr_upper_bound_defacto, fpr_upper_bound, rtol=rtol): + logger.warning( + "The upper bound of the shared FPR integration range is not exactly achieved. " + f"Expected {fpr_upper_bound} but got {fpr_upper_bound_defacto}, which is not within {rtol=}.", + ) + + # reminder: fpr lower/upper bound is threshold upper/lower bound (reversed) + thresh_lower_bound_idx = fpr_upper_bound_thresh_idx + thresh_upper_bound_idx = fpr_lower_bound_thresh_idx + + # deal with edge cases + if thresh_lower_bound_idx >= thresh_upper_bound_idx: + msg = ( + "The thresholds corresponding to the given `fpr_bounds` are not valid because " + "they matched the same threshold or the are in the wrong order. " + f"FPR upper/lower = threshold lower/upper = {thresh_lower_bound_idx} and {thresh_upper_bound_idx}." + ) + raise RuntimeError(msg) + + # limit the curves to the integration range [lbound, ubound] + shared_fpr_bounded: ndarray = shared_fpr[thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + per_image_tprs_bounded: ndarray = per_image_tprs[:, thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + + # `shared_fpr` and `tprs` are in descending order; `flip()` reverts to ascending order + shared_fpr_bounded = np.flip(shared_fpr_bounded) + per_image_tprs_bounded = np.flip(per_image_tprs_bounded, axis=1) + + # the log's base does not matter because it's a constant factor canceled by normalization factor + shared_fpr_bounded_log = np.log(shared_fpr_bounded) + + # deal with edge cases + invalid_shared_fpr = ~np.isfinite(shared_fpr_bounded_log) + + if invalid_shared_fpr.all(): + msg = ( + "Cannot compute AUPIMO because the shared fpr integration range is invalid). " + "Try increasing the number of thresholds." + ) + raise RuntimeError(msg) + + if invalid_shared_fpr.any(): + logger.warning( + "Some values in the shared fpr integration range are nan. " + "The AUPIMO will be computed without these values.", + ) + + # get rid of nan values by removing them from the integration range + shared_fpr_bounded_log = shared_fpr_bounded_log[~invalid_shared_fpr] + per_image_tprs_bounded = per_image_tprs_bounded[:, ~invalid_shared_fpr] + + num_points_integral = int(shared_fpr_bounded_log.shape[0]) + + if num_points_integral <= 30: + msg = ( + "Cannot compute AUPIMO because the shared fpr integration range doesnt have enough points. " + f"Found {num_points_integral} points in the integration range. " + "Try increasing `num_threshs`." + ) + if not force: + raise RuntimeError(msg) + msg += " Computation was forced!" + logger.warning(msg) + + if num_points_integral < 300: + logger.warning( + "The AUPIMO may be inaccurate because the shared fpr integration range doesnt have enough points. " + f"Found {num_points_integral} points in the integration range. " + "Try increasing `num_threshs`.", + ) + + aucs: ndarray = np.trapezoid(per_image_tprs_bounded, x=shared_fpr_bounded_log, axis=1) + + # normalize, then clip(0, 1) makes sure that the values are in [0, 1] in case of numerical errors + normalization_factor = aupimo_normalizing_factor(fpr_bounds) + aucs = (aucs / normalization_factor).clip(0, 1) + + return threshs, shared_fpr, per_image_tprs, image_classes, aucs, num_points_integral + + +# =========================================== AUX =========================================== + + +def thresh_at_shared_fpr_level(threshs: ndarray, shared_fpr: ndarray, fpr_level: float) -> tuple[int, float, float]: + """Return the threshold and its index at the given shared FPR level. + + Three cases are possible: + - fpr_level == 0: the lowest threshold that achieves 0 FPR is returned + - fpr_level == 1: the highest threshold that achieves 1 FPR is returned + - 0 < fpr_level < 1: the threshold that achieves the closest (higher or lower) FPR to `fpr_level` is returned + + Args: + threshs: thresholds at which the shared FPR was computed. + shared_fpr: shared FPR values. + fpr_level: shared FPR value at which to get the threshold. + + Returns: + tuple[int, float, float]: + [0] index of the threshold + [1] threshold + [2] the actual shared FPR value at the returned threshold + """ + _validate.is_threshs(threshs) + _validate.is_rate_curve(shared_fpr, nan_allowed=False, decreasing=True) + _joint_validate_threshs_shared_fpr(threshs, shared_fpr) + _validate.is_rate(fpr_level, zero_ok=True, one_ok=True) + + shared_fpr_min, shared_fpr_max = shared_fpr.min(), shared_fpr.max() + + if fpr_level < shared_fpr_min: + msg = ( + "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " + f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + ) + raise ValueError(msg) + + if fpr_level > shared_fpr_max: + msg = ( + "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " + f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + ) + raise ValueError(msg) + + # fpr_level == 0 or 1 are special case + # because there may be multiple solutions, and the chosen should their MINIMUM/MAXIMUM respectively + if fpr_level == 0.0: + index = np.min(np.where(shared_fpr == fpr_level)) + + elif fpr_level == 1.0: + index = np.max(np.where(shared_fpr == fpr_level)) + + else: + index = np.argmin(np.abs(shared_fpr - fpr_level)) + + index = int(index) + fpr_level_defacto = shared_fpr[index] + thresh = threshs[index] + return index, thresh, fpr_level_defacto + + +def aupimo_normalizing_factor(fpr_bounds: tuple[float, float]) -> float: + """Constant that normalizes the AUPIMO integral to 0-1 range. + + It is the maximum possible value from the integral in AUPIMO's definition. + It corresponds to assuming a constant function T_i: thresh --> 1. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the normalization factor (>0). + """ + _validate.is_rate_range(fpr_bounds) + fpr_lower_bound, fpr_upper_bound = fpr_bounds + # the log's base must be the same as the one used in the integration! + return float(np.log(fpr_upper_bound / fpr_lower_bound)) + + +def aupimo_random_model_score(fpr_bounds: tuple[float, float]) -> float: + """AUPIMO of a theoretical random model. + + "Random model" means that there is no discrimination between normal and anomalous pixels/patches/images. + It corresponds to assuming the functions T = F. + + For the FPR bounds (1e-5, 1e-4), the random model AUPIMO is ~4e-5. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the AUPIMO score. + """ + _validate.is_rate_range(fpr_bounds) + fpr_lower_bound, fpr_upper_bound = fpr_bounds + integral_value = fpr_upper_bound - fpr_lower_bound + return float(integral_value / aupimo_normalizing_factor(fpr_bounds)) diff --git a/src/anomalib/metrics/per_image/utils.py b/src/anomalib/metrics/per_image/utils.py new file mode 100644 index 0000000000..1d47674e2c --- /dev/null +++ b/src/anomalib/metrics/per_image/utils.py @@ -0,0 +1,521 @@ +"""Torch-oriented interfaces for `utils.py`.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import OrderedDict +from copy import deepcopy +from typing import TYPE_CHECKING + +import pandas as pd +import torch +from pandas import DataFrame +from torch import Tensor + +from . import _validate, utils_numpy +from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy + +if TYPE_CHECKING: + from .pimo import AUPIMOResult + + +logger = logging.getLogger(__name__) + +# =========================================== ARGS VALIDATION =========================================== + + +def _validate_is_models_ordered(models_ordered: tuple[str, ...]) -> None: + if not isinstance(models_ordered, tuple): + msg = f"Expected models ordered to be a tuple, but got {type(models_ordered)}." + raise TypeError(msg) + + if len(models_ordered) < 2: + msg = f"Expected models ordered to have at least 2 models, but got {len(models_ordered)}." + raise ValueError(msg) + + for model_name in models_ordered: + if not isinstance(model_name, str): + msg = f"Expected model name to be a string, but got {type(model_name)} for model {model_name}." + raise TypeError(msg) + + if model_name == "": + msg = "Expected model name to be non-empty, but got empty string." + raise ValueError(msg) + + num_redundant_models = len(models_ordered) - len(set(models_ordered)) + if num_redundant_models > 0: + msg = f"Expected models ordered to have unique models, but got {num_redundant_models} redundant models." + raise ValueError(msg) + + +def _validate_is_confidences(confidences: dict[tuple[str, str], float]) -> None: + if not isinstance(confidences, dict): + msg = f"Expected confidences to be a dict, but got {type(confidences)}." + raise TypeError(msg) + + for (model1, model2), confidence in confidences.items(): + if not isinstance(model1, str): + msg = f"Expected model name to be a string, but got {type(model1)} for model {model1}." + raise TypeError(msg) + + if not isinstance(model2, str): + msg = f"Expected model name to be a string, but got {type(model2)} for model {model2}." + raise TypeError(msg) + + if not isinstance(confidence, float): + msg = f"Expected confidence to be a float, but got {type(confidence)} for models {model1} and {model2}." + raise TypeError(msg) + + if not (0 <= confidence <= 1): + msg = f"Expected confidence to be between 0 and 1, but got {confidence} for models {model1} and {model2}." + raise ValueError(msg) + + +def _joint_validate_models_ordered_and_confidences( + models_ordered: tuple[str, ...], + confidences: dict[tuple[str, str], float], +) -> None: + num_models = len(models_ordered) + expected_num_pairs = num_models * (num_models - 1) + + if len(confidences) != expected_num_pairs: + msg = f"Expected {expected_num_pairs} pairs of models, but got {len(confidences)} pairs of models." + raise ValueError(msg) + + models_in_confidences = {model for pair_models in confidences for model in pair_models} + + diff = set(models_ordered).symmetric_difference(models_in_confidences) + if len(diff) > 0: + msg = ( + "Expected models in confidences to be the same as models ordered, but got models missing in one" + f"of them: {diff}." + ) + raise ValueError(msg) + + +def _validate_is_scores_per_model_tensor(scores_per_model: dict[str, Tensor] | OrderedDict[str, Tensor]) -> None: + first_key_value = None + + for model_name, scores in scores_per_model.items(): + if scores.ndim != 1: + msg = f"Expected scores to be 1D, but got {scores.ndim}D for model {model_name}." + raise ValueError(msg) + + num_valid_scores = scores[~torch.isnan(scores)].numel() + + if num_valid_scores < 1: + msg = f"Expected at least 1 non-nan score, but got {num_valid_scores} for model {model_name}." + raise ValueError(msg) + + if first_key_value is None: + first_key_value = (model_name, scores) + continue + + first_model_name, first_scores = first_key_value + + # same shape + if scores.shape[0] != first_scores.shape[0]: + msg = ( + "Expected scores to have the same number of scores, " + f"but got ({model_name}) {scores.shape[0]} != {first_scores.shape[0]} ({first_model_name})." + ) + raise ValueError(msg) + + # `nan` at the same indices + if (torch.isnan(scores) != torch.isnan(first_scores)).any(): + msg = ( + "Expected `nan` values, if any, to be at the same indices, " + f"but there are differences between models {model_name} and {first_model_name}." + ) + raise ValueError(msg) + + +def _validate_is_scores_per_model_aupimoresult( + scores_per_model: dict[str, "AUPIMOResult"] | OrderedDict[str, "AUPIMOResult"], + missing_paths_ok: bool, +) -> None: + first_key_value = None + + for model_name, aupimoresult in scores_per_model.items(): + if first_key_value is None: + first_key_value = (model_name, aupimoresult) + continue + + first_model_name, first_aupimoresult = first_key_value + + if aupimoresult.fpr_bounds != first_aupimoresult.fpr_bounds: + msg = ( + "Expected AUPIMOResult objects in scores per model to have the same FPR bounds, " + f"but got ({model_name}) {aupimoresult.fpr_bounds} != " + f"{first_aupimoresult.fpr_bounds} ({first_model_name})." + ) + raise ValueError(msg) + + available_paths = [tuple(scores.paths) for scores in scores_per_model.values() if scores.paths is not None] + + if len(set(available_paths)) > 1: + msg = ( + "Expected AUPIMOResult objects in scores per model to have the same paths, " + "but got different paths for different models." + ) + raise ValueError(msg) + + if len(available_paths) != len(scores_per_model): + msg = "Some models have paths, while others are missing them." + if not missing_paths_ok: + raise ValueError(msg) + logger.warning(msg) + + +def _validate_is_scores_per_model( + scores_per_model: dict[str, Tensor] + | OrderedDict[str, Tensor] + | dict[str, "AUPIMOResult"] + | OrderedDict[str, "AUPIMOResult"], +) -> None: + # it has to be imported here to avoid circular imports + from .pimo import AUPIMOResult + + if not isinstance(scores_per_model, dict | OrderedDict): + msg = f"Expected scores per model to be a dictionary or ordered dictionary, but got {type(scores_per_model)}." + raise TypeError(msg) + + if len(scores_per_model) < 2: + msg = f"Expected scores per model to have at least 2 models, but got {len(scores_per_model)}." + raise ValueError(msg) + + if not all(isinstance(model_name, str) for model_name in scores_per_model): + msg = "Expected scores per model to have model names (strings) as keys." + raise TypeError(msg) + + first_instance = next(iter(scores_per_model.values())) + + if ( + isinstance(first_instance, Tensor) + and any(not isinstance(scores, Tensor) for scores in scores_per_model.values()) + ) or ( + isinstance(first_instance, AUPIMOResult) + and any(not isinstance(scores, AUPIMOResult) for scores in scores_per_model.values()) + ): + msg = ( + "Values in the scores per model dict must have the same type for values (Tensor or AUPIMOResult), " + "but more than one type was found." + ) + raise TypeError(msg) + + if isinstance(first_instance, Tensor): + _validate_is_scores_per_model_tensor(scores_per_model) + return + + _validate_is_scores_per_model_tensor( + {model_name: scores.aupimos for model_name, scores in scores_per_model.items()}, + ) + + _validate_is_scores_per_model_aupimoresult(scores_per_model, missing_paths_ok=True) + + +# =========================================== FUNCTIONS =========================================== + + +def per_image_scores_stats( + per_image_scores: Tensor, + images_classes: Tensor | None = None, + only_class: int | None = None, + outliers_policy: str | StatsOutliersPolicy = StatsOutliersPolicy.NONE.value, + repeated_policy: str | StatsRepeatedPolicy = StatsRepeatedPolicy.AVOID.value, + repeated_replacement_atol: float = 1e-2, +) -> list[dict[str, str | int | float]]: + """Compute statistics of per-image scores (based on a boxplot's statistics). + + ***Torch-oriented interface for `.utils_numpy.per_image_scores_stats`*** + + For a single per-image metric collection (1 model, 1 dataset), compute statistics (based on a boxplot) + and find the closest image to each statistic. + + This function uses `matplotlib.cbook.boxplot_stats`, which is the same function used by `matplotlib.pyplot.boxplot`. + + ** OUTLIERS ** + Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away + from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. + + Outliers are handled according to `outliers_policy`: + - None | "none": do not include outliers. + - "hi": only include high outliers. + - "lo": only include low outliers. + - "both": include both high and low outliers. + + ** IMAGE INDEX ** + Each statistic is associated with the image whose score is the closest to the statistic's value. + + ** REPEATED VALUES ** + It is possible that two stats have the same value (e.g. the median and the 25th percentile can be the same). + Such cases are handled according to `repeated_policy`: + - None | "none": do not address the issue, so several stats can have the same value and image index. + - "avoid": avoid repeated values by iterativealy looking for other images with similar score, whose score + must be within `repeated_replacement_atol` (absolute tolerance) of the repeated value. + + Args: + per_image_scores (Tensor): 1D Tensor of per-image scores. + images_classes (Tensor | None): + Used to filter statistics to only one class. If None, all images are considered. + If given, 1D Tensor of binary image classes (0 for 'normal', 1 for 'anomalous'). Defaults to None. + only_class (int | None): + Only used if `images_classes` is not None. + If not None, only compute statistics for images of the given class. + `None` means both image classes are used. + Defaults to None. + outliers_policy (str | None): How to handle outliers stats (use them?). See `OutliersPolicy`. Defaults to None. + repeated_policy (str | None): How to handle repeated values in boxplot stats (two stats with same value). + See `RepeatedPolicy`. Defaults to None. + repeated_replacement_atol (float): Absolute tolerance used to replace repeated values. Only used if + `repeated_policy` is not None (or 'none'). Defaults to 1e-2 (1%). + + Returns: + list[dict[str, str | int | float]]: List of boxplot statistics. + + Each dictionary has the following keys: + - 'stat_name': Name of the statistic. Possible values: + - 'mean': Mean of the scores. + - 'med': Median of the scores. + - 'q1': 25th percentile of the scores. + - 'q3': 75th percentile of the scores. + - 'whishi': Upper whisker value. + - 'whislo': Lower whisker value. + - 'outlo_i': low outlier value; `i` is a unique index for each low outlier. + - 'outhi_j': high outlier value; `j` is a unique index for each high outlier. + - 'stat_value': Value of the statistic (same units as `values`). + - 'image_idx': Index of the image in `per_image_scores` whose score is the closest to the statistic's value. + - 'score': The score of the image at index `image_idx` (not necessarily the same as `stat_value`). + + The list is sorted by increasing `stat_value`. + """ + _validate.is_tensor(per_image_scores, "per_image_scores") + per_image_scores_array = per_image_scores.detach().cpu().numpy() + + if images_classes is not None: + _validate.is_tensor(images_classes, "images_classes") + images_classes_array = images_classes.detach().cpu().numpy() + + else: + images_classes_array = None + + # other validations happen inside `utils_numpy.per_image_scores_stats` + + return utils_numpy.per_image_scores_stats( + per_image_scores_array, + images_classes_array, + only_class=only_class, + outliers_policy=outliers_policy, + repeated_policy=repeated_policy, + repeated_replacement_atol=repeated_replacement_atol, + ) + + +def compare_models_pairwise_ttest_rel( + scores_per_model: dict[str, Tensor] + | OrderedDict[str, Tensor] + | dict[str, "AUPIMOResult"] + | OrderedDict[str, "AUPIMOResult"], + alternative: str, + higher_is_better: bool, +) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: + """Compare all pairs of models using the paired t-test on two related samples (parametric). + + ***Torch-oriented interface for `.numpy_utils.compare_models_pairwise_ttest_rel`*** + + This is a test for the null hypothesis that two repeated samples have identical average (expected) values. + In fact, it tests whether the average of the differences between the two samples is significantly different from 0. + + Refs: + - `scipy.stats.ttest_rel`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_rel.html + - Wikipedia page: https://en.wikipedia.org/wiki/Student's_t-test#Dependent_t-test_for_paired_samples + + === + + If an ordered dictionary is given, the models are sorted by the order of the dictionary. + Otherwise, the models are sorted by average SCORE. + + Args: + scores_per_model: Dictionary of `n` models and their per-image scores. + key: model name + value: tensor of shape (num_images,). All `nan` values must be at the same positions. + higher_is_better: Whether higher values of score are better or worse. Defaults to True. + alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. + Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. + + Returns: + (models_ordered, test_results): + - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). + + Automatic sorting is by average score from best to worst model. + Depending on `higher_is_better`, this corresponds to: + - `higher_is_better=True` ==> descending score order + - `higher_is_better=False` ==> ascending score order + along the indices from 0 to `n-1`. + + - confidences: Dictionary of confidence values for each pair of models. + + For all pairs of indices i and j from 0 to `n-1` such that i != j: + - key: (models_ordered[i], models_ordered[j]) + - value: confidence on the alternative hypothesis. + + For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: + - if `less`: model[i] < model[j] + - if `greater`: model[i] > model[j] + - if `two-sided`: model[i] != model[j] + in termos of average score. + """ + _validate_is_scores_per_model(scores_per_model) + scores_per_model_items = [ + ( + model_name, + (scores if isinstance(scores, Tensor) else scores.aupimos).detach().cpu().numpy(), + ) + for model_name, scores in scores_per_model.items() + ] + cls = OrderedDict if isinstance(scores_per_model, OrderedDict) else dict + scores_per_model_with_arrays = cls(scores_per_model_items) + + return utils_numpy.compare_models_pairwise_ttest_rel(scores_per_model_with_arrays, alternative, higher_is_better) + + +def compare_models_pairwise_wilcoxon( + scores_per_model: dict[str, Tensor] + | OrderedDict[str, Tensor] + | dict[str, "AUPIMOResult"] + | OrderedDict[str, "AUPIMOResult"], + alternative: str, + higher_is_better: bool, +) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: + """Compare all pairs of models using the Wilcoxon signed-rank test (non-parametric). + + ***Torch-oriented interface for `.numpy_utils.compare_models_pairwise_wilcoxon`*** + + Each comparison of two models is a Wilcoxon signed-rank test (null hypothesis is that they are equal). + + It tests whether the distribution of the differences of scores is symmetric about zero in a non-parametric way. + This is like the non-parametric version of the paired t-test. + + Refs: + - `scipy.stats.wilcoxon`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html#scipy.stats.wilcoxon + - Wikipedia page: https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test + + === + + If an ordered dictionary is given, the models are sorted by the order of the dictionary. + Otherwise, the models are sorted by average RANK. + + Args: + scores_per_model: Dictionary of `n` models and their per-image scores. + key: model name + value: tensor of shape (num_images,). All `nan` values must be at the same positions. + higher_is_better: Whether higher values of score are better or worse. Defaults to True. + alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. + Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. + atol: Absolute tolerance used to consider two scores as equal. Defaults to 1e-3 (0.1%). + When doing a paired test, if the difference between two scores is below `atol`, the difference is + truncated to 0. If `atol` is None, no truncation is done. + + Returns: + (models_ordered, test_results): + - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). + + Automatic sorting is from "best to worst" model, which corresponds to ascending average rank + along the indices from 0 to `n-1`. + + - confidences: Dictionary of confidence values for each pair of models. + + For all pairs of indices i and j from 0 to `n-1` such that i != j: + - key: (models_ordered[i], models_ordered[j]) + - value: confidence on the alternative hypothesis. + + For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: + - if `less`: model[i] < model[j] + - if `greater`: model[i] > model[j] + - if `two-sided`: model[i] != model[j] + in terms of average ranks (not scores!). + """ + _validate_is_scores_per_model(scores_per_model) + scores_per_model_items = [ + ( + model_name, + (scores if isinstance(scores, Tensor) else scores.aupimos).detach().cpu().numpy(), + ) + for model_name, scores in scores_per_model.items() + ] + cls = OrderedDict if isinstance(scores_per_model, OrderedDict) else dict + scores_per_model_with_arrays = cls(scores_per_model_items) + + return utils_numpy.compare_models_pairwise_wilcoxon(scores_per_model_with_arrays, alternative, higher_is_better) + + +def format_pairwise_tests_results( + models_ordered: tuple[str, ...], + confidences: dict[tuple[str, str], float], + model1_as_column: bool = True, + left_to_right: bool = False, + top_to_bottom: bool = False, +) -> DataFrame: + """Format the results of pairwise tests into a square dataframe. + + The confidence values refer to the confidence level (in [0, 1]) on the alternative hypothesis, + which is formulated as "`model1` `model2`", where `` can be '<', '>', or '!='. + + HOW TO READ THE DATAFRAME + ========================= + There are 6 possible ways to read the dataframe, depending on the values of `model1_as_column` and `alternative` + (from the pairwise test function that generated `confidences`). + + *column* and *row* below refer to a generic column and row value (model names) in the dataframe. + + if ( + model1_as_column == True and alternative == 'less' + or model1_as_column == False and alternative == 'greater' + ) + read: "column < row" + equivalently: "row > column" + + elif ( + model1_as_column == True and alternative == 'greater' + or model1_as_column == False and alternative == 'less' + ) + read: "column > row" + equivalently: "row < column" + + else: # alternative == 'two-sided' + read: "column != row" + equivalently: "row != column" + + Args: + models_ordered: The models ordered in a meaningful way, generally from best to worst when automatically ordered. + confidences: The confidence on the alternative hypothesis, as returned by the pairwise test function. + model1_as_column: Whether to put `model1` as column or row in the dataframe. + left_to_right: Whether to order the columns from best to worst model as left to right. + top_to_bottom: Whether to order the rows from best to worst model as top to bottom. + Default column/row ordering is from worst to best model (left to right, top to bottom), + so the upper left corner is the worst model compared to itself, and the bottom right corner is the best + model compared to itself. + + """ + _validate_is_models_ordered(models_ordered) + _validate_is_confidences(confidences) + _joint_validate_models_ordered_and_confidences(models_ordered, confidences) + confidences = deepcopy(confidences) + confidences.update({(model, model): torch.nan for model in models_ordered}) + # `df` stands for `dataframe` + confdf = pd.DataFrame(confidences, index=["confidence"]).T + confdf.index.names = ["model1", "model2"] + confdf = confdf.reset_index() + confdf["model1"] = pd.Categorical(confdf["model1"], categories=models_ordered, ordered=True) + confdf["model2"] = pd.Categorical(confdf["model2"], categories=models_ordered, ordered=True) + # df at this point: 3 columns: model1, model2, confidence + index_model, column_model = ("model2", "model1") if model1_as_column else ("model1", "model2") + confdf = confdf.pivot_table(index=index_model, columns=column_model, values="confidence", dropna=False, sort=False) + # now it is a square dataframe with models as index and columns, and confidence as values + confdf = confdf.sort_index(axis=0, ascending=top_to_bottom) + return confdf.sort_index(axis=1, ascending=left_to_right) diff --git a/src/anomalib/metrics/per_image/utils_numpy.py b/src/anomalib/metrics/per_image/utils_numpy.py new file mode 100644 index 0000000000..736780831c --- /dev/null +++ b/src/anomalib/metrics/per_image/utils_numpy.py @@ -0,0 +1,481 @@ +"""Utility functions for per-image metrics.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from collections import OrderedDict +from enum import Enum + +import matplotlib as mpl +import numpy as np +import scipy +import scipy.stats +from numpy import ndarray + +from . import _validate + +# =========================================== CONSTANTS =========================================== + + +class StatsOutliersPolicy(Enum): + """How to handle outliers in per-image metrics boxplots. Use them? Only high? Only low? Both? + + Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away + from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. + + None | "none": do not include outliers. + "hi": only include high outliers. + "lo": only include low outliers. + "both": include both high and low outliers. + """ + + NONE: str = "none" + HI: str = "hi" + LO: str = "lo" + BOTH: str = "both" + + +class StatsRepeatedPolicy(Enum): + """How to handle repeated values in per-image metrics boxplots (two stats with same value). Avoid them? + + None | "none": do not avoid repeated values, so several stats can have the same value and image index. + "avoid": if a stat has the same value as another stat, the one with the closest then another image, + with the nearest score, is selected. + """ + + NONE: str = "none" + AVOID: str = "avoid" + + +class StatsAlternativeHypothesis(Enum): + """Alternative hypothesis for the statistical tests used to compare per-image metrics.""" + + TWO_SIDED: str = "two-sided" + LESS: str = "less" + GREATER: str = "greater" + + +# =========================================== ARGS VALIDATION =========================================== +def _validate_is_image_class(image_class: int) -> None: + if not isinstance(image_class, int): + msg = f"Expected image class to be an int (0 for 'normal', 1 for 'anomalous'), but got {type(image_class)}." + raise TypeError(msg) + + if image_class not in (0, 1): + msg = f"Expected image class to be either 0 for 'normal' or 1 for 'anomalous', but got {image_class}." + raise ValueError(msg) + + +def _validate_is_per_image_scores(per_image_scores: ndarray) -> None: + if not isinstance(per_image_scores, ndarray): + msg = f"Expected per-image scores to be a numpy array, but got {type(per_image_scores)}." + raise TypeError(msg) + + if per_image_scores.ndim != 1: + msg = f"Expected per-image scores to be 1D, but got {per_image_scores.ndim}D." + raise ValueError(msg) + + +def _validate_is_scores_per_model(scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray]) -> None: + if not isinstance(scores_per_model, dict | OrderedDict): + msg = f"Expected scores per model to be a dictionary or ordered dictionary, but got {type(scores_per_model)}." + raise TypeError(msg) + + if len(scores_per_model) < 2: + msg = f"Expected scores per model to have at least 2 models, but got {len(scores_per_model)}." + raise ValueError(msg) + + first_key_value = None + + for model_name, scores in scores_per_model.items(): + if not isinstance(model_name, str): + msg = f"Expected model name to be a string, but got {type(model_name)} for model {model_name}." + raise TypeError(msg) + + if not isinstance(scores, ndarray): + msg = f"Expected scores to be a numpy array, but got {type(scores)} for model {model_name}." + raise TypeError(msg) + + if scores.ndim != 1: + msg = f"Expected scores to be 1D, but got {scores.ndim}D for model {model_name}." + raise ValueError(msg) + + num_valid_scores = scores[~np.isnan(scores)].shape[0] + + if num_valid_scores < 2: + msg = f"Expected at least 2 scores, but got {num_valid_scores} for model {model_name}." + raise ValueError(msg) + + if first_key_value is None: + first_key_value = (model_name, scores) + continue + + first_model_name, first_scores = first_key_value + + # same shape + if scores.shape != first_scores.shape: + msg = ( + "Expected scores to have the same shape, " + f"but got ({model_name}) {scores.shape} != {first_scores.shape} ({first_model_name})." + ) + raise ValueError(msg) + + # `nan` at the same indices + if (np.isnan(scores) != np.isnan(first_scores)).any(): + msg = ( + "Expected `nan` values, if any, to be at the same indices, " + f"but there are differences between models {model_name} and {first_model_name}." + ) + raise ValueError(msg) + + +# =========================================== FUNCTIONS =========================================== + + +def per_image_scores_stats( + per_image_scores: ndarray, + images_classes: ndarray | None = None, + only_class: int | None = None, + outliers_policy: StatsOutliersPolicy | str = StatsOutliersPolicy.NONE.value, + repeated_policy: StatsRepeatedPolicy | str = StatsRepeatedPolicy.AVOID.value, + repeated_replacement_atol: float = 1e-2, +) -> list[dict[str, str | int | float]]: + """Compute statistics of per-image scores (based on a boxplot's statistics). + + For a single per-image metric collection (1 model, 1 dataset), compute statistics (based on a boxplot) + and find the closest image to each statistic. + + This function uses `matplotlib.cbook.boxplot_stats`, which is the same function used by `matplotlib.pyplot.boxplot`. + + ** OUTLIERS ** + Outliers are defined as in a boxplot, i.e. values that are more than 1.5 times the interquartile range (IQR) away + from the Q1 and Q3 quartiles (respectively low and high outliers). The IQR is the difference between Q3 and Q1. + + Outliers are handled according to `outliers_policy`: + - None | "none": do not include outliers. + - "hi": only include high outliers. + - "lo": only include low outliers. + - "both": include both high and low outliers. + + ** IMAGE INDEX ** + Each statistic is associated with the image whose score is the closest to the statistic's value. + + ** REPEATED VALUES ** + It is possible that two stats have the same value (e.g. the median and the 25th percentile can be the same). + Such cases are handled according to `repeated_policy`: + - None | "none": do not address the issue, so several stats can have the same value and image index. + - "avoid": avoid repeated values by iterativealy looking for other images with similar score, whose score + must be within `repeated_replacement_atol` (absolute tolerance) of the repeated value. + + Args: + per_image_scores (ndarray): 1D ndarray of per-image scores. + images_classes (ndarray | None): + Used to filter statistics to only one class. If None, all images are considered. + If given, 1D ndarray of binary image classes (0 for 'normal', 1 for 'anomalous'). Defaults to None. + only_class (int | None): + Only used if `images_classes` is not None. + If not None, only compute statistics for images of the given class. + `None` means both image classes are used. + Defaults to None. + outliers_policy (str | None): How to handle outliers stats (use them?). See `OutliersPolicy`. Defaults to None. + repeated_policy (str | None): How to handle repeated values in boxplot stats (two stats with same value). + See `RepeatedPolicy`. Defaults to None. + repeated_replacement_atol (float): Absolute tolerance used to replace repeated values. Only used if + `repeated_policy` is not None (or 'none'). Defaults to 1e-2 (1%). + + Returns: + list[dict[str, str | int | float]]: List of boxplot statistics. + + Each dictionary has the following keys: + - 'stat_name': Name of the statistic. Possible values: + - 'mean': Mean of the scores. + - 'med': Median of the scores. + - 'q1': 25th percentile of the scores. + - 'q3': 75th percentile of the scores. + - 'whishi': Upper whisker value. + - 'whislo': Lower whisker value. + - 'outlo_i': low outlier value; `i` is a unique index for each low outlier. + - 'outhi_j': high outlier value; `j` is a unique index for each high outlier. + - 'stat_value': Value of the statistic (same units as `values`). + - 'image_idx': Index of the image in `per_image_scores` whose score is the closest to the statistic's value. + - 'score': The score of the image at index `image_idx` (not necessarily the same as `stat_value`). + + The list is sorted by increasing `stat_value`. + """ + outliers_policy = StatsOutliersPolicy(outliers_policy) + repeated_policy = StatsRepeatedPolicy(repeated_policy) + _validate_is_per_image_scores(per_image_scores) + + # restrain the images to the class `only_class` if given, else use all images + if images_classes is None: + images_selection_mask = np.ones_like(per_image_scores, dtype=bool) + + elif only_class is not None: + _validate.is_images_classes(images_classes) + _validate.is_same_shape(per_image_scores, images_classes) + _validate_is_image_class(only_class) + images_selection_mask = images_classes == only_class + + else: + images_selection_mask = np.ones_like(per_image_scores, dtype=bool) + + # indexes in `per_image_scores` are referred to as `candidate_idx` + # while the indexes in the original array are referred to as `image_idx` + # - `candidate_idx` works for `per_image_scores` and `candidate2image_idx` (see below) + # - `image_idx` works for `images_classes` and `images_idxs_selected` + per_image_scores = per_image_scores[images_selection_mask] + # converts `candidate_idx` to `image_idx` + candidate2image_idx = np.nonzero(images_selection_mask)[0] + + # function used in `matplotlib.boxplot` + boxplot_stats = mpl.cbook.boxplot_stats(per_image_scores)[0] # [0] is for the only boxplot + + # remove unnecessary keys + boxplot_stats = {name: value for name, value in boxplot_stats.items() if name not in ("iqr", "cilo", "cihi")} + + # unroll `fliers` (outliers), remove unnecessary ones according to `outliers_policy`, + # then add them to `boxplot_stats` with unique keys + outliers = boxplot_stats.pop("fliers") + outliers_lo = outliers[outliers < boxplot_stats["med"]] + outliers_hi = outliers[outliers > boxplot_stats["med"]] + + if outliers_policy in (StatsOutliersPolicy.HI, StatsOutliersPolicy.BOTH): + boxplot_stats = { + **boxplot_stats, + **{f"outhi_{idx:06}": value for idx, value in enumerate(outliers_hi)}, + } + + if outliers_policy in (StatsOutliersPolicy.LO, StatsOutliersPolicy.BOTH): + boxplot_stats = { + **boxplot_stats, + **{f"outlo_{idx:06}": value for idx, value in enumerate(outliers_lo)}, + } + + # state variables for the stateful function `append_record` below + images_idxs_selected: set[int] = set() + records: list[dict[str, str | int | float]] = [] + + def append_record(stat_name: str, stat_value: float) -> None: + candidates_sorted = np.abs(per_image_scores - stat_value).argsort() + candidate_idx = candidates_sorted[0] + image_idx = candidate2image_idx[candidate_idx] + + # handle repeated values + if image_idx not in images_idxs_selected or repeated_policy == StatsRepeatedPolicy.NONE: + pass + + elif repeated_policy == StatsRepeatedPolicy.AVOID: + for other_candidate_idx in candidates_sorted: + other_candidate_image_idx = candidate2image_idx[other_candidate_idx] + if other_candidate_image_idx in images_idxs_selected: + continue + # if the code reaches here, it means that `other_candidate_image_idx` is not in `images_idxs_selected` + # i.e. this image has not been selected yet, so it can be used + other_candidate_score = per_image_scores[other_candidate_idx] + # if the other candidate is not too far from the value, use it + # note that the first choice has not changed, so if no other is selected in the loop + # it will be the first choice + if np.isclose(other_candidate_score, stat_value, atol=repeated_replacement_atol): + candidate_idx = other_candidate_idx + image_idx = other_candidate_image_idx + break + + images_idxs_selected.add(image_idx) + records.append( + { + "stat_name": stat_name, + "stat_value": float(stat_value), + "image_idx": int(image_idx), + "score": float(per_image_scores[candidate_idx]), + }, + ) + + # loop over the stats from the lowest to the highest value + for stat, val in sorted(boxplot_stats.items(), key=lambda x: x[1]): + append_record(stat, val) + return sorted(records, key=lambda r: r["score"]) + + +def compare_models_pairwise_ttest_rel( + scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray], + alternative: str, + higher_is_better: bool, +) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: + """Compare all pairs of models using the paired t-test on two related samples (parametric). + + This is a test for the null hypothesis that two repeated samples have identical average (expected) values. + In fact, it tests whether the average of the differences between the two samples is significantly different from 0. + + Refs: + - `scipy.stats.ttest_rel`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_rel.html + - Wikipedia page: https://en.wikipedia.org/wiki/Student's_t-test#Dependent_t-test_for_paired_samples + + === + + If an ordered dictionary is given, the models are sorted by the order of the dictionary. + Otherwise, the models are sorted by average SCORE. + + Args: + scores_per_model: Dictionary of `n` models and their per-image scores. + key: model name + value: tensor of shape (num_images,). All `nan` values must be at the same positions. + higher_is_better: Whether higher values of score are better or worse. Defaults to True. + alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. + Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. + + Returns: + (models_ordered, test_results): + - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). + + Automatic sorting is by average score from best to worst model. + Depending on `higher_is_better`, this corresponds to: + - `higher_is_better=True` ==> descending score order + - `higher_is_better=False` ==> ascending score order + along the indices from 0 to `n-1`. + + - confidences: Dictionary of confidence values for each pair of models. + + For all pairs of indices i and j from 0 to `n-1` such that i != j: + - key: (models_ordered[i], models_ordered[j]) + - value: confidence on the alternative hypothesis. + + For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: + - if `less`: model[i] < model[j] + - if `greater`: model[i] > model[j] + - if `two-sided`: model[i] != model[j] + in termos of average score. + """ + _validate_is_scores_per_model(scores_per_model) + StatsAlternativeHypothesis(alternative) + + # remove nan values; list of items keeps the order of the OrderedDict + scores_per_model_nonan_items = [ + (model_name, scores[~np.isnan(scores)]) for model_name, scores in scores_per_model.items() + ] + + # sort models by average value if not an ordered dictionary + # position 0 is assumed the best model + if isinstance(scores_per_model, OrderedDict): + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) + else: + scores_per_model_nonan = OrderedDict( + sorted(scores_per_model_nonan_items, key=lambda kv: kv[1].mean(), reverse=higher_is_better), + ) + + models_ordered = tuple(scores_per_model_nonan.keys()) + models_pairs = list(itertools.permutations(models_ordered, 2)) + confidences: dict[tuple[str, str], float] = {} + for model_i, model_j in models_pairs: + values_i = scores_per_model_nonan[model_i] + values_j = scores_per_model_nonan[model_j] + pvalue = scipy.stats.ttest_rel( + values_i, + values_j, + alternative=alternative, + ).pvalue + confidences[(model_i, model_j)] = 1.0 - float(pvalue) + + return models_ordered, confidences + + +def compare_models_pairwise_wilcoxon( + scores_per_model: dict[str, ndarray] | OrderedDict[str, ndarray], + alternative: str, + higher_is_better: bool, + atol: float | None = 1e-3, +) -> tuple[tuple[str, ...], dict[tuple[str, str], float]]: + """Compare all pairs of models using the Wilcoxon signed-rank test (non-parametric). + + Each comparison of two models is a Wilcoxon signed-rank test (null hypothesis is that they are equal). + + It tests whether the distribution of the differences of scores is symmetric about zero in a non-parametric way. + This is like the non-parametric version of the paired t-test. + + Refs: + - `scipy.stats.wilcoxon`: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html#scipy.stats.wilcoxon + - Wikipedia page: https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test + + === + + If an ordered dictionary is given, the models are sorted by the order of the dictionary. + Otherwise, the models are sorted by average RANK. + + Args: + scores_per_model: Dictionary of `n` models and their per-image scores. + key: model name + value: tensor of shape (num_images,). All `nan` values must be at the same positions. + higher_is_better: Whether higher values of score are better or worse. Defaults to True. + alternative: Alternative hypothesis for the statistical tests. See `confidences` in "Returns" section. + Valid values are `StatsAlternativeHypothesis.ALTERNATIVES`. + atol: Absolute tolerance used to consider two scores as equal. Defaults to 1e-3 (0.1%). + When doing a paired test, if the difference between two scores is below `atol`, the difference is + truncated to 0. If `atol` is None, no truncation is done. + + Returns: + (models_ordered, test_results): + - models_ordered: Models sorted by the user (`OrderedDict` input) or automatically (`dict` input). + + Automatic sorting is from "best to worst" model, which corresponds to ascending average rank + along the indices from 0 to `n-1`. + + - confidences: Dictionary of confidence values for each pair of models. + + For all pairs of indices i and j from 0 to `n-1` such that i != j: + - key: (models_ordered[i], models_ordered[j]) + - value: confidence on the alternative hypothesis. + + For models `models_ordered[i]` and `models_ordered[j]`, the alternative hypothesis is: + - if `less`: model[i] < model[j] + - if `greater`: model[i] > model[j] + - if `two-sided`: model[i] != model[j] + in terms of average ranks (not scores!). + """ + _validate_is_scores_per_model(scores_per_model) + StatsAlternativeHypothesis(alternative) + + # remove nan values; list of items keeps the order of the OrderedDict + scores_per_model_nonan_items = [ + (model_name, scores[~np.isnan(scores)]) for model_name, scores in scores_per_model.items() + ] + + # sort models by average value if not an ordered dictionary + # position 0 is assumed the best model + if isinstance(scores_per_model, OrderedDict): + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items) + else: + # these average ranks will NOT consider `atol` because we want to rank the models anyway + scores_nonan = np.stack([v for _, v in scores_per_model_nonan_items], axis=0) + avg_ranks = scipy.stats.rankdata( + -scores_nonan if higher_is_better else scores_nonan, + method="average", + axis=0, + ).mean(axis=1) + # ascending order, lower score is better --> best to worst model + argsort_avg_ranks = avg_ranks.argsort() + scores_per_model_nonan = OrderedDict(scores_per_model_nonan_items[idx] for idx in argsort_avg_ranks) + + models_ordered = tuple(scores_per_model_nonan.keys()) + models_pairs = list(itertools.permutations(models_ordered, 2)) + confidences: dict[tuple[str, str], float] = {} + for model_i, model_j in models_pairs: + values_i = scores_per_model_nonan[model_i] + values_j = scores_per_model_nonan[model_j] + diff = values_i - values_j + + if atol is not None: + # make the difference null if below the tolerance + diff[np.abs(diff) <= atol] = 0.0 + + # extreme case + if (diff == 0).all(): # noqa: SIM108 + pvalue = 1.0 + else: + pvalue = scipy.stats.wilcoxon(diff, alternative=alternative).pvalue + confidences[(model_i, model_j)] = 1.0 - float(pvalue) + + return models_ordered, confidences diff --git a/tests/unit/data/utils/test_path.py b/tests/unit/data/utils/test_path.py index c3f134b021..5a4b8fee45 100644 --- a/tests/unit/data/utils/test_path.py +++ b/tests/unit/data/utils/test_path.py @@ -76,3 +76,8 @@ def test_no_read_execute_permission() -> None: Path(tmp_dir).chmod(0o222) # Remove read and execute permission with pytest.raises(PermissionError, match=r"Read or execute permissions denied for the path:*"): validate_path(tmp_dir, base_dir=Path(tmp_dir)) + + def test_file_wrongsuffix(self) -> None: + """Test ``validate_path`` raises ValueError for a file with wrong suffix.""" + with pytest.raises(ValueError, match="Path extension is not accepted."): + validate_path("file.png", should_exist=False, accepted_extensions=(".json", ".txt")) diff --git a/tests/unit/metrics/per_image/__init__.py b/tests/unit/metrics/per_image/__init__.py new file mode 100644 index 0000000000..555d67a102 --- /dev/null +++ b/tests/unit/metrics/per_image/__init__.py @@ -0,0 +1,8 @@ +"""Per-Image Metrics Tests.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/metrics/per_image/test_binclf_curve.py b/tests/unit/metrics/per_image/test_binclf_curve.py new file mode 100644 index 0000000000..6b0499bf9a --- /dev/null +++ b/tests/unit/metrics/per_image/test_binclf_curve.py @@ -0,0 +1,531 @@ +"""Tests for per-image binary classification curves using numpy and numba versions.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: SLF001, PT011 + +import numpy as np +import pytest +import torch +from numpy import ndarray +from torch import Tensor + +from anomalib.metrics.per_image import binclf_curve, binclf_curve_numpy +from anomalib.metrics.per_image.binclf_curve_numpy import HAS_NUMBA + +if HAS_NUMBA: + from anomalib.metrics.per_image import _binclf_curve_numba + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Generate test cases.""" + pred = np.arange(1, 5, dtype=np.float32) + threshs = np.arange(1, 5, dtype=np.float32) + + gt_norm = np.zeros(4).astype(bool) + gt_anom = np.concatenate([np.zeros(2), np.ones(2)]).astype(bool) + + # in the case where thresholds are all unique values in the predictions + expected_norm = np.stack( + [ + np.array([[0, 4], [0, 0]]), + np.array([[1, 3], [0, 0]]), + np.array([[2, 2], [0, 0]]), + np.array([[3, 1], [0, 0]]), + ], + axis=0, + ).astype(int) + expected_anom = np.stack( + [ + np.array([[0, 2], [0, 2]]), + np.array([[1, 1], [0, 2]]), + np.array([[2, 0], [0, 2]]), + np.array([[2, 0], [1, 1]]), + ], + axis=0, + ).astype(int) + + expected_tprs_norm = np.array([np.nan, np.nan, np.nan, np.nan]) + expected_tprs_anom = np.array([1.0, 1.0, 1.0, 0.5]) + expected_tprs = np.stack([expected_tprs_anom, expected_tprs_norm], axis=0).astype(np.float64) + + expected_fprs_norm = np.array([1.0, 0.75, 0.5, 0.25]) + expected_fprs_anom = np.array([1.0, 0.5, 0.0, 0.0]) + expected_fprs = np.stack([expected_fprs_anom, expected_fprs_norm], axis=0).astype(np.float64) + + # in the case where all thresholds are higher than the highest prediction + expected_norm_threshs_too_high = np.stack( + [ + np.array([[4, 0], [0, 0]]), + np.array([[4, 0], [0, 0]]), + np.array([[4, 0], [0, 0]]), + np.array([[4, 0], [0, 0]]), + ], + axis=0, + ).astype(int) + expected_anom_threshs_too_high = np.stack( + [ + np.array([[2, 0], [2, 0]]), + np.array([[2, 0], [2, 0]]), + np.array([[2, 0], [2, 0]]), + np.array([[2, 0], [2, 0]]), + ], + axis=0, + ).astype(int) + + # in the case where all thresholds are lower than the lowest prediction + expected_norm_threshs_too_low = np.stack( + [ + np.array([[0, 4], [0, 0]]), + np.array([[0, 4], [0, 0]]), + np.array([[0, 4], [0, 0]]), + np.array([[0, 4], [0, 0]]), + ], + axis=0, + ).astype(int) + expected_anom_threshs_too_low = np.stack( + [ + np.array([[0, 2], [0, 2]]), + np.array([[0, 2], [0, 2]]), + np.array([[0, 2], [0, 2]]), + np.array([[0, 2], [0, 2]]), + ], + axis=0, + ).astype(int) + + if metafunc.function is test__binclf_one_curve_python or metafunc.function is test__binclf_one_curve_numba: + metafunc.parametrize( + argnames=("pred", "gt", "threshs", "expected"), + argvalues=[ + (pred, gt_anom, threshs[:3], expected_anom[:3]), + (pred, gt_anom, threshs, expected_anom), + (pred, gt_norm, threshs, expected_norm), + (pred, gt_norm, 10 * threshs, expected_norm_threshs_too_high), + (pred, gt_anom, 10 * threshs, expected_anom_threshs_too_high), + (pred, gt_norm, 0.001 * threshs, expected_norm_threshs_too_low), + (pred, gt_anom, 0.001 * threshs, expected_anom_threshs_too_low), + ], + ) + + preds = np.stack([pred, pred], axis=0) + gts = np.stack([gt_anom, gt_norm], axis=0) + binclf_curves = np.stack([expected_anom, expected_norm], axis=0) + binclf_curves_threshs_too_high = np.stack([expected_anom_threshs_too_high, expected_norm_threshs_too_high], axis=0) + binclf_curves_threshs_too_low = np.stack([expected_anom_threshs_too_low, expected_norm_threshs_too_low], axis=0) + + if ( + metafunc.function is test__binclf_multiple_curves_python + or metafunc.function is test__binclf_multiple_curves_numba + ): + metafunc.parametrize( + argnames=("preds", "gts", "threshs", "expecteds"), + argvalues=[ + (preds, gts, threshs[:3], binclf_curves[:, :3]), + (preds, gts, threshs, binclf_curves), + ], + ) + + if metafunc.function is test_binclf_multiple_curves: + metafunc.parametrize( + argnames=( + "preds", + "gts", + "threshs", + "expected_binclf_curves", + ), + argvalues=[ + (preds[:1], gts[:1], threshs, binclf_curves[:1]), + (preds, gts, threshs, binclf_curves), + (10 * preds, gts, 10 * threshs, binclf_curves), + ], + ) + metafunc.parametrize( + argnames=("algorithm",), + argvalues=[ + ("python",), + ("numba",), + ], + ) + + if metafunc.function is test_binclf_multiple_curves_validations: + metafunc.parametrize( + argnames=("args", "kwargs", "exception"), + argvalues=[ + # `scores` and `gts` must be 2D + ([preds.reshape(2, 2, 2), gts, threshs], {"algorithm": "numba"}, ValueError), + ([preds, gts.flatten(), threshs], {"algorithm": "numba"}, ValueError), + # `threshs` must be 1D + ([preds, gts, threshs.reshape(2, 2)], {"algorithm": "numba"}, ValueError), + # `scores` and `gts` must have the same shape + ([preds, gts[:1], threshs], {"algorithm": "numba"}, ValueError), + ([preds[:, :2], gts, threshs], {"algorithm": "numba"}, ValueError), + # `scores` be of type float + ([preds.astype(int), gts, threshs], {"algorithm": "numba"}, TypeError), + # `gts` be of type bool + ([preds, gts.astype(int), threshs], {"algorithm": "numba"}, TypeError), + # `threshs` be of type float + ([preds, gts, threshs.astype(int)], {"algorithm": "numba"}, TypeError), + # `threshs` must be sorted in ascending order + ([preds, gts, np.flip(threshs)], {"algorithm": "numba"}, ValueError), + ([preds, gts, np.concatenate([threshs[-2:], threshs[:2]])], {"algorithm": "numba"}, ValueError), + # `threshs` must be unique + ([preds, gts, np.sort(np.concatenate([threshs, threshs]))], {"algorithm": "numba"}, ValueError), + # invalid `algorithm` + ([preds, gts, threshs], {"algorithm": "blurp"}, ValueError), + ], + ) + + # the following tests are for `per_image_binclf_curve()`, which expects + # inputs in image spatial format, i.e. (height, width) + preds = preds.reshape(2, 2, 2) + gts = gts.reshape(2, 2, 2) + + per_image_binclf_curves_numpy_argvalues = [ + # `threshs_choice` = "given" + ( + preds, + gts, + "given", + threshs, + None, + threshs, + binclf_curves, + ), + ( + preds, + gts, + "given", + 10 * threshs, + 2, + 10 * threshs, + binclf_curves_threshs_too_high, + ), + ( + preds, + gts, + "given", + 0.01 * threshs, + None, + 0.01 * threshs, + binclf_curves_threshs_too_low, + ), + # `threshs_choice` = 'minmax-linspace'" + ( + preds, + gts, + "minmax-linspace", + None, + len(threshs), + threshs, + binclf_curves, + ), + ( + 2 * preds, + gts.astype(int), # this is ok + "minmax-linspace", + None, + len(threshs), + 2 * threshs, + binclf_curves, + ), + ] + + if metafunc.function is test_per_image_binclf_curve_numpy: + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + "threshs_choice", + "threshs_given", + "num_threshs", + "expected_threshs", + "expected_binclf_curves", + ), + argvalues=per_image_binclf_curves_numpy_argvalues, + ) + + # the test with the torch interface are the same we just convert ndarray to Tensor + if metafunc.function is test_per_image_binclf_curve_torch: + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + "threshs_choice", + "threshs_given", + "num_threshs", + "expected_threshs", + "expected_binclf_curves", + ), + argvalues=[ + tuple(torch.from_numpy(v) if isinstance(v, np.ndarray) else v for v in argvals) + for argvals in per_image_binclf_curves_numpy_argvalues + ], + ) + + if metafunc.function is test_per_image_binclf_curve_numpy or metafunc.function is test_per_image_binclf_curve_torch: + metafunc.parametrize( + argnames=("algorithm",), + argvalues=[ + ("python",), + ("numba",), + ], + ) + + if metafunc.function is test_per_image_binclf_curve_numpy_validations: + metafunc.parametrize( + argnames=("args", "exception"), + argvalues=[ + # `scores` and `gts` must be 3D + ([preds.reshape(2, 2, 2, 1), gts], ValueError), + ([preds, gts.flatten()], ValueError), + # `scores` and `gts` must have the same shape + ([preds, gts[:1]], ValueError), + ([preds[:, :1], gts], ValueError), + # `scores` be of type float + ([preds.astype(int), gts], TypeError), + # `gts` be of type bool or int + ([preds, gts.astype(float)], TypeError), + # `threshs` be of type float + ([preds, gts, threshs.astype(int)], TypeError), + ], + ) + metafunc.parametrize( + argnames=("kwargs",), + argvalues=[ + ({"algorithm": "numba", "threshs_choice": "given", "threshs_given": threshs, "num_threshs": None},), + ( + { + "algorithm": "python", + "threshs_choice": "minmax-linspace", + "threshs_given": None, + "num_threshs": len(threshs), + }, + ), + ], + ) + + # same as above but testing other validations + if metafunc.function is test_per_image_binclf_curve_numpy_validations_alt: + metafunc.parametrize( + argnames=("args", "kwargs", "exception"), + argvalues=[ + # invalid `threshs_choice` + ( + [preds, gts], + {"algorithm": "glfrb", "threshs_choice": "given", "threshs_given": threshs, "num_threshs": None}, + ValueError, + ), + ], + ) + + if metafunc.function is test_rate_metrics_numpy: + metafunc.parametrize( + argnames=("binclf_curves", "expected_fprs", "expected_tprs"), + argvalues=[ + (binclf_curves, expected_fprs, expected_tprs), + (10 * binclf_curves, expected_fprs, expected_tprs), + ], + ) + + if metafunc.function is test_rate_metrics_torch: + metafunc.parametrize( + argnames=("binclf_curves", "expected_fprs", "expected_tprs"), + argvalues=[ + (torch.from_numpy(binclf_curves), torch.from_numpy(expected_fprs), torch.from_numpy(expected_tprs)), + ], + ) + + +# ================================================================================================== +# LOW-LEVEL FUNCTIONS (PYTHON) + + +def test__binclf_one_curve_python(pred: ndarray, gt: ndarray, threshs: ndarray, expected: ndarray) -> None: + """Test if `_binclf_one_curve_python()` returns the expected values.""" + computed = binclf_curve_numpy._binclf_one_curve_python(pred, gt, threshs) + assert computed.shape == (threshs.size, 2, 2) + assert (computed == expected).all() + + +def test__binclf_multiple_curves_python( + preds: ndarray, + gts: ndarray, + threshs: ndarray, + expecteds: ndarray, +) -> None: + """Test if `_binclf_multiple_curves_python()` returns the expected values.""" + computed = binclf_curve_numpy._binclf_multiple_curves_python(preds, gts, threshs) + assert computed.shape == (preds.shape[0], threshs.size, 2, 2) + assert (computed == expecteds).all() + + +# ================================================================================================== +# LOW-LEVEL FUNCTIONS (NUMBA) + + +def test__binclf_one_curve_numba(pred: ndarray, gt: ndarray, threshs: ndarray, expected: ndarray) -> None: + """Test if `_binclf_one_curve_numba()` returns the expected values.""" + if not HAS_NUMBA: + pytest.skip("Numba is not available.") + computed = _binclf_curve_numba.binclf_one_curve_numba(pred, gt, threshs) + assert computed.shape == (threshs.size, 2, 2) + assert (computed == expected).all() + + +def test__binclf_multiple_curves_numba(preds: ndarray, gts: ndarray, threshs: ndarray, expecteds: ndarray) -> None: + """Test if `_binclf_multiple_curves_python()` returns the expected values.""" + if not HAS_NUMBA: + pytest.skip("Numba is not available.") + computed = _binclf_curve_numba.binclf_multiple_curves_numba(preds, gts, threshs) + assert computed.shape == (preds.shape[0], threshs.size, 2, 2) + assert (computed == expecteds).all() + + +# ================================================================================================== +# API FUNCTIONS (NUMPY) + + +def test_binclf_multiple_curves( + preds: ndarray, + gts: ndarray, + threshs: ndarray, + expected_binclf_curves: ndarray, + algorithm: str, +) -> None: + """Test if `binclf_multiple_curves()` returns the expected values.""" + computed = binclf_curve_numpy.binclf_multiple_curves( + preds, + gts, + threshs, + algorithm=algorithm, + ) + assert computed.shape == expected_binclf_curves.shape + assert (computed == expected_binclf_curves).all() + + # it's ok to have the threhsholds beyond the range of the preds + binclf_curve_numpy.binclf_multiple_curves(preds, gts, 2 * threshs, algorithm=algorithm) + + # or inside the bounds without reaching them + binclf_curve_numpy.binclf_multiple_curves(preds, gts, 0.5 * threshs, algorithm=algorithm) + + # it's also ok to have more threshs than unique values in the preds + # add the values in between the threshs + threshs_unncessary = 0.5 * (threshs[:-1] + threshs[1:]) + threshs_unncessary = np.concatenate([threshs_unncessary, threshs]) + threshs_unncessary = np.sort(threshs_unncessary) + binclf_curve_numpy.binclf_multiple_curves(preds, gts, threshs_unncessary, algorithm=algorithm) + + # or less + binclf_curve_numpy.binclf_multiple_curves(preds, gts, threshs[1:3], algorithm=algorithm) + + +def test_binclf_multiple_curves_validations(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `_binclf_multiple_curves_python()` raises the expected errors.""" + with pytest.raises(exception): + binclf_curve_numpy.binclf_multiple_curves(*args, **kwargs) + + +def test_per_image_binclf_curve_numpy( + anomaly_maps: ndarray, + masks: ndarray, + algorithm: str, + threshs_choice: str, + threshs_given: ndarray | None, + num_threshs: int | None, + expected_threshs: ndarray, + expected_binclf_curves: ndarray, +) -> None: + """Test if `per_image_binclf_curve()` returns the expected values.""" + computed_threshs, computed_binclf_curves = binclf_curve_numpy.per_image_binclf_curve( + anomaly_maps, + masks, + algorithm=algorithm, + threshs_choice=threshs_choice, + threshs_given=threshs_given, + num_threshs=num_threshs, + ) + + # threshs + assert computed_threshs.shape == expected_threshs.shape + assert computed_threshs.dtype == computed_threshs.dtype + assert (computed_threshs == expected_threshs).all() + + # binclf_curves + assert computed_binclf_curves.shape == expected_binclf_curves.shape + assert computed_binclf_curves.dtype == expected_binclf_curves.dtype + assert (computed_binclf_curves == expected_binclf_curves).all() + + +def test_per_image_binclf_curve_numpy_validations(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `per_image_binclf_curve()` raises the expected errors.""" + with pytest.raises(exception): + binclf_curve_numpy.per_image_binclf_curve(*args, **kwargs) + + +def test_per_image_binclf_curve_numpy_validations_alt(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `per_image_binclf_curve()` raises the expected errors.""" + test_per_image_binclf_curve_numpy_validations(args, kwargs, exception) + + +def test_rate_metrics_numpy(binclf_curves: ndarray, expected_fprs: ndarray, expected_tprs: ndarray) -> None: + """Test if rate metrics are computed correctly.""" + tprs = binclf_curve_numpy.per_image_tpr(binclf_curves) + fprs = binclf_curve_numpy.per_image_fpr(binclf_curves) + + assert tprs.shape == expected_tprs.shape + assert fprs.shape == expected_fprs.shape + + assert np.allclose(tprs, expected_tprs, equal_nan=True) + assert np.allclose(fprs, expected_fprs, equal_nan=True) + + +# ================================================================================================== +# API FUNCTIONS (TORCH) + + +def test_per_image_binclf_curve_torch( + anomaly_maps: Tensor, + masks: Tensor, + algorithm: str, + threshs_choice: str, + threshs_given: Tensor | None, + num_threshs: int | None, + expected_threshs: Tensor, + expected_binclf_curves: Tensor, +) -> None: + """Test if `per_image_binclf_curve()` returns the expected values.""" + computed_threshs, computed_binclf_curves = binclf_curve.per_image_binclf_curve( + anomaly_maps, + masks, + algorithm=algorithm, + threshs_choice=threshs_choice, + threshs_given=threshs_given, + num_threshs=num_threshs, + ) + + # threshs + assert computed_threshs.shape == expected_threshs.shape + assert computed_threshs.dtype == computed_threshs.dtype + assert (computed_threshs == expected_threshs).all() + + # binclf_curves + assert computed_binclf_curves.shape == expected_binclf_curves.shape + assert computed_binclf_curves.dtype == expected_binclf_curves.dtype + assert (computed_binclf_curves == expected_binclf_curves).all() + + +def test_rate_metrics_torch(binclf_curves: Tensor, expected_fprs: Tensor, expected_tprs: Tensor) -> None: + """Test if rate metrics are computed correctly.""" + tprs = binclf_curve.per_image_tpr(binclf_curves) + fprs = binclf_curve.per_image_fpr(binclf_curves) + + assert tprs.shape == expected_tprs.shape + assert fprs.shape == expected_fprs.shape + + assert torch.allclose(tprs, expected_tprs, equal_nan=True) + assert torch.allclose(fprs, expected_fprs, equal_nan=True) diff --git a/tests/unit/metrics/per_image/test_pimo.py b/tests/unit/metrics/per_image/test_pimo.py new file mode 100644 index 0000000000..ce30a13542 --- /dev/null +++ b/tests/unit/metrics/per_image/test_pimo.py @@ -0,0 +1,604 @@ +"""Test `anomalib.metrics.per_image.pimo_numpy`.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch +from numpy import ndarray +from torch import Tensor + +from anomalib.metrics.per_image import pimo, pimo_numpy +from anomalib.metrics.per_image.pimo import AUPIMOResult, PIMOResult + +from .test_utils import assert_statsdict_stuff + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Generate tests for all functions in this module. + + All functions are parametrized with the same setting: 1 normal and 2 anomalous images. + The anomaly maps are the same for all functions, but the masks are different. + """ + expected_threshs = np.arange(1, 7 + 1, dtype=np.float32) + shape = (1000, 1000) # (H, W), 1 million pixels + + # --- normal --- + # histogram of scores: + # value: 7 6 5 4 3 2 1 + # count: 1 9 90 900 9k 90k 900k + # cumsum: 1 10 100 1k 10k 100k 1M + pred_norm = np.ones(1_000_000, dtype=np.float32) + pred_norm[:100_000] += 1 + pred_norm[:10_000] += 1 + pred_norm[:1_000] += 1 + pred_norm[:100] += 1 + pred_norm[:10] += 1 + pred_norm[:1] += 1 + pred_norm = pred_norm.reshape(shape) + mask_norm = np.zeros_like(pred_norm, dtype=np.int32) + + expected_fpr_norm = np.array([1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], dtype=np.float64) + expected_tpr_norm = np.full((7,), np.nan, dtype=np.float64) + + # --- anomalous --- + pred_anom1 = pred_norm.copy() + mask_anom1 = np.ones_like(pred_anom1, dtype=np.int32) + expected_tpr_anom1 = expected_fpr_norm.copy() + + # only the first 100_000 pixels are anomalous + # which corresponds to the first 100_000 highest scores (2 to 7) + pred_anom2 = pred_norm.copy() + mask_anom2 = np.concatenate([np.ones(100_000), np.zeros(900_000)]).reshape(shape).astype(np.int32) + expected_tpr_anom2 = (10 * expected_fpr_norm).clip(0, 1) + + anomaly_maps = np.stack([pred_norm, pred_anom1, pred_anom2], axis=0) + masks = np.stack([mask_norm, mask_anom1, mask_anom2], axis=0) + + expected_shared_fpr = expected_fpr_norm + expected_per_image_tprs = np.stack([expected_tpr_norm, expected_tpr_anom1, expected_tpr_anom2], axis=0) + expected_image_classes = np.array([0, 1, 1], dtype=np.int32) + + if ( + metafunc.function is test_pimo_numpy + or metafunc.function is test_pimo + or metafunc.function is test_aupimo_values_numpy + or metafunc.function is test_aupimo_values + ): + argvalues_arrays = [ + ( + anomaly_maps, + masks, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ), + ( + 10 * anomaly_maps, + masks, + 10 * expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ), + ] + argvalues_tensors = [ + tuple(torch.from_numpy(arg) if isinstance(arg, ndarray) else arg for arg in arvals) + for arvals in argvalues_arrays + ] + argvalues = argvalues_arrays if "numpy" in metafunc.function.__name__ else argvalues_tensors + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + "expected_threshs", + "expected_shared_fpr", + "expected_per_image_tprs", + "expected_image_classes", + ), + argvalues=argvalues, + ) + + if metafunc.function is test_aupimo_values_numpy or metafunc.function is test_aupimo_values: + argvalues_arrays = [ + ( + (1e-1, 1.0), + np.array( + [ + np.nan, + # recall: trapezium area = (a + b) * h / 2 + (0.10 + 1.0) * 1 / 2, + (1.0 + 1.0) * 1 / 2, + ], + dtype=np.float64, + ), + ), + ( + (1e-3, 1e-1), + np.array( + [ + np.nan, + # average of two trapezium areas / 2 (normalizing factor) + (((1e-3 + 1e-2) * 1 / 2) + ((1e-2 + 1e-1) * 1 / 2)) / 2, + (((1e-2 + 1e-1) * 1 / 2) + ((1e-1 + 1.0) * 1 / 2)) / 2, + ], + dtype=np.float64, + ), + ), + ( + (1e-5, 1e-4), + np.array( + [ + np.nan, + (1e-5 + 1e-4) * 1 / 2, + (1e-4 + 1e-3) * 1 / 2, + ], + dtype=np.float64, + ), + ), + ] + argvalues_tensors = [ + tuple(torch.from_numpy(arg) if isinstance(arg, ndarray) else arg for arg in arvals) + for arvals in argvalues_arrays + ] + argvalues = argvalues_arrays if "numpy" in metafunc.function.__name__ else argvalues_tensors + metafunc.parametrize( + argnames=( + "fpr_bounds", + "expected_aupimos", # trapezoid surfaces + ), + argvalues=argvalues, + ) + + if metafunc.function is test_aupimo_edge: + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + ), + argvalues=[ + ( + anomaly_maps, + masks, + ), + ( + 10 * anomaly_maps, + masks, + ), + ], + ) + metafunc.parametrize( + argnames=("fpr_bounds",), + argvalues=[ + ((1e-1, 1.0),), + ((1e-3, 1e-2),), + ((1e-5, 1e-4),), + (None,), + ], + ) + + if metafunc.function is test_pimoresult_object or metafunc.function is test_aupimoresult_object: + anomaly_maps = torch.from_numpy(anomaly_maps) + masks = torch.from_numpy(masks) + metafunc.parametrize(argnames=("anomaly_maps", "masks"), argvalues=[(anomaly_maps, masks)]) + metafunc.parametrize(argnames=("paths",), argvalues=[(None,), (["/path/to/a", "/path/to/b", "/path/to/c"],)]) + + +def _do_test_pimo_outputs( + threshs: ndarray | Tensor, + shared_fpr: ndarray | Tensor, + per_image_tprs: ndarray | Tensor, + image_classes: ndarray | Tensor, + expected_threshs: ndarray | Tensor, + expected_shared_fpr: ndarray | Tensor, + expected_per_image_tprs: ndarray | Tensor, + expected_image_classes: ndarray | Tensor, +) -> None: + """Test if the outputs of any of the PIMO interfaces are correct.""" + if isinstance(threshs, Tensor): + assert isinstance(shared_fpr, Tensor) + assert isinstance(per_image_tprs, Tensor) + assert isinstance(image_classes, Tensor) + assert isinstance(expected_threshs, Tensor) + assert isinstance(expected_shared_fpr, Tensor) + assert isinstance(expected_per_image_tprs, Tensor) + assert isinstance(expected_image_classes, Tensor) + allclose = torch.allclose + + elif isinstance(threshs, ndarray): + assert isinstance(shared_fpr, ndarray) + assert isinstance(per_image_tprs, ndarray) + assert isinstance(image_classes, ndarray) + assert isinstance(expected_threshs, ndarray) + assert isinstance(expected_shared_fpr, ndarray) + assert isinstance(expected_per_image_tprs, ndarray) + assert isinstance(expected_image_classes, ndarray) + allclose = np.allclose + + else: + msg = "Expected `threshs` to be a Tensor or ndarray." + raise TypeError(msg) + + assert threshs.ndim == 1 + assert shared_fpr.ndim == 1 + assert per_image_tprs.ndim == 2 + assert tuple(image_classes.shape) == (3,) + + assert allclose(threshs, expected_threshs) + assert allclose(shared_fpr, expected_shared_fpr) + assert allclose(per_image_tprs, expected_per_image_tprs, equal_nan=True) + assert (image_classes == expected_image_classes).all() + + +def test_pimo_numpy( + anomaly_maps: ndarray, + masks: ndarray, + expected_threshs: ndarray, + expected_shared_fpr: ndarray, + expected_per_image_tprs: ndarray, + expected_image_classes: ndarray, +) -> None: + """Test if `pimo()` returns the expected values.""" + threshs, shared_fpr, per_image_tprs, image_classes = pimo_numpy.pimo_curves( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + ) + _do_test_pimo_outputs( + threshs, + shared_fpr, + per_image_tprs, + image_classes, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ) + + +def test_pimo( + anomaly_maps: Tensor, + masks: Tensor, + expected_threshs: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, +) -> None: + """Test if `pimo()` returns the expected values.""" + + def do_assertions(pimoresult: PIMOResult) -> None: + threshs = pimoresult.threshs + shared_fpr = pimoresult.shared_fpr + per_image_tprs = pimoresult.per_image_tprs + image_classes = pimoresult.image_classes + _do_test_pimo_outputs( + threshs, + shared_fpr, + per_image_tprs, + image_classes, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ) + + # functional interface + pimoresult = pimo.pimo_curves( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + ) + do_assertions(pimoresult) + + # metric interface + metric = pimo.PIMO( + num_threshs=7, + binclf_algorithm="numba", + ) + metric.update(anomaly_maps, masks) + pimoresult = metric.compute() + do_assertions(pimoresult) + + +def _do_test_aupimo_outputs( + threshs: ndarray | Tensor, + shared_fpr: ndarray | Tensor, + per_image_tprs: ndarray | Tensor, + image_classes: ndarray | Tensor, + aupimos: ndarray | Tensor, + expected_threshs: ndarray | Tensor, + expected_shared_fpr: ndarray | Tensor, + expected_per_image_tprs: ndarray | Tensor, + expected_image_classes: ndarray | Tensor, + expected_aupimos: ndarray | Tensor, +) -> None: + _do_test_pimo_outputs( + threshs, + shared_fpr, + per_image_tprs, + image_classes, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ) + if isinstance(threshs, Tensor): + assert isinstance(aupimos, Tensor) + assert isinstance(expected_aupimos, Tensor) + allclose = torch.allclose + + elif isinstance(threshs, ndarray): + assert isinstance(aupimos, ndarray) + assert isinstance(expected_aupimos, ndarray) + allclose = np.allclose + assert tuple(aupimos.shape) == (3,) + assert allclose(aupimos, expected_aupimos, equal_nan=True) + + +def test_aupimo_values_numpy( + anomaly_maps: ndarray, + masks: ndarray, + fpr_bounds: tuple[float, float], + expected_threshs: ndarray, + expected_shared_fpr: ndarray, + expected_per_image_tprs: ndarray, + expected_image_classes: ndarray, + expected_aupimos: ndarray, +) -> None: + """Test if `aupimo()` returns the expected values.""" + threshs, shared_fpr, per_image_tprs, image_classes, aupimos, _ = pimo_numpy.aupimo_scores( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + fpr_bounds=fpr_bounds, + force=True, + ) + _do_test_aupimo_outputs( + threshs, + shared_fpr, + per_image_tprs, + image_classes, + aupimos, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + expected_aupimos, + ) + + +def test_aupimo_values( + anomaly_maps: ndarray, + masks: ndarray, + fpr_bounds: tuple[float, float], + expected_threshs: ndarray, + expected_shared_fpr: ndarray, + expected_per_image_tprs: ndarray, + expected_image_classes: ndarray, + expected_aupimos: ndarray, +) -> None: + """Test if `aupimo()` returns the expected values.""" + + def do_assertions(pimoresult: PIMOResult, aupimoresult: AUPIMOResult) -> None: + # test metadata + assert aupimoresult.fpr_bounds == fpr_bounds + # recall: this one is not the same as the number of thresholds in the curve + # this is the number of thresholds used to compute the integral in `aupimo()` + # always less because of the integration bounds + assert aupimoresult.num_threshs < 7 + + # test data + # from pimo result + threshs = pimoresult.threshs + shared_fpr = pimoresult.shared_fpr + per_image_tprs = pimoresult.per_image_tprs + image_classes = pimoresult.image_classes + # from aupimo result + aupimos = aupimoresult.aupimos + _do_test_aupimo_outputs( + threshs, + shared_fpr, + per_image_tprs, + image_classes, + aupimos, + expected_threshs, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + expected_aupimos, + ) + thresh_lower_bound = aupimoresult.thresh_lower_bound + thresh_upper_bound = aupimoresult.thresh_upper_bound + assert anomaly_maps.min() <= thresh_lower_bound < thresh_upper_bound <= anomaly_maps.max() + + # functional interface + pimoresult_from_functional, aupimoresult_from_functional = pimo.aupimo_scores( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + fpr_bounds=fpr_bounds, + force=True, + ) + do_assertions(pimoresult_from_functional, aupimoresult_from_functional) + + # metric interface + metric = pimo.AUPIMO( + num_threshs=7, + binclf_algorithm="numba", + fpr_bounds=fpr_bounds, + return_average=False, + force=True, + ) + metric.update(anomaly_maps, masks) + pimoresult_from_metric, aupimoresult_from_metric = metric.compute() + do_assertions(pimoresult_from_metric, aupimoresult_from_metric) + + # metric interface + metric = pimo.AUPIMO( + num_threshs=7, + binclf_algorithm="numba", + fpr_bounds=fpr_bounds, + return_average=True, # only return the average AUPIMO + force=True, + ) + metric.update(anomaly_maps, masks) + metric.compute() + + +def test_aupimo_edge( + anomaly_maps: ndarray, + masks: ndarray, + fpr_bounds: tuple[float, float], +) -> None: + """Test some edge cases.""" + # None is the case of testing the default bounds + fpr_bounds = {"fpr_bounds": fpr_bounds} if fpr_bounds is not None else {} + + # not enough points on the curve + # 10 threshs / 6 decades = 1.6 threshs per decade < 3 + with pytest.raises(RuntimeError): # force=False --> raise error + pimo_numpy.aupimo_scores( + anomaly_maps, + masks, + num_threshs=10, + binclf_algorithm="numba", + force=False, + **fpr_bounds, + ) + + with pytest.warns(RuntimeWarning): # force=True --> warn + pimo_numpy.aupimo_scores( + anomaly_maps, + masks, + num_threshs=10, + binclf_algorithm="numba", + force=True, + **fpr_bounds, + ) + + # default number of points on the curve (300k threshs) should be enough + rng = np.random.default_rng(42) + pimo_numpy.aupimo_scores( + anomaly_maps * rng.uniform(1.0, 1.1, size=anomaly_maps.shape), + masks, + # num_threshs=, + binclf_algorithm="numba", + force=False, + **fpr_bounds, + ) + + +def test_pimoresult_object( + anomaly_maps: Tensor, + masks: Tensor, + paths: list[str] | None, +) -> None: + """Test if `PIMOResult` can be converted to other formats and back.""" + optional_kwargs = {} + if paths is not None: + optional_kwargs["paths"] = paths + + pimoresult = pimo.pimo_curves( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + **optional_kwargs, + ) + + _ = pimoresult.num_threshs + _ = pimoresult.num_images + _ = pimoresult.image_classes + + # object -> dict -> object + dic = pimoresult.to_dict() + assert isinstance(dic, dict) + pimoresult_from_dict = PIMOResult.from_dict(dic) + assert isinstance(pimoresult_from_dict, PIMOResult) + # values should be the same + assert torch.allclose(pimoresult_from_dict.threshs, pimoresult.threshs) + assert torch.allclose(pimoresult_from_dict.shared_fpr, pimoresult.shared_fpr) + assert torch.allclose(pimoresult_from_dict.per_image_tprs, pimoresult.per_image_tprs, equal_nan=True) + + # object -> file -> object + with tempfile.TemporaryDirectory() as tmpdir: + file_path = Path(tmpdir) / "pimo.pt" + pimoresult.save(str(file_path)) + assert file_path.exists() + pimoresult_from_load = PIMOResult.load(str(file_path)) + assert isinstance(pimoresult_from_load, PIMOResult) + # values should be the same + assert torch.allclose(pimoresult_from_load.threshs, pimoresult.threshs) + assert torch.allclose(pimoresult_from_load.shared_fpr, pimoresult.shared_fpr) + assert torch.allclose(pimoresult_from_load.per_image_tprs, pimoresult.per_image_tprs, equal_nan=True) + + +def test_aupimoresult_object( + anomaly_maps: Tensor, + masks: Tensor, + paths: list[str] | None, +) -> None: + """Test if `AUPIMOResult` can be converted to other formats and back.""" + optional_kwargs = {} + if paths is not None: + optional_kwargs["paths"] = paths + + _, aupimoresult = pimo.aupimo_scores( + anomaly_maps, + masks, + num_threshs=7, + binclf_algorithm="numba", + fpr_bounds=(1e-5, 1e-4), + force=True, + **optional_kwargs, + ) + + # call properties + _ = aupimoresult.num_images + _ = aupimoresult.image_classes + _ = aupimoresult.fpr_bounds + _ = aupimoresult.thresh_bounds + + # object -> dict -> object + dic = aupimoresult.to_dict() + assert isinstance(dic, dict) + aupimoresult_from_dict = AUPIMOResult.from_dict(dic) + assert isinstance(aupimoresult_from_dict, AUPIMOResult) + # values should be the same + assert aupimoresult_from_dict.fpr_bounds == aupimoresult.fpr_bounds + assert aupimoresult_from_dict.num_threshs == aupimoresult.num_threshs + assert aupimoresult_from_dict.thresh_bounds == aupimoresult.thresh_bounds + assert torch.allclose(aupimoresult_from_dict.aupimos, aupimoresult.aupimos, equal_nan=True) + + # object -> file -> object + with tempfile.TemporaryDirectory() as tmpdir: + file_path = Path(tmpdir) / "aupimo.json" + aupimoresult.save(str(file_path)) + assert file_path.exists() + aupimoresult_from_load = AUPIMOResult.load(str(file_path)) + assert isinstance(aupimoresult_from_load, AUPIMOResult) + # values should be the same + assert aupimoresult_from_load.fpr_bounds == aupimoresult.fpr_bounds + assert aupimoresult_from_load.num_threshs == aupimoresult.num_threshs + assert aupimoresult_from_load.thresh_bounds == aupimoresult.thresh_bounds + assert torch.allclose(aupimoresult_from_load.aupimos, aupimoresult.aupimos, equal_nan=True) + + # statistics + stats = aupimoresult.stats() + assert len(stats) == 6 + + for statdic in stats: + assert_statsdict_stuff(statdic, 2) diff --git a/tests/unit/metrics/per_image/test_utils.py b/tests/unit/metrics/per_image/test_utils.py new file mode 100644 index 0000000000..0e712b6584 --- /dev/null +++ b/tests/unit/metrics/per_image/test_utils.py @@ -0,0 +1,308 @@ +"""Test `utils.py`.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict + +import numpy as np +import pytest +import torch +from torch import Tensor + +from anomalib.metrics.per_image import ( + AUPIMOResult, + StatsOutliersPolicy, + StatsRepeatedPolicy, + compare_models_pairwise_ttest_rel, + compare_models_pairwise_wilcoxon, + format_pairwise_tests_results, + per_image_scores_stats, +) + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Generate test cases.""" + num_images = 100 + # avg is 0.8 + aucs1 = 0.8 * torch.ones(num_images) + # avg ~ 0.7 + aucs2 = torch.linspace(0.6, 0.8, num_images) + # avg ~ 0.6 + aucs3 = torch.sin(torch.linspace(0, torch.pi, num_images)).clip(0, 1) + + mock_aupimoresult_stuff = { + "fpr_lower_bound": 1e-5, + "fpr_upper_bound": 1e-4, + "num_threshs": 1_000, + "thresh_lower_bound": 1.0, + "thresh_upper_bound": 2.0, + } + fake_paths = [f"/path/to/file_{i}" for i in range(num_images)] + scores_per_model_dicts = [ + ({"a": aucs1, "b": aucs2},), + ({"a": aucs1, "b": aucs2, "c": aucs3},), + (OrderedDict([("c", aucs1), ("b", aucs2), ("a", aucs3)]),), + ( + { + "a": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs1}), + "b": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs2}), + "c": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs3}), + }, + ), + ( + { + "a": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs1, "paths": fake_paths}), + "b": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs2, "paths": fake_paths}), + "c": AUPIMOResult(**{**mock_aupimoresult_stuff, "aupimos": aucs3, "paths": fake_paths}), + }, + ), + ] + + if ( + metafunc.function is test_compare_models_pairwise_ttest + or metafunc.function is test_compare_models_pairwise_wilcoxon + ): + metafunc.parametrize(("scores_per_model",), scores_per_model_dicts) + metafunc.parametrize( + ("alternative", "higher_is_better"), + [ + ("two-sided", True), + ("two-sided", False), + ("less", False), + ("greater", True), + # not considering the case (less, true) and (greater, false) because it will break + # some assumptions in the assertions but they are possible + ], + ) + + if metafunc.function is test_format_pairwise_tests_results: + metafunc.parametrize(("scores_per_model",), scores_per_model_dicts[:3]) + + +def assert_statsdict_stuff(statdic: dict, max_image_idx: int) -> None: + """Assert stuff about a `statdic`.""" + assert "stat_name" in statdic + stat_name = statdic["stat_name"] + assert stat_name in ("mean", "med", "q1", "q3", "whishi", "whislo") or stat_name.startswith( + ("outlo_", "outhi_"), + ) + assert "stat_value" in statdic + assert "image_idx" in statdic + image_idx = statdic["image_idx"] + assert 0 <= image_idx <= max_image_idx + + +def test_per_image_scores_stats() -> None: + """Test `per_image_scores_boxplot_stats`.""" + gen = torch.Generator().manual_seed(42) + num_scores = 201 + scores = torch.randn(num_scores, generator=gen) + + stats = per_image_scores_stats(scores) + assert len(stats) == 6 + for statdic in stats: + assert_statsdict_stuff(statdic, num_scores - 1) + + classes = (torch.arange(num_scores) % 3 == 0).to(torch.long) + stats = per_image_scores_stats(scores, classes, only_class=None) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, classes, only_class=0) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, classes, only_class=1) + assert len(stats) == 6 + + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.BOTH) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.LO) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.HI) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, outliers_policy=StatsOutliersPolicy.NONE) + assert len(stats) == 6 + + # force repeated values + scores = torch.round(scores * 10) / 10 + stats = per_image_scores_stats(scores, repeated_policy=StatsRepeatedPolicy.AVOID) + assert len(stats) == 6 + stats = per_image_scores_stats( + scores, + classes, + repeated_policy=StatsRepeatedPolicy.AVOID, + repeated_replacement_atol=1e-1, + ) + assert len(stats) == 6 + stats = per_image_scores_stats(scores, repeated_policy=StatsRepeatedPolicy.NONE) + assert len(stats) == 6 + + +def test_per_image_scores_stats_specific_values() -> None: + """Test `per_image_scores_boxplot_stats` with specific values.""" + scores = torch.concatenate( + [ + # whislo = min value is 0.0 + torch.tensor([0.0]), + torch.zeros(98), + # q1 value is 0.0 + torch.tensor([0.0]), + torch.linspace(0.01, 0.29, 98), + # med value is 0.3 + torch.tensor([0.3]), + torch.linspace(0.31, 0.69, 99), + # q3 value is 0.7 + torch.tensor([0.7]), + torch.linspace(0.71, 0.99, 99), + # whishi = max value is 1.0 + torch.tensor([1.0]), + ], + ) + + stats = per_image_scores_stats(scores) + assert len(stats) == 6 + + statdict_whislo = stats[0] + statdict_q1 = stats[1] + statdict_med = stats[2] + statdict_mean = stats[3] + statdict_q3 = stats[4] + statdict_whishi = stats[5] + + assert statdict_whislo["stat_name"] == "whislo" + assert np.isclose(statdict_whislo["stat_value"], 0.0) + + assert statdict_q1["stat_name"] == "q1" + assert np.isclose(statdict_q1["stat_value"], 0.0, atol=1e-2) + + assert statdict_med["stat_name"] == "med" + assert np.isclose(statdict_med["stat_value"], 0.3, atol=1e-2) + + assert statdict_mean["stat_name"] == "mean" + assert np.isclose(statdict_mean["stat_value"], 0.3762, atol=1e-2) + + assert statdict_q3["stat_name"] == "q3" + assert np.isclose(statdict_q3["stat_value"], 0.7, atol=1e-2) + + assert statdict_whishi["stat_name"] == "whishi" + assert statdict_whishi["stat_value"] == 1.0 + + +def test_compare_models_pairwise_ttest(scores_per_model: dict, alternative: str, higher_is_better: bool) -> None: + """Test `compare_models_pairwise_ttest`.""" + models_ordered, confidences = compare_models_pairwise_ttest_rel( + scores_per_model, + alternative=alternative, + higher_is_better=higher_is_better, + ) + assert len(confidences) == (len(models_ordered) * (len(models_ordered) - 1)) + + diff = set(scores_per_model.keys()).symmetric_difference(set(models_ordered)) + assert len(diff) == 0 + + if isinstance(scores_per_model, OrderedDict): + assert models_ordered == tuple(scores_per_model.keys()) + + elif len(scores_per_model) == 2: + assert models_ordered == (("a", "b") if higher_is_better else ("b", "a")) + + elif len(scores_per_model) == 3: + assert models_ordered == (("a", "b", "c") if higher_is_better else ("c", "b", "a")) + + if isinstance(next(iter(scores_per_model.values())), AUPIMOResult): + return + + def copy_and_add_nan(scores: Tensor) -> Tensor: + scores = scores.clone() + scores[5:] = torch.nan + return scores + + # removing samples should reduce the confidences + scores_per_model["a"] = copy_and_add_nan(scores_per_model["a"]) + scores_per_model["b"] = copy_and_add_nan(scores_per_model["b"]) + if "c" in scores_per_model: + scores_per_model["c"] = copy_and_add_nan(scores_per_model["c"]) + + compare_models_pairwise_ttest_rel( + scores_per_model, + alternative=alternative, + higher_is_better=higher_is_better, + ) + + +def test_compare_models_pairwise_wilcoxon(scores_per_model: dict, alternative: str, higher_is_better: bool) -> None: + """Test `compare_models_pairwise_wilcoxon`.""" + models_ordered, confidences = compare_models_pairwise_wilcoxon( + scores_per_model, + alternative=alternative, + higher_is_better=higher_is_better, + ) + assert len(confidences) == (len(models_ordered) * (len(models_ordered) - 1)) + + diff = set(scores_per_model.keys()).symmetric_difference(set(models_ordered)) + assert len(diff) == 0 + + if isinstance(scores_per_model, OrderedDict): + assert models_ordered == tuple(scores_per_model.keys()) + + elif len(scores_per_model) == 2: + assert models_ordered == (("a", "b") if higher_is_better else ("b", "a")) + + elif len(scores_per_model) == 3: + # this one is not trivial without looking at the data, so no assertions + pass + + if isinstance(next(iter(scores_per_model.values())), AUPIMOResult): + return + + def copy_and_add_nan(scores: Tensor) -> Tensor: + scores = scores.clone() + scores[5:] = torch.nan + return scores + + # removing samples should reduce the confidences + scores_per_model["a"] = copy_and_add_nan(scores_per_model["a"]) + scores_per_model["b"] = copy_and_add_nan(scores_per_model["b"]) + if "c" in scores_per_model: + scores_per_model["c"] = copy_and_add_nan(scores_per_model["c"]) + + compare_models_pairwise_wilcoxon( + scores_per_model, + alternative=alternative, + higher_is_better=higher_is_better, + ) + + +def test_format_pairwise_tests_results(scores_per_model: dict) -> None: + """Test `format_pairwise_tests_results`.""" + models_ordered, confidences = compare_models_pairwise_wilcoxon( + scores_per_model, + alternative="greater", + higher_is_better=True, + ) + confidence_df = format_pairwise_tests_results( + models_ordered, + confidences, + model1_as_column=True, + left_to_right=True, + top_to_bottom=True, + ) + assert tuple(confidence_df.columns.tolist()) == models_ordered + assert tuple(confidence_df.index.tolist()) == models_ordered + + models_ordered, confidences = compare_models_pairwise_ttest_rel( + scores_per_model, + alternative="greater", + higher_is_better=True, + ) + confidence_df = format_pairwise_tests_results( + models_ordered, + confidences, + model1_as_column=True, + left_to_right=True, + top_to_bottom=True, + ) + assert tuple(confidence_df.columns.tolist()) == models_ordered + assert tuple(confidence_df.index.tolist()) == models_ordered diff --git a/third-party-programs.txt b/third-party-programs.txt index 3155b2a930..8aff59c810 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -42,3 +42,7 @@ terms are listed below. 7. CLIP neural network used for deep feature extraction in AI-VAD model Copyright (c) 2022 @openai, https://github.com/openai/CLIP. SPDX-License-Identifier: MIT + +8. AUPIMO metric implementation is based on the original code + Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo + SPDX-License-Identifier: MIT \ No newline at end of file