diff --git a/src/anomalib/models/components/base/anomalib_module.py b/src/anomalib/models/components/base/anomalib_module.py index b5fc6a57cf..3cd20c356b 100644 --- a/src/anomalib/models/components/base/anomalib_module.py +++ b/src/anomalib/models/components/base/anomalib_module.py @@ -50,7 +50,6 @@ import lightning.pytorch as pl import torch from lightning.pytorch import Callback -from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn from torchvision.transforms.v2 import Compose, Normalize, Resize @@ -143,7 +142,6 @@ def __init__( self.visualizer = self._resolve_visualizer(visualizer) self._input_size: tuple[int, int] | None = None - self._is_setup = False @property def name(self) -> str: @@ -154,33 +152,6 @@ def name(self) -> str: """ return self.__class__.__name__ - def setup(self, stage: str | None = None) -> None: - """Set up the model if not already done. - - This method ensures the model is built by calling ``_setup()`` if needed. - - Args: - stage (str | None, optional): Current stage of training. - Defaults to ``None``. - """ - if getattr(self, "model", None) is None or not self._is_setup: - self._setup() - if isinstance(stage, TrainerFn): - # only set the flag if the stage is a TrainerFn, which means the - # setup has been called from a trainer - self._is_setup = True - - def _setup(self) -> None: - """Set up the model architecture. - - This method should be overridden by subclasses to build their model - architecture. It is called by ``setup()`` when the model needs to be - initialized. - - This is useful when the model cannot be fully initialized in ``__init__`` - because it requires data-dependent parameters. - """ - def configure_callbacks(self) -> Sequence[Callback] | Callback: """Configure callbacks for the model. diff --git a/src/anomalib/models/image/winclip/lightning_model.py b/src/anomalib/models/image/winclip/lightning_model.py index e078f60e50..1bdf7686db 100644 --- a/src/anomalib/models/image/winclip/lightning_model.py +++ b/src/anomalib/models/image/winclip/lightning_model.py @@ -125,8 +125,9 @@ def __init__( self.class_name = class_name self.k_shot = k_shot self.few_shot_source = Path(few_shot_source) if few_shot_source else None + self.is_setup = False - def _setup(self) -> None: + def setup(self, stage: str) -> None: """Setup WinCLIP model. This method: @@ -137,6 +138,10 @@ def _setup(self) -> None: Note: This hook is called before the model is moved to the target device. """ + del stage + if self.is_setup: + return + # get class name self.class_name = self._get_class_name() ref_images = None @@ -158,6 +163,7 @@ def _setup(self) -> None: # call setup to initialize the model self.model.setup(self.class_name, ref_images) + self.is_setup = True def _get_class_name(self) -> str: """Get the class name used in the prompt ensemble.