Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions src/anomalib/models/components/base/anomalib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import lightning.pytorch as pl
import torch
from lightning.pytorch import Callback
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torchvision.transforms.v2 import Compose, Normalize, Resize
Expand Down Expand Up @@ -143,7 +142,6 @@ def __init__(
self.visualizer = self._resolve_visualizer(visualizer)

self._input_size: tuple[int, int] | None = None
self._is_setup = False

@property
def name(self) -> str:
Expand All @@ -154,33 +152,6 @@ def name(self) -> str:
"""
return self.__class__.__name__

def setup(self, stage: str | None = None) -> None:
"""Set up the model if not already done.

This method ensures the model is built by calling ``_setup()`` if needed.

Args:
stage (str | None, optional): Current stage of training.
Defaults to ``None``.
"""
if getattr(self, "model", None) is None or not self._is_setup:
self._setup()
if isinstance(stage, TrainerFn):
# only set the flag if the stage is a TrainerFn, which means the
# setup has been called from a trainer
self._is_setup = True

def _setup(self) -> None:
"""Set up the model architecture.

This method should be overridden by subclasses to build their model
architecture. It is called by ``setup()`` when the model needs to be
initialized.

This is useful when the model cannot be fully initialized in ``__init__``
because it requires data-dependent parameters.
"""

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure callbacks for the model.

Expand Down
8 changes: 7 additions & 1 deletion src/anomalib/models/image/winclip/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def __init__(
self.class_name = class_name
self.k_shot = k_shot
self.few_shot_source = Path(few_shot_source) if few_shot_source else None
self.is_setup = False

def _setup(self) -> None:
def setup(self, stage: str) -> None:
"""Setup WinCLIP model.

This method:
Expand All @@ -137,6 +138,10 @@ def _setup(self) -> None:
Note:
This hook is called before the model is moved to the target device.
"""
del stage
if self.is_setup:
return

# get class name
self.class_name = self._get_class_name()
ref_images = None
Expand All @@ -158,6 +163,7 @@ def _setup(self) -> None:

# call setup to initialize the model
self.model.setup(self.class_name, ref_images)
self.is_setup = True

def _get_class_name(self) -> str:
"""Get the class name used in the prompt ensemble.
Expand Down