Skip to content

Commit 2f3d616

Browse files
authored
🔨 Replace imgaug with Native PyTorch Transforms (#2436)
* Add multi random choice transform * Add DRAEMAugmenter class and Perlin noise generation to new_perlin.py - Introduced DRAEMAugmenter for advanced image augmentations using torchvision v2. - Implemented various augmentation techniques including ColorJitter, RandomAdjustSharpness, and custom transformations. - Added functionality for comparing augmentation methods and visualizing results. - Included utility functions for metrics computation and image processing. - Established logging for better traceability of operations. This commit enhances the image processing capabilities within the Anomalib framework, facilitating more robust anomaly detection workflows. * Add the new perlin noise Signed-off-by: Samet Akcay <[email protected]> * Add the new perlin noise Signed-off-by: Samet Akcay <[email protected]> * add generate_perlin_noise relative import Signed-off-by: Samet Akcay <[email protected]> * add tiffile as a dependency Signed-off-by: Samet Akcay <[email protected]> * Remove upper bound from wandb Signed-off-by: Samet Akcay <[email protected]> * Added skimage Signed-off-by: Samet Akcay <[email protected]> * add scikit-learn as a dependency Signed-off-by: Samet Akcay <[email protected]> * limit ollama to < 0.4.0 as it has breaking changes Signed-off-by: Samet Akcay <[email protected]> * Fix data generators in test helpers Signed-off-by: Samet Akcay <[email protected]> * Update the perlin augmenters Signed-off-by: Samet Akcay <[email protected]> * Fix numpy validator tests caused by numpy upgrade Signed-off-by: Samet Akcay <[email protected]> * Fix CS-Flow tests Signed-off-by: Samet Akcay <[email protected]> * Fix the tests Signed-off-by: Samet Akcay <[email protected]> --------- Signed-off-by: Samet Akcay <[email protected]>
1 parent c16f51e commit 2f3d616

File tree

16 files changed

+453
-342
lines changed

16 files changed

+453
-342
lines changed

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ core = [
4242
"av>=10.0.0",
4343
"einops>=0.3.2",
4444
"freia>=0.2",
45-
"imgaug==0.4.0",
4645
"kornia>=0.6.6",
4746
"matplotlib>=3.4.3",
4847
"opencv-python>=4.5.3.56",
4948
"pandas>=1.1.0",
49+
"scikit-image", # NOTE: skimage should be removed as part of dependency cleanup
50+
"tifffile",
5051
"timm",
5152
"lightning>=2.2",
5253
"torch>=2",
@@ -57,12 +58,12 @@ core = [
5758
"open-clip-torch>=2.23.0,<2.26.1",
5859
]
5960
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
60-
vlm = ["ollama", "openai", "python-dotenv","transformers"]
61+
vlm = ["ollama<0.4.0", "openai", "python-dotenv","transformers"]
6162
loggers = [
6263
"comet-ml>=3.31.7",
6364
"gradio>=4",
6465
"tensorboard",
65-
"wandb>=0.12.17,<=0.15.9",
66+
"wandb",
6667
"mlflow >=1.0.0",
6768
]
6869
notebooks = ["gitpython", "ipykernel", "ipywidgets", "notebook"]

src/anomalib/data/transforms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .center_crop import ExportableCenterCrop
7+
from .multi_random_choice import MultiRandomChoice
78

8-
__all__ = ["ExportableCenterCrop"]
9+
__all__ = ["ExportableCenterCrop", "MultiRandomChoice"]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Multi random choice transform."""
2+
3+
# Copyright (C) 2024 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from collections.abc import Callable, Sequence
7+
8+
import torch
9+
from torchvision.transforms import v2
10+
11+
12+
class MultiRandomChoice(v2.Transform):
13+
"""Apply multiple transforms randomly picked from a list.
14+
15+
This transform does not support torchscript.
16+
17+
Args:
18+
transforms (sequence or torch.nn.Module): List of transformations to choose from.
19+
probabilities (list[float] | None, optional): Probability of each transform being picked.
20+
If None (default), all transforms have equal probability. If provided, probabilities
21+
will be normalized to sum to 1.
22+
num_transforms (int): Maximum number of transforms to apply at once.
23+
Defaults to ``1``.
24+
fixed_num_transforms (bool): If ``True``, always applies exactly ``num_transforms`` transforms.
25+
If ``False``, randomly picks between 1 and ``num_transforms``.
26+
Defaults to ``False``.
27+
28+
Examples:
29+
>>> import torchvision.transforms.v2 as v2
30+
>>> transforms = [
31+
... v2.RandomHorizontalFlip(p=1.0),
32+
... v2.ColorJitter(brightness=0.5),
33+
... v2.RandomRotation(10),
34+
... ]
35+
>>> # Apply 1-2 random transforms with equal probability
36+
>>> transform = MultiRandomChoice(transforms, num_transforms=2)
37+
38+
>>> # Always apply exactly 2 transforms with custom probabilities
39+
>>> transform = MultiRandomChoice(
40+
... transforms,
41+
... probabilities=[0.5, 0.3, 0.2],
42+
... num_transforms=2,
43+
... fixed_num_transforms=True
44+
... )
45+
"""
46+
47+
def __init__(
48+
self,
49+
transforms: Sequence[Callable],
50+
probabilities: list[float] | None = None,
51+
num_transforms: int = 1,
52+
fixed_num_transforms: bool = False,
53+
) -> None:
54+
if not isinstance(transforms, Sequence):
55+
msg = "Argument transforms should be a sequence of callables"
56+
raise TypeError(msg)
57+
58+
if probabilities is None:
59+
probabilities = [1.0] * len(transforms)
60+
elif len(probabilities) != len(transforms):
61+
msg = f"Length of p doesn't match the number of transforms: {len(probabilities)} != {len(transforms)}"
62+
raise ValueError(msg)
63+
64+
super().__init__()
65+
66+
self.transforms = transforms
67+
total = sum(probabilities)
68+
self.probabilities = [probability / total for probability in probabilities]
69+
70+
self.num_transforms = num_transforms
71+
self.fixed_num_transforms = fixed_num_transforms
72+
73+
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]:
74+
"""Apply randomly selected transforms to the input."""
75+
# First determine number of transforms to apply
76+
num_transforms = (
77+
self.num_transforms if self.fixed_num_transforms else int(torch.randint(self.num_transforms, (1,)) + 1)
78+
)
79+
# Get transforms
80+
idx = torch.multinomial(torch.tensor(self.probabilities), num_transforms).tolist()
81+
transform = v2.Compose([self.transforms[i] for i in idx])
82+
return transform(*inputs)

src/anomalib/data/utils/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
# Copyright (C) 2022-2024 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
from .augmenter import Augmenter
76
from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
87
from .download import DownloadInfo, download_and_extract
9-
from .generators import random_2d_perlin
8+
from .generators import generate_perlin_noise
109
from .image import (
1110
generate_output_image_filename,
1211
get_image_filenames,
@@ -30,7 +29,7 @@
3029
"generate_output_image_filename",
3130
"get_image_filenames",
3231
"get_image_height_and_width",
33-
"random_2d_perlin",
32+
"generate_perlin_noise",
3433
"read_image",
3534
"read_mask",
3635
"read_depth_image",
@@ -42,7 +41,6 @@
4241
"TestSplitMode",
4342
"LabelName",
4443
"DirType",
45-
"Augmenter",
4644
"masks_to_boxes",
4745
"boxes_to_masks",
4846
"boxes_to_anomaly_maps",

src/anomalib/data/utils/augmenter.py

Lines changed: 0 additions & 171 deletions
This file was deleted.

src/anomalib/data/utils/generators/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
# Copyright (C) 2022 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
from .perlin import random_2d_perlin
6+
from .perlin import PerlinAnomalyGenerator, generate_perlin_noise
77

8-
__all__ = ["random_2d_perlin"]
8+
__all__ = ["PerlinAnomalyGenerator", "generate_perlin_noise"]

0 commit comments

Comments
 (0)