Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a298a3a
Dataclasses and post-processing refactor (#2098)
djdameln Sep 2, 2024
2efb4af
Merge main and resolve conflicts (#2287)
samet-akcay Sep 2, 2024
333baec
Rename Item to DatasetItem (#2289)
samet-akcay Sep 2, 2024
1f9a7e0
πŸ“š Add docstrings to dataclasses (#2292)
samet-akcay Sep 6, 2024
3f76f11
Move datamodules to datamodule sub-package
samet-akcay Sep 10, 2024
1b4ff81
Move datamodules to datamodule sub-package
samet-akcay Sep 10, 2024
c26c8c5
Split datamodules and datasets
samet-akcay Sep 10, 2024
d4623eb
Restructure dataclasses to data
samet-akcay Sep 10, 2024
795d0e7
Fix relative imports
samet-akcay Sep 10, 2024
a3727ae
Use absolute imports
samet-akcay Sep 10, 2024
19439f4
Add datasets dir
samet-akcay Sep 10, 2024
be5811f
Add relative imports for torch datasets
samet-akcay Sep 10, 2024
77cbf64
restructure datamodule tests
samet-akcay Sep 11, 2024
897b23b
Refactor and restructure anomalib.data (#2302)
samet-akcay Sep 11, 2024
127d101
Add a new logic to ImageValidator
samet-akcay Sep 11, 2024
2f7141a
Merge upstream and resolve conflicts
samet-akcay Sep 12, 2024
bc8d1d5
Add ImageBatchValidator and update ImageBatch
samet-akcay Sep 12, 2024
b6df346
Convert asserts to errors and add docstrings in ImageBatchValidator
samet-akcay Sep 12, 2024
6038460
Add VideoItem
samet-akcay Sep 12, 2024
7ebc2de
Add VideoBatch and VideoBatchValidator
samet-akcay Sep 12, 2024
21f428c
Move depth datamodules tests to the depth folder
samet-akcay Sep 12, 2024
882312f
Add DepthItem, DepthBatch, DepthValidator and DepthBatchValidator
samet-akcay Sep 12, 2024
efefe97
Update video frames validation
samet-akcay Sep 12, 2024
7d38244
Add numpy image and validators
samet-akcay Sep 12, 2024
d24b56b
Add numpy depth and validators
samet-akcay Sep 12, 2024
798c37b
Add numpy videos
samet-akcay Sep 12, 2024
92c912d
Convert private _validate to public validate
samet-akcay Sep 12, 2024
fa4cbf2
Add missing copyright, and add relative imports
samet-akcay Sep 12, 2024
45c513d
rebase and resolve conflicts
samet-akcay Sep 13, 2024
d7da3e5
fix cfa test
samet-akcay Sep 13, 2024
8952fc5
fix pred label tests
samet-akcay Sep 13, 2024
9c8c6bb
Run all the tests even if they fail
samet-akcay Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions src/anomalib/data/dataclasses/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,28 @@ class _InputFields(Generic[T, ImageT, MaskT, PathT], ABC):
methods.
"""

image: FieldDescriptor[ImageT] = FieldDescriptor(validator_name="_validate_image")
gt_label: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_gt_label")
gt_mask: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="_validate_gt_mask")
mask_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_mask_path")
image: FieldDescriptor[ImageT] = FieldDescriptor(validator_name="validate_image")
gt_label: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_gt_label")
gt_mask: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="validate_gt_mask")
mask_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="validate_mask_path")

@abstractmethod
def _validate_image(self, image: ImageT) -> ImageT:
def validate_image(self, image: ImageT) -> ImageT:
"""Validate the image."""
raise NotImplementedError

@abstractmethod
def _validate_gt_mask(self, gt_mask: MaskT) -> MaskT | None:
def validate_gt_mask(self, gt_mask: MaskT) -> MaskT | None:
"""Validate the ground truth mask."""
raise NotImplementedError

@abstractmethod
def _validate_mask_path(self, mask_path: PathT) -> PathT | None:
def validate_mask_path(self, mask_path: PathT) -> PathT | None:
"""Validate the mask path."""
raise NotImplementedError

@abstractmethod
def _validate_gt_label(self, gt_label: T) -> T | None:
def validate_gt_label(self, gt_label: T) -> T | None:
"""Validate the ground truth label."""
raise NotImplementedError

Expand All @@ -163,7 +163,7 @@ class _ImageInputFields(Generic[PathT], ABC):
with disk-stored image datasets, facilitating custom data loading strategies.

The ``image_path`` field uses a ``FieldDescriptor`` with a validation method.
Subclasses must implement ``_validate_image_path`` to ensure path validity
Subclasses must implement ``validate_image_path`` to ensure path validity
according to specific Anomalib model or dataset requirements.

This class is designed to complement ``_InputFields`` for comprehensive
Expand All @@ -172,7 +172,7 @@ class _ImageInputFields(Generic[PathT], ABC):
Examples:
Assuming a concrete implementation ``DummyImageInput``:
>>> class DummyImageInput(_ImageInputFields):
... def _validate_image_path(self, image_path):
... def validate_image_path(self, image_path):
... return image_path # Implement actual validation
... # Implement other required methods

Expand All @@ -190,10 +190,10 @@ class _ImageInputFields(Generic[PathT], ABC):
methods.
"""

image_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_image_path")
image_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="validate_image_path")

@abstractmethod
def _validate_image_path(self, image_path: PathT) -> PathT | None:
def validate_image_path(self, image_path: PathT) -> PathT | None:
"""Validate the image path."""
raise NotImplementedError

Expand All @@ -217,7 +217,7 @@ class _VideoInputFields(Generic[T, ImageT, MaskT, PathT], ABC):
Assuming a concrete implementation ``DummyVideoInput``:

>>> class DummyVideoInput(_VideoInputFields):
... def _validate_original_image(self, original_image):
... def validate_original_image(self, original_image):
... return original_image # Implement actual validation
... # Implement other required methods

Expand All @@ -243,34 +243,34 @@ class _VideoInputFields(Generic[T, ImageT, MaskT, PathT], ABC):
methods.
"""

original_image: FieldDescriptor[ImageT | None] = FieldDescriptor(validator_name="_validate_original_image")
video_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_video_path")
target_frame: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_target_frame")
frames: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_frames")
last_frame: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_last_frame")
original_image: FieldDescriptor[ImageT | None] = FieldDescriptor(validator_name="validate_original_image")
video_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="validate_video_path")
target_frame: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_target_frame")
frames: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_frames")
last_frame: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_last_frame")

@abstractmethod
def _validate_original_image(self, original_image: ImageT) -> ImageT | None:
def validate_original_image(self, original_image: ImageT) -> ImageT | None:
"""Validate the original image."""
raise NotImplementedError

@abstractmethod
def _validate_video_path(self, video_path: PathT) -> PathT | None:
def validate_video_path(self, video_path: PathT) -> PathT | None:
"""Validate the video path."""
raise NotImplementedError

@abstractmethod
def _validate_target_frame(self, target_frame: T) -> T | None:
def validate_target_frame(self, target_frame: T) -> T | None:
"""Validate the target frame."""
raise NotImplementedError

@abstractmethod
def _validate_frames(self, frames: T) -> T | None:
def validate_frames(self, frames: T) -> T | None:
"""Validate the frames."""
raise NotImplementedError

@abstractmethod
def _validate_last_frame(self, last_frame: T) -> T | None:
def validate_last_frame(self, last_frame: T) -> T | None:
"""Validate the last frame."""
raise NotImplementedError

Expand All @@ -293,9 +293,9 @@ class _DepthInputFields(Generic[T, PathT], _ImageInputFields[PathT], ABC):
Assuming a concrete implementation ``DummyDepthInput``:

>>> class DummyDepthInput(_DepthInputFields):
... def _validate_depth_map(self, depth_map):
... def validate_depth_map(self, depth_map):
... return depth_map # Implement actual validation
... def _validate_depth_path(self, depth_path):
... def validate_depth_path(self, depth_path):
... return depth_path # Implement actual validation
... # Implement other required methods

Expand All @@ -316,16 +316,16 @@ class _DepthInputFields(Generic[T, PathT], _ImageInputFields[PathT], ABC):
methods.
"""

depth_map: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_depth_map")
depth_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_depth_path")
depth_map: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_depth_map")
depth_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="validate_depth_path")

@abstractmethod
def _validate_depth_map(self, depth_map: ImageT) -> ImageT | None:
def validate_depth_map(self, depth_map: ImageT) -> ImageT | None:
"""Validate the depth map."""
raise NotImplementedError

@abstractmethod
def _validate_depth_path(self, depth_path: PathT) -> PathT | None:
def validate_depth_path(self, depth_path: PathT) -> PathT | None:
"""Validate the depth path."""
raise NotImplementedError

Expand All @@ -345,13 +345,13 @@ class _OutputFields(Generic[T, MaskT], ABC):
Assuming a concrete implementation ``DummyOutput``:

>>> class DummyOutput(_OutputFields):
... def _validate_anomaly_map(self, anomaly_map):
... def validate_anomaly_map(self, anomaly_map):
... return anomaly_map # Implement actual validation
... def _validate_pred_score(self, pred_score):
... def validate_pred_score(self, pred_score):
... return pred_score # Implement actual validation
... def _validate_pred_mask(self, pred_mask):
... def validate_pred_mask(self, pred_mask):
... return pred_mask # Implement actual validation
... def _validate_pred_label(self, pred_label):
... def validate_pred_label(self, pred_label):
... return pred_label # Implement actual validation

>>> # Create an output instance with predictions
Expand All @@ -374,28 +374,28 @@ class _OutputFields(Generic[T, MaskT], ABC):
methods.
"""

anomaly_map: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="_validate_anomaly_map")
pred_score: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_pred_score")
pred_mask: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="_validate_pred_mask")
pred_label: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_pred_label")
anomaly_map: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="validate_anomaly_map")
pred_score: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_pred_score")
pred_mask: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="validate_pred_mask")
pred_label: FieldDescriptor[T | None] = FieldDescriptor(validator_name="validate_pred_label")

@abstractmethod
def _validate_anomaly_map(self, anomaly_map: MaskT) -> MaskT | None:
def validate_anomaly_map(self, anomaly_map: MaskT) -> MaskT | None:
"""Validate the anomaly map."""
raise NotImplementedError

@abstractmethod
def _validate_pred_score(self, pred_score: T) -> T | None:
def validate_pred_score(self, pred_score: T) -> T | None:
"""Validate the predicted score."""
raise NotImplementedError

@abstractmethod
def _validate_pred_mask(self, pred_mask: MaskT) -> MaskT | None:
def validate_pred_mask(self, pred_mask: MaskT) -> MaskT | None:
"""Validate the predicted mask."""
raise NotImplementedError

@abstractmethod
def _validate_pred_label(self, pred_label: T) -> T | None:
def validate_pred_label(self, pred_label: T) -> T | None:
"""Validate the predicted label."""
raise NotImplementedError

Expand Down Expand Up @@ -477,7 +477,7 @@ class _GenericItem(
Assuming a concrete implementation ``DummyItem``:

>>> class DummyItem(_GenericItem):
... def _validate_image(self, image):
... def validate_image(self, image):
... return image # Implement actual validation
... # Implement other required methods

Expand Down Expand Up @@ -522,7 +522,7 @@ class _GenericBatch(
Assuming a concrete implementation ``DummyBatch``:

>>> class DummyBatch(_GenericBatch):
... def _validate_image(self, image):
... def validate_image(self, image):
... return image # Implement actual validation
... # Implement other required methods

Expand Down
132 changes: 132 additions & 0 deletions src/anomalib/data/dataclasses/numpy/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,135 @@

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import numpy as np

from anomalib.data.dataclasses.generic import BatchIterateMixin, _DepthInputFields
from anomalib.data.dataclasses.numpy.base import NumpyBatch, NumpyItem
from anomalib.data.validators.numpy.depth import NumpyDepthBatchValidator, NumpyDepthValidator


@dataclass
class NumpyDepthItem(_DepthInputFields[np.ndarray, str], NumpyItem):
"""Dataclass for a single depth item in Anomalib datasets using numpy arrays.

This class combines _DepthInputFields and NumpyItem for depth-based anomaly detection.
It includes depth-specific fields and validation methods to ensure proper formatting
for Anomalib's depth-based models.
"""

@staticmethod
def validate_image(image: np.ndarray) -> np.ndarray:
"""Validate the image."""
return NumpyDepthValidator.validate_image(image)

@staticmethod
def validate_gt_label(gt_label: np.ndarray | None) -> np.ndarray | None:
"""Validate the ground truth label."""
return NumpyDepthValidator.validate_gt_label(gt_label)

@staticmethod
def validate_gt_mask(gt_mask: np.ndarray | None) -> np.ndarray | None:
"""Validate the ground truth mask."""
return NumpyDepthValidator.validate_gt_mask(gt_mask)

@staticmethod
def validate_mask_path(mask_path: str | None) -> str | None:
"""Validate the mask path."""
return NumpyDepthValidator.validate_mask_path(mask_path)

@staticmethod
def validate_anomaly_map(anomaly_map: np.ndarray | None) -> np.ndarray | None:
"""Validate the anomaly map."""
return NumpyDepthValidator.validate_anomaly_map(anomaly_map)

@staticmethod
def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction score."""
return NumpyDepthValidator.validate_pred_score(pred_score)

@staticmethod
def validate_pred_mask(pred_mask: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction mask."""
return NumpyDepthValidator.validate_pred_mask(pred_mask)

@staticmethod
def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction label."""
return NumpyDepthValidator.validate_pred_label(pred_label)

@staticmethod
def validate_image_path(image_path: str | None) -> str | None:
"""Validate the image path."""
return NumpyDepthValidator.validate_image_path(image_path)

@staticmethod
def validate_depth_map(depth_map: np.ndarray | None) -> np.ndarray | None:
"""Validate the depth map."""
return NumpyDepthValidator.validate_depth_map(depth_map)

@staticmethod
def validate_depth_path(depth_path: str | None) -> str | None:
"""Validate the depth path."""
return NumpyDepthValidator.validate_depth_path(depth_path)


class NumpyDepthBatch(
BatchIterateMixin[NumpyDepthItem],
_DepthInputFields[np.ndarray, list[str]],
NumpyBatch,
):
"""Dataclass for a batch of depth items in Anomalib datasets using numpy arrays."""

item_class = NumpyDepthItem

@staticmethod
def validate_image(image: np.ndarray) -> np.ndarray:
"""Validate the image."""
return NumpyDepthBatchValidator.validate_image(image)

def validate_gt_label(self, gt_label: np.ndarray | None) -> np.ndarray | None:
"""Validate the ground truth label."""
return NumpyDepthBatchValidator.validate_gt_label(gt_label, self.batch_size)

def validate_gt_mask(self, gt_mask: np.ndarray | None) -> np.ndarray | None:
"""Validate the ground truth mask."""
return NumpyDepthBatchValidator.validate_gt_mask(gt_mask, self.batch_size)

def validate_mask_path(self, mask_path: list[str] | None) -> list[str] | None:
"""Validate the mask path."""
return NumpyDepthBatchValidator.validate_mask_path(mask_path, self.batch_size)

def validate_anomaly_map(self, anomaly_map: np.ndarray | None) -> np.ndarray | None:
"""Validate the anomaly map."""
return NumpyDepthBatchValidator.validate_anomaly_map(anomaly_map, self.batch_size)

@staticmethod
def validate_pred_score(pred_score: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction score."""
return NumpyDepthBatchValidator.validate_pred_score(pred_score)

def validate_pred_mask(self, pred_mask: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction mask."""
return NumpyDepthBatchValidator.validate_pred_mask(pred_mask, self.batch_size)

@staticmethod
def validate_pred_label(pred_label: np.ndarray | None) -> np.ndarray | None:
"""Validate the prediction label."""
return NumpyDepthBatchValidator.validate_pred_label(pred_label)

@staticmethod
def validate_image_path(image_path: list[str] | None) -> list[str] | None:
"""Validate the image path."""
return NumpyDepthBatchValidator.validate_image_path(image_path)

def validate_depth_map(self, depth_map: np.ndarray | None) -> np.ndarray | None:
"""Validate the depth map."""
return NumpyDepthBatchValidator.validate_depth_map(depth_map, self.batch_size)

@staticmethod
def validate_depth_path(depth_path: list[str] | None) -> list[str] | None:
"""Validate the depth path."""
return NumpyDepthBatchValidator.validate_depth_path(depth_path)
Loading