|
| 1 | +"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification.""" |
| 2 | + |
| 3 | +# Copyright (C) 2024 Intel Corporation |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | + |
| 6 | +import logging |
| 7 | +from enum import Enum |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch.utils.data import DataLoader |
| 11 | + |
| 12 | +from anomalib import LearningType |
| 13 | +from anomalib.models import AnomalyModule |
| 14 | + |
| 15 | +from .backends import Backend, Ollama |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class VlmAdBackend(Enum): |
| 21 | + """Supported VLM backends.""" |
| 22 | + |
| 23 | + OLLAMA = "ollama" |
| 24 | + |
| 25 | + |
| 26 | +class VlmAd(AnomalyModule): |
| 27 | + """Visual anomaly model.""" |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + backend: VlmAdBackend | str = VlmAdBackend.OLLAMA, |
| 32 | + api_key: str | None = None, |
| 33 | + k_shot: int = 3, |
| 34 | + ) -> None: |
| 35 | + super().__init__() |
| 36 | + self.k_shot = k_shot |
| 37 | + backend = VlmAdBackend(backend) |
| 38 | + self.vlm_backend: Backend = self._setup_vlm(backend, api_key) |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def _setup_vlm(backend: VlmAdBackend, api_key: str | None) -> Backend: |
| 42 | + match backend: |
| 43 | + case VlmAdBackend.OLLAMA: |
| 44 | + return Ollama() |
| 45 | + case _: |
| 46 | + msg = f"Unsupported VLM backend: {backend}" |
| 47 | + raise ValueError(msg) |
| 48 | + |
| 49 | + def _setup(self) -> None: |
| 50 | + if self.k_shot: |
| 51 | + logger.info("Collecting reference images from training dataset.") |
| 52 | + dataloader = self.trainer.datamodule.train_dataloader() |
| 53 | + self.collect_reference_images(dataloader) |
| 54 | + |
| 55 | + def collect_reference_images(self, dataloader: DataLoader) -> None: |
| 56 | + """Collect reference images for few-shot inference.""" |
| 57 | + count = 0 |
| 58 | + for batch in dataloader: |
| 59 | + for img_path in batch["image_path"]: |
| 60 | + self.vlm_backend.add_reference_images(img_path) |
| 61 | + count += 1 |
| 62 | + if count == self.k_shot: |
| 63 | + return |
| 64 | + |
| 65 | + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict: |
| 66 | + """Validation step.""" |
| 67 | + del args, kwargs # These variables are not used. |
| 68 | + responses = [(self.vlm_backend.predict(img_path)) for img_path in batch["image_path"]] |
| 69 | + |
| 70 | + batch["str_output"] = responses |
| 71 | + batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device) |
| 72 | + return batch |
| 73 | + |
| 74 | + @property |
| 75 | + def learning_type(self) -> LearningType: |
| 76 | + """The learning type of the model.""" |
| 77 | + return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT |
| 78 | + |
| 79 | + @property |
| 80 | + def trainer_arguments(self) -> dict[str, int | float]: |
| 81 | + """Doesn't need training.""" |
| 82 | + return {} |
| 83 | + |
| 84 | + @staticmethod |
| 85 | + def configure_transforms(image_size: tuple[int, int] | None = None) -> None: |
| 86 | + """This modes does not require any transforms.""" |
| 87 | + if image_size is not None: |
| 88 | + logger.warning("Ignoring image_size argument as each backend has its own transforms.") |
0 commit comments