Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ lint.ignore = [
"A005", #Β Module is shadowing a Python built-in
"B909", # Mutation to loop iterable during iteration
"PLR6301", # could be a function, class method or static method
"PLW1514", # Add explicit encoding argument
"PLR6201", # Convert to set
"PLC2701", # Private name import
"PLC0415", # import should be at the top of the file
"PLR0917", # Too many positional arguments
Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow
saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models).
"""
is_zero_or_few_shot = trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
is_zero_or_few_shot = trainer.lightning_module.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}
return (
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit
or trainer.state.fn not in {TrainerFn.FITTING, TrainerFn.VALIDATING} # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or (self._last_global_step_saved == trainer.global_step and not is_zero_or_few_shot)
)
Expand All @@ -52,7 +52,7 @@ def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end

if trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
if trainer.lightning_module.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
return False

return super()._should_save_on_train_epoch_end(trainer)
10 changes: 5 additions & 5 deletions src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("--metrics.image", type=list[str] | str | None, default=None)
parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"):
if hasattr(parser, "subcommand") and parser.subcommand not in {"export", "predict"}:
parser.link_arguments("task", "data.init_args.task")
parser.add_argument(
"--default_root_dir",
Expand Down Expand Up @@ -273,7 +273,7 @@ def _set_install_subcommand(self, action_subcommand: _ActionSubCommands) -> None
def before_instantiate_classes(self) -> None:
"""Modify the configuration to properly instantiate classes and sets up tiler."""
subcommand = self.config["subcommand"]
if subcommand in (*self.subcommands(), "train", "predict"):
if subcommand in {*self.subcommands(), "train", "predict"}:
self.config[subcommand] = update_config(self.config[subcommand])

def instantiate_classes(self) -> None:
Expand All @@ -283,7 +283,7 @@ def instantiate_classes(self) -> None:
But for subcommands we do not want to instantiate any trainer specific classes such as datamodule, model, etc
This is because the subcommand is responsible for instantiating and executing code based on the passed config
"""
if self.config["subcommand"] in (*self.subcommands(), "predict"): # trainer commands
if self.config["subcommand"] in {*self.subcommands(), "predict"}: # trainer commands
# since all classes are instantiated, the LightningCLI also creates an unused ``Trainer`` object.
# the minor change here is that engine is instantiated instead of trainer
self.config_init = self.parser.instantiate_classes(self.config)
Expand All @@ -296,7 +296,7 @@ def instantiate_classes(self) -> None:
else:
self.config_init = self.parser.instantiate_classes(self.config)
subcommand = self.config["subcommand"]
if subcommand in ("train", "export"):
if subcommand in {"train", "export"}:
self.instantiate_engine()
if "model" in self.config_init[subcommand]:
self.model = self._get(self.config_init, "model")
Expand Down Expand Up @@ -352,7 +352,7 @@ def _run_subcommand(self) -> None:

install_kwargs = self.config.get("install", {})
anomalib_install(**install_kwargs)
elif self.config["subcommand"] in (*self.subcommands(), "train", "export", "predict"):
elif self.config["subcommand"] in {*self.subcommands(), "train", "export", "predict"}:
fn = getattr(self.engine, self.subcommand)
fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand)
fn(**fn_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def anomalib_install(option: str = "full", verbose: bool = False) -> int:

# Parse requirements into torch and other requirements.
# This is done to parse the correct version of torch (cpu/cuda).
torch_requirement, other_requirements = parse_requirements(requirements, skip_torch=option not in ("full", "core"))
torch_requirement, other_requirements = parse_requirements(requirements, skip_torch=option not in {"full", "core"})

# Get install args for torch to install it from a specific index-url
install_args: list[str] = []
torch_install_args = []
if option in ("full", "core") and torch_requirement is not None:
if option in {"full", "core"} and torch_requirement is not None:
torch_install_args = get_torch_install_args(torch_requirement)

# Combine torch and other requirements.
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/cli/utils/help_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_verbosity_subcommand() -> dict:
{'subcommand': 'train', 'help': True, 'verbosity': 1}
"""
arguments: dict = {"subcommand": None, "help": False, "verbosity": 2}
if len(sys.argv) >= 2 and sys.argv[1] not in ("--help", "-h"):
if len(sys.argv) >= 2 and sys.argv[1] not in {"--help", "-h"}:
arguments["subcommand"] = sys.argv[1]
if "--help" in sys.argv or "-h" in sys.argv:
arguments["help"] = True
Expand Down Expand Up @@ -252,7 +252,7 @@ def format_help(self) -> str:
"""
with self.console.capture() as capture:
section = self._root_section
if self.subcommand in REQUIRED_ARGUMENTS and self.verbosity_level in (0, 1) and len(section.rich_items) > 1:
if self.subcommand in REQUIRED_ARGUMENTS and self.verbosity_level in {0, 1} and len(section.rich_items) > 1:
contents = render_guide(self.subcommand)
for content in contents:
self.console.print(content)
Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/cli/utils/installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_cuda_version() -> str | None:
# Check $CUDA_HOME/version.json file.
version_file = Path(cuda_home) / "version.json"
if version_file.is_file():
with Path(version_file).open() as file:
with Path(version_file).open(encoding="utf-8") as file:
data = json.load(file)
cuda_version = data.get("cuda", {}).get("version", None)
if cuda_version is not None:
Expand Down Expand Up @@ -319,7 +319,7 @@ def get_torch_install_args(requirement: str | Requirement) -> list[str]:
)
install_args: list[str] = []

if platform.system() in ("Linux", "Windows"):
if platform.system() in {"Linux", "Windows"}:
# Get the hardware suffix (eg., +cpu, +cu116 and +cu118 etc.)
hardware_suffix = get_hardware_suffix(with_available_torch_build=True, torch_version=version)

Expand All @@ -339,7 +339,7 @@ def get_torch_install_args(requirement: str | Requirement) -> list[str]:
torch_version,
torchvision_requirement,
]
elif platform.system() in ("macos", "Darwin"):
elif platform.system() in {"macos", "Darwin"}:
torch_version = str(requirement)
install_args += [torch_version]
else:
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/cli/utils/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def add_openvino_export_arguments(parser: ArgumentParser) -> None:
ov_parser = get_common_cli_parser()
# remove redundant keys from mo keys
for arg in ov_parser._actions: # noqa: SLF001
if arg.dest in ("help", "input_model", "output_dir"):
if arg.dest in {"help", "input_model", "output_dir"}:
continue
group.add_argument(f"--ov_args.{arg.dest}", type=arg.type, default=arg.default, help=arg.help)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/data/image/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFram
path = validate_path(path)

samples_list = [
(str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in (".bmp", ".png")
(str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in {".bmp", ".png"}
]
if not samples_list:
msg = f"Found 0 images in {path}"
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/data/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def extract(file_name: Path, root: Path) -> None:
zip_file.extract(file_info, root)

# Safely extract tar files.
elif file_name.suffix in (".tar", ".gz", ".xz", ".tgz"):
elif file_name.suffix in {".tar", ".gz", ".xz", ".tgz"}:
with tarfile.open(file_name) as tar_file:
members = tar_file.getmembers()
safe_members = [member for member in members if not is_file_potentially_dangerous(member.name)]
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/data/utils/tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(
msg,
)

if self.mode not in (ImageUpscaleMode.PADDING, ImageUpscaleMode.INTERPOLATION):
if self.mode not in {ImageUpscaleMode.PADDING, ImageUpscaleMode.INTERPOLATION}:
msg = f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation"
raise ValueError(msg)

Expand Down
7 changes: 2 additions & 5 deletions src/anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any,
model = core.read_model(model=path[0], weights=path[1])
else:
path = path if isinstance(path, Path) else Path(path)
if path.suffix in (".bin", ".xml"):
if path.suffix in {".bin", ".xml"}:
if path.suffix == ".bin":
bin_path, xml_path = path, path.with_suffix(".xml")
elif path.suffix == ".xml":
Expand Down Expand Up @@ -199,7 +199,4 @@ def predict(
predictions = self.model(image)
pred_dict = self.post_process(predictions)

return NumpyImageBatch(
image=image,
**pred_dict,
)
return NumpyImageBatch(image=image, **pred_dict)
4 changes: 2 additions & 2 deletions src/anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_device(device: str) -> torch.device:
Returns:
torch.device: Device to use for inference.
"""
if device not in ("auto", "cpu", "cuda", "gpu"):
if device not in {"auto", "cpu", "cuda", "gpu"}:
msg = f"Unknown device {device}"
raise ValueError(msg)

Expand All @@ -90,7 +90,7 @@ def _load_checkpoint(self, path: str | Path) -> dict:
if isinstance(path, str):
path = Path(path)

if path.suffix not in (".pt", ".pth"):
if path.suffix not in {".pt", ".pth"}:
msg = f"Unknown torch checkpoint file format {path.suffix}. Make sure you save the Torch model."
raise ValueError(msg)

Expand Down
10 changes: 5 additions & 5 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import torch
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -358,7 +358,7 @@ def _setup_transform(

def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()]
_callbacks: list[Callback] = []

# Add ModelCheckpoint if it is not in the callbacks list.
has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"])
Expand Down Expand Up @@ -418,7 +418,7 @@ def _should_run_validation(
bool: Whether it is needed to run a validation sequence.
"""
# validation before predict is only necessary for zero-/few-shot models
if model.learning_type not in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
if model.learning_type not in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
return False
# check if a checkpoint path is provided
return ckpt_path is None
Expand Down Expand Up @@ -472,7 +472,7 @@ def fit(
self._setup_trainer(model)
self._setup_dataset_task(train_dataloaders, val_dataloaders, datamodule)
self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path)
if model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path)
else:
Expand Down Expand Up @@ -795,7 +795,7 @@ def train(
datamodule,
)
self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path)
if model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/metrics/threshold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .base import BaseThreshold
from .base import BaseThreshold, Threshold
from .f1_adaptive_threshold import F1AdaptiveThreshold
from .manual_threshold import ManualThreshold

__all__ = ["BaseThreshold", "F1AdaptiveThreshold", "ManualThreshold"]
__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"]
30 changes: 25 additions & 5 deletions src/anomalib/metrics/threshold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC
import warnings

import torch
from torchmetrics import Metric


class BaseThreshold(Metric, ABC):
"""Base class for thresholding metrics."""
class Threshold(Metric):
"""Base class for thresholding metrics.

This class serves as the foundation for all threshold-based metrics in the system.
It inherits from torchmetrics.Metric and provides a common interface for
threshold computation and updates.

Subclasses should implement the `compute` and `update` methods to define
specific threshold calculation logic.
"""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
Expand All @@ -21,7 +29,7 @@ def compute(self) -> torch.Tensor:
Returns:
Value of the optimal threshold.
"""
msg = "Subclass of BaseAnomalyScoreThreshold must implement the compute method"
msg = "Subclass of Threshold must implement the compute method"
raise NotImplementedError(msg)

def update(self, *args, **kwargs) -> None: # noqa: ARG002
Expand All @@ -31,5 +39,17 @@ def update(self, *args, **kwargs) -> None: # noqa: ARG002
*args: Any positional arguments.
**kwargs: Any keyword arguments.
"""
msg = "Subclass of BaseAnomalyScoreThreshold must implement the update method"
msg = "Subclass of Threshold must implement the update method"
raise NotImplementedError(msg)


class BaseThreshold(Threshold):
"""Alias for Threshold class for backward compatibility."""

def __init__(self, **kwargs) -> None:
warnings.warn(
"BaseThreshold is deprecated and will be removed in a future version. Use Threshold instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)
4 changes: 2 additions & 2 deletions src/anomalib/metrics/threshold/f1_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve

from .base import BaseThreshold
from .base import Threshold

logger = logging.getLogger(__name__)


class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, BaseThreshold):
class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
"""Anomaly Score Threshold.

This class computes/stores the threshold that determines the anomalous label
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/metrics/threshold/manual_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch

from .base import BaseThreshold
from .base import Threshold


class ManualThreshold(BaseThreshold):
class ManualThreshold(Threshold):
"""Initialize Manual Threshold.

Args:
Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from anomalib import LearningType
from anomalib.dataclasses import Batch, InferenceBatch
from anomalib.metrics.threshold import BaseThreshold
from anomalib.metrics.threshold import Threshold
from anomalib.post_processing import OneClassPostProcessor, PostProcessor

from .export_mixin import ExportMixin
Expand Down Expand Up @@ -157,7 +157,7 @@ def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars:

return super()._save_to_state_dict(destination, prefix, keep_vars)

def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> BaseThreshold:
def _get_instance(self, state_dict: OrderedDict[str, Any], dict_key: str) -> Threshold:
"""Get the threshold class from the ``state_dict``."""
class_path = state_dict.pop(dict_key)
module = importlib.import_module(".".join(class_path.split(".")[:-1]))
Expand Down Expand Up @@ -292,7 +292,7 @@ def from_config(
model_parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION)
model_parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
model_parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
model_parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
model_parser.add_argument("--metrics.threshold", type=Threshold | str, default="F1AdaptiveThreshold")
model_parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True)
args = ["--config", str(config_path)]
for key, value in kwargs.items():
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/models/components/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

import torch
from torch.nn import functional as F # noqa: N812
from tqdm import tqdm

from anomalib.models.components.dimensionality_reduction import SparseRandomProjection
from anomalib.utils.rich import safe_track


class KCenterGreedy:
Expand Down Expand Up @@ -98,7 +98,7 @@ def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[in

selected_coreset_idxs: list[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
for _ in safe_track(sequence=range(self.coreset_size), description="Selecting Coreset Indices."):
for _ in tqdm(range(self.coreset_size), desc="Selecting Coreset Indices."):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
Expand Down
Loading