Skip to content

Commit 29cb912

Browse files
PIMO (#1726)
* update Signed-off-by: jpcbertoldo <[email protected]> * test binclf curves numpy and numba and fixes Signed-off-by: jpcbertoldo <[email protected]> * correct som docstrings Signed-off-by: jpcbertoldo <[email protected]> * torch interface and tests Signed-off-by: jpcbertoldo <[email protected]> * torch interface and tests Signed-off-by: jpcbertoldo <[email protected]> * constants regrouped in dataclass as class vars Signed-off-by: jpcbertoldo <[email protected]> * result class was unneccesary for per_image_binclf_curve Signed-off-by: jpcbertoldo <[email protected]> * factorize function _get_threshs_minmax_linspace Signed-off-by: jpcbertoldo <[email protected]> * small docs fixes Signed-off-by: jpcbertoldo <[email protected]> * add pimo numpy version and test Signed-off-by: jpcbertoldo <[email protected]> * move validation Signed-off-by: jpcbertoldo <[email protected]> * add `shared_fpr_metric` option Signed-off-by: jpcbertoldo <[email protected]> * add pimo torch functional version and test Signed-off-by: jpcbertoldo <[email protected]> * add torchmetrics interface and test Signed-off-by: jpcbertoldo <[email protected]> * renames and put things in init Signed-off-by: jpcbertoldo <[email protected]> * validate inputs in result objects Signed-off-by: jpcbertoldo <[email protected]> * result objects to from dict and tests Signed-off-by: jpcbertoldo <[email protected]> * add save and load methods to result objects and test Signed-off-by: jpcbertoldo <[email protected]> * refactor validations and minor changes Signed-off-by: jpcbertoldo <[email protected]> * test result objects' properties Signed-off-by: jpcbertoldo <[email protected]> * minor refactors Signed-off-by: jpcbertoldo <[email protected]> * add missing docstrings Signed-off-by: jpcbertoldo <[email protected]> * minore vocabulary fix for consistency Signed-off-by: jpcbertoldo <[email protected]> * add per image scores statistics and test it Signed-off-by: jpcbertoldo <[email protected]> * refactor constants notation Signed-off-by: jpcbertoldo <[email protected]> * add stats tests and test it Signed-off-by: jpcbertoldo <[email protected]> * change the meaning of AUPIMO.num_thresh Signed-off-by: jpcbertoldo <[email protected]> * interface to format pairwise test results Signed-off-by: jpcbertoldo <[email protected]> * improve doc Signed-off-by: jpcbertoldo <[email protected]> * add optional `paths` to result objects and some minor fixes and refactors Signed-off-by: jpcbertoldo <[email protected]> * remove frozen from dataclasses and some done todos Signed-off-by: jpcbertoldo <[email protected]> * review headers Signed-off-by: jpcbertoldo <[email protected]> * doc modifs Signed-off-by: jpcbertoldo <[email protected]> * refactor `score_less_than_thresh` in `_binclf_one_curve_python` Signed-off-by: jpcbertoldo <[email protected]> * correct license comments Signed-off-by: jpcbertoldo <[email protected]> * fix doc Signed-off-by: jpcbertoldo <[email protected]> * numba as extra requirement Signed-off-by: jpcbertoldo <[email protected]> * refactor copyrights from jpcbertoldo Signed-off-by: jpcbertoldo <[email protected]> * remove from __future__ import annotations Signed-off-by: jpcbertoldo <[email protected]> * refactor validations names Signed-off-by: jpcbertoldo <[email protected]> * dedupe file path validation Signed-off-by: jpcbertoldo <[email protected]> * fix tests Signed-off-by: jpcbertoldo <[email protected]> * Add todo Signed-off-by: jpcbertoldo <[email protected]> * refactor enums Signed-off-by: jpcbertoldo <[email protected]> * only logger.warning Signed-off-by: jpcbertoldo <[email protected]> * refactor test imports Signed-off-by: jpcbertoldo <[email protected]> * refactor docs Signed-off-by: jpcbertoldo <[email protected]> * refactor some docs Signed-off-by: jpcbertoldo <[email protected]> * correct pre commit errors Signed-off-by: jpcbertoldo <[email protected]> * remove author tag Signed-off-by: jpcbertoldo <[email protected]> * add thrid party program Signed-off-by: jpcbertoldo <[email protected]> * Update src/anomalib/metrics/per_image/pimo.py * move HAS_NUMBA Signed-off-by: jpcbertoldo <[email protected]> * remove PIMOSharedFPRMetric Signed-off-by: jpcbertoldo <[email protected]> * make torchmetrics compute avg by dft Signed-off-by: jpcbertoldo <[email protected]> * pre-commit hooks corrections Signed-off-by: jpcbertoldo <[email protected]> * correct numpy.trapezoid Signed-off-by: jpcbertoldo <[email protected]> --------- Signed-off-by: jpcbertoldo <[email protected]> Co-authored-by: Samet Akcay <[email protected]>
1 parent cdd338c commit 29cb912

File tree

18 files changed

+4974
-2
lines changed

18 files changed

+4974
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ test = [
8484
"coverage[toml]",
8585
"tox",
8686
]
87-
full = ["anomalib[core,openvino,loggers,notebooks]"]
87+
extra = ["numba>=0.58.1"]
88+
full = ["anomalib[core,openvino,loggers,notebooks,extra]"]
8889
dev = ["anomalib[full,docs,test]"]
8990

9091
[project.scripts]

src/anomalib/data/utils/path.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,20 @@ def contains_non_printable_characters(path: str | Path) -> bool:
142142
return not printable_pattern.match(str(path))
143143

144144

145-
def validate_path(path: str | Path, base_dir: str | Path | None = None, should_exist: bool = True) -> Path:
145+
def validate_path(
146+
path: str | Path,
147+
base_dir: str | Path | None = None,
148+
should_exist: bool = True,
149+
accepted_extensions: tuple[str, ...] | None = None,
150+
) -> Path:
146151
"""Validate the path.
147152
148153
Args:
149154
path (str | Path): Path to validate.
150155
base_dir (str | Path): Base directory to restrict file access.
151156
should_exist (bool): If True, do not raise an exception if the path does not exist.
157+
accepted_extensions (tuple[str, ...] | None): Accepted extensions for the path. An exception is raised if the
158+
path does not have one of the accepted extensions. If None, no check is performed. Defaults to None.
152159
153160
Returns:
154161
Path: Validated path.
@@ -213,6 +220,11 @@ def validate_path(path: str | Path, base_dir: str | Path | None = None, should_e
213220
msg = f"Read or execute permissions denied for the path: {path}"
214221
raise PermissionError(msg)
215222

223+
# Check if the path has one of the accepted extensions
224+
if accepted_extensions is not None and path.suffix not in accepted_extensions:
225+
msg = f"Path extension is not accepted. Accepted extensions: {accepted_extensions}. Path: {path}"
226+
raise ValueError(msg)
227+
216228
return path
217229

218230

src/anomalib/metrics/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torchmetrics
1212
from omegaconf import DictConfig, ListConfig
1313

14+
from . import per_image
1415
from .anomaly_score_distribution import AnomalyScoreDistribution
1516
from .aupr import AUPR
1617
from .aupro import AUPRO
@@ -19,6 +20,7 @@
1920
from .f1_max import F1Max
2021
from .f1_score import F1Score
2122
from .min_max import MinMax
23+
from .per_image import AUPIMO, PIMO, aupimo_scores, pimo_curves
2224
from .precision_recall_curve import BinaryPrecisionRecallCurve
2325
from .pro import PRO
2426
from .threshold import F1AdaptiveThreshold, ManualThreshold
@@ -35,6 +37,11 @@
3537
"ManualThreshold",
3638
"MinMax",
3739
"PRO",
40+
"per_image",
41+
"pimo_curves",
42+
"aupimo_scores",
43+
"PIMO",
44+
"AUPIMO",
3845
]
3946

4047
logger = logging.getLogger(__name__)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Per-Image Metrics."""
2+
3+
# Original Code
4+
# https://github.com/jpcbertoldo/aupimo
5+
#
6+
# Modified
7+
# Copyright (C) 2024 Intel Corporation
8+
# SPDX-License-Identifier: Apache-2.0
9+
10+
from .binclf_curve import per_image_binclf_curve, per_image_fpr, per_image_tpr
11+
from .binclf_curve_numpy import BinclfAlgorithm, BinclfThreshsChoice
12+
from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult, aupimo_scores, pimo_curves
13+
from .utils import (
14+
compare_models_pairwise_ttest_rel,
15+
compare_models_pairwise_wilcoxon,
16+
format_pairwise_tests_results,
17+
per_image_scores_stats,
18+
)
19+
from .utils_numpy import StatsOutliersPolicy, StatsRepeatedPolicy
20+
21+
__all__ = [
22+
# constants
23+
"BinclfAlgorithm",
24+
"BinclfThreshsChoice",
25+
"StatsOutliersPolicy",
26+
"StatsRepeatedPolicy",
27+
# result classes
28+
"PIMOResult",
29+
"AUPIMOResult",
30+
# functional interfaces
31+
"per_image_binclf_curve",
32+
"per_image_fpr",
33+
"per_image_tpr",
34+
"pimo_curves",
35+
"aupimo_scores",
36+
# torchmetrics interfaces
37+
"PIMO",
38+
"AUPIMO",
39+
# utils
40+
"compare_models_pairwise_ttest_rel",
41+
"compare_models_pairwise_wilcoxon",
42+
"format_pairwise_tests_results",
43+
"per_image_scores_stats",
44+
]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Binary classification matrix curve (NUMBA implementation of low level functions).
2+
3+
Details: `.binclf_curve`.
4+
"""
5+
6+
# Original Code
7+
# https://github.com/jpcbertoldo/aupimo
8+
#
9+
# Modified
10+
# Copyright (C) 2024 Intel Corporation
11+
# SPDX-License-Identifier: Apache-2.0
12+
13+
import numba
14+
import numpy as np
15+
from numpy import ndarray
16+
17+
18+
@numba.jit(nopython=True)
19+
def binclf_one_curve_numba(scores: ndarray, gts: ndarray, threshs: ndarray) -> ndarray:
20+
"""One binary classification matrix at each threshold (NUMBA implementation).
21+
22+
This does the same as `_binclf_one_curve_python` but with numba using just-in-time compilation.
23+
24+
Note: VALIDATION IS NOT DONE HERE! Make sure to validate the arguments before calling this function.
25+
26+
Args:
27+
scores (ndarray): Anomaly scores (D,).
28+
gts (ndarray): Binary (bool) ground truth of shape (D,).
29+
threshs (ndarray): Sequence of thresholds in ascending order (K,).
30+
31+
Returns:
32+
ndarray: Binary classification matrix curve (K, 2, 2)
33+
34+
Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`.
35+
"""
36+
num_th = len(threshs)
37+
38+
# POSITIVES
39+
scores_pos = scores[gts]
40+
# the sorting is very important for the algorithm to work and the speedup
41+
scores_pos = np.sort(scores_pos)
42+
# start counting with lowest th, so everything is predicted as positive (this variable is updated in the loop)
43+
num_pos = current_count_tp = len(scores_pos)
44+
45+
tps = np.empty((num_th,), dtype=np.int64)
46+
47+
# NEGATIVES
48+
# same thing but for the negative samples
49+
scores_neg = scores[~gts]
50+
scores_neg = np.sort(scores_neg)
51+
num_neg = current_count_fp = len(scores_neg)
52+
53+
fps = np.empty((num_th,), dtype=np.int64)
54+
55+
# it will progressively drop the scores that are below the current th
56+
for thidx, th in enumerate(threshs):
57+
num_drop = 0
58+
num_scores = len(scores_pos)
59+
while num_drop < num_scores and scores_pos[num_drop] < th: # ! scores_pos !
60+
num_drop += 1
61+
# ---
62+
scores_pos = scores_pos[num_drop:]
63+
current_count_tp -= num_drop
64+
tps[thidx] = current_count_tp
65+
66+
# same with the negatives
67+
num_drop = 0
68+
num_scores = len(scores_neg)
69+
while num_drop < num_scores and scores_neg[num_drop] < th: # ! scores_neg !
70+
num_drop += 1
71+
# ---
72+
scores_neg = scores_neg[num_drop:]
73+
current_count_fp -= num_drop
74+
fps[thidx] = current_count_fp
75+
76+
fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps
77+
tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps
78+
79+
# sequence of dimensions is (threshs, true class, predicted class) (see docstring)
80+
return np.stack(
81+
(
82+
np.stack((tns, fps), axis=-1),
83+
np.stack((fns, tps), axis=-1),
84+
),
85+
axis=-1,
86+
).transpose(0, 2, 1)
87+
88+
89+
@numba.jit(nopython=True, parallel=True)
90+
def binclf_multiple_curves_numba(scores_batch: ndarray, gts_batch: ndarray, threshs: ndarray) -> ndarray:
91+
"""Multiple binary classification matrix at each threshold (NUMBA implementation).
92+
93+
This does the same as `_binclf_multiple_curves_python` but with numba,
94+
using parallelization and just-in-time compilation.
95+
96+
Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function.
97+
98+
Args:
99+
scores_batch (ndarray): Anomaly scores (N, D,).
100+
gts_batch (ndarray): Binary (bool) ground truth of shape (N, D,).
101+
threshs (ndarray): Sequence of thresholds in ascending order (K,).
102+
103+
Returns:
104+
ndarray: Binary classification matrix curves (N, K, 2, 2)
105+
106+
Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`.
107+
"""
108+
num_imgs = scores_batch.shape[0]
109+
num_th = len(threshs)
110+
ret = np.empty((num_imgs, num_th, 2, 2), dtype=np.int64)
111+
for imgidx in numba.prange(num_imgs):
112+
scoremap = scores_batch[imgidx]
113+
mask = gts_batch[imgidx]
114+
ret[imgidx] = binclf_one_curve_numba(scoremap, mask, threshs)
115+
return ret

0 commit comments

Comments
 (0)