diff --git a/pyproject.toml b/pyproject.toml index 2893ad20c4..6a1af889b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ core = [ "open-clip-torch>=2.23.0,<2.26.1", ] openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"] +vlm = ["ollama", "transformers"] loggers = [ "comet-ml>=3.31.7", "gradio>=4", @@ -84,7 +85,7 @@ test = [ "coverage[toml]", "tox", ] -full = ["anomalib[core,openvino,loggers,notebooks]"] +full = ["anomalib[core,openvino,loggers,notebooks, vlm]"] dev = ["anomalib[full,docs,test]"] [project.scripts] diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 69243d07c3..f10954205b 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -34,6 +34,7 @@ Rkde, Stfpm, Uflow, + VlmAd, WinClip, ) from .video import AiVad @@ -62,6 +63,7 @@ class UnknownModelError(ModuleNotFoundError): "Stfpm", "Uflow", "AiVad", + "VlmAd", "WinClip", "Llm", "Llmollama", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 9b18c9b3b9..fabfbded9a 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -24,6 +24,7 @@ from .rkde import Rkde from .stfpm import Stfpm from .uflow import Uflow +from .vlm import VlmAd from .winclip import WinClip __all__ = [ @@ -44,6 +45,7 @@ "Rkde", "Stfpm", "Uflow", + "VlmAd", "WinClip", "Llm", "Llmollama", diff --git a/src/anomalib/models/image/vlm_ad/__init__.py b/src/anomalib/models/image/vlm_ad/__init__.py new file mode 100644 index 0000000000..46ab8e0fee --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/__init__.py @@ -0,0 +1,8 @@ +"""Visual Anomaly Model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import VlmAd + +__all__ = ["VlmAd"] diff --git a/src/anomalib/models/image/vlm_ad/backends/__init__.py b/src/anomalib/models/image/vlm_ad/backends/__init__.py new file mode 100644 index 0000000000..c1653ece3b --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/__init__.py @@ -0,0 +1,9 @@ +"""VLM backends.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .base import Backend +from .ollama import Ollama + +__all__ = ["Backend", "Ollama"] diff --git a/src/anomalib/models/image/vlm_ad/backends/base.py b/src/anomalib/models/image/vlm_ad/backends/base.py new file mode 100644 index 0000000000..7e27c2c74a --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/base.py @@ -0,0 +1,23 @@ +"""Base backend.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from pathlib import Path + + +class Backend(ABC): + """Base backend.""" + + @abstractmethod + def __init__(self, api_key: str | None = None) -> None: + """Initialize the backend.""" + + @abstractmethod + def add_reference_images(self, image: str | Path) -> None: + """Add reference images for k-shot.""" + + @abstractmethod + def predict(self, image: str | Path) -> str: + """Predict the anomaly label.""" diff --git a/src/anomalib/models/image/vlm_ad/backends/ollama.py b/src/anomalib/models/image/vlm_ad/backends/ollama.py new file mode 100644 index 0000000000..41df554fda --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/ollama.py @@ -0,0 +1,89 @@ +"""Ollama backend. + +Assumes that the Ollama service is running in the background. +See: https://github.com/ollama/ollama +Ensure that ollama is running. On linux: `ollama serve` +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from dataclasses import dataclass +from pathlib import Path + +from anomalib.utils.exceptions import try_import + +from .base import Backend + +if try_import("ollama"): + from ollama import chat + from ollama._client import _encode_image +else: + chat = None + +logger = logging.getLogger(__name__) + + +@dataclass +class Prompt: + """Ollama prompt.""" + + few_shot: str + predict: str + + +class Ollama(Backend): + """Ollama backend.""" + + def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None: + """Initialize the Ollama backend.""" + if api_key: + logger.warning("API key is not required for Ollama backend.") + self.model_name: str = model_name + self._ref_images_encoded: list[str] = [] + + def add_reference_images(self, image: str | Path) -> None: + """Encode the image to base64.""" + self._ref_images_encoded.append(_encode_image(image)) + + @property + def prompt(self) -> Prompt: + """Get the Ollama prompt.""" + return Prompt( + predict=( + "You are given an image. It is either normal or anomalous." + "First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n" + "Then give the reason for your decision.\n" + "For example, 'YES: The image has a crack on the wall.'" + ), + few_shot=( + "These are a few examples of normal picture without any anomalies." + " You have to use these to determine if the image I provide in the next" + " chat is normal or anomalous." + ), + ) + + def predict(self, image: str | Path) -> str: + """Predict the anomaly label.""" + if not chat: + msg = "Ollama is not installed. Please install it using `pip install ollama`." + raise ImportError(msg) + image_encoded = _encode_image(image) + messages = [] + + # few-shot + if len(self._ref_images_encoded) > 0: + messages.append({ + "role": "user", + "images": self._ref_images_encoded, + "content": self.prompt.few_shot, + }) + + messages.append({"role": "user", "images": [image_encoded], "content": self.prompt.predict}) + + response = chat( + model=self.model_name, + messages=messages, + ) + return response["message"]["content"].strip() diff --git a/src/anomalib/models/image/vlm_ad/lightning_model.py b/src/anomalib/models/image/vlm_ad/lightning_model.py new file mode 100644 index 0000000000..1a7ffffc55 --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/lightning_model.py @@ -0,0 +1,88 @@ +"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from enum import Enum + +import torch +from torch.utils.data import DataLoader + +from anomalib import LearningType +from anomalib.models import AnomalyModule + +from .backends import Backend, Ollama + +logger = logging.getLogger(__name__) + + +class VlmAdBackend(Enum): + """Supported VLM backends.""" + + OLLAMA = "ollama" + + +class VlmAd(AnomalyModule): + """Visual anomaly model.""" + + def __init__( + self, + backend: VlmAdBackend | str = VlmAdBackend.OLLAMA, + api_key: str | None = None, + k_shot: int = 3, + ) -> None: + super().__init__() + self.k_shot = k_shot + backend = VlmAdBackend(backend) + self.vlm_backend: Backend = self._setup_vlm(backend, api_key) + + @staticmethod + def _setup_vlm(backend: VlmAdBackend, api_key: str | None) -> Backend: + match backend: + case VlmAdBackend.OLLAMA: + return Ollama() + case _: + msg = f"Unsupported VLM backend: {backend}" + raise ValueError(msg) + + def _setup(self) -> None: + if self.k_shot: + logger.info("Collecting reference images from training dataset.") + dataloader = self.trainer.datamodule.train_dataloader() + self.collect_reference_images(dataloader) + + def collect_reference_images(self, dataloader: DataLoader) -> None: + """Collect reference images for few-shot inference.""" + count = 0 + for batch in dataloader: + for img_path in batch["image_path"]: + self.vlm_backend.add_reference_images(img_path) + count += 1 + if count == self.k_shot: + return + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict: + """Validation step.""" + del args, kwargs # These variables are not used. + responses = [(self.vlm_backend.predict(img_path)) for img_path in batch["image_path"]] + + batch["str_output"] = responses + batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device) + return batch + + @property + def learning_type(self) -> LearningType: + """The learning type of the model.""" + return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT + + @property + def trainer_arguments(self) -> dict[str, int | float]: + """Doesn't need training.""" + return {} + + @staticmethod + def configure_transforms(image_size: tuple[int, int] | None = None) -> None: + """This modes does not require any transforms.""" + if image_size is not None: + logger.warning("Ignoring image_size argument as each backend has its own transforms.") diff --git a/src/anomalib/utils/visualization/image.py b/src/anomalib/utils/visualization/image.py index 0edc5b1c29..f7ff593df1 100644 --- a/src/anomalib/utils/visualization/image.py +++ b/src/anomalib/utils/visualization/image.py @@ -3,7 +3,6 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import textwrap from collections.abc import Iterator from enum import Enum from pathlib import Path @@ -44,8 +43,7 @@ def __init__( self, image: np.ndarray, pred_score: float, - pred_label: None, - text_descr: str | None = None, + pred_label: str, anomaly_map: np.ndarray | None = None, gt_mask: np.ndarray | None = None, pred_mask: np.ndarray | None = None, @@ -54,8 +52,6 @@ def __init__( box_labels: np.ndarray | None = None, normalize: bool = False, ) -> None: - self.text_descr = text_descr - self.anomaly_map = anomaly_map self.box_labels = box_labels self.gt_boxes = gt_boxes @@ -71,7 +67,9 @@ def __init__( if anomaly_map is not None: self.heat_map = superimpose_anomaly_map( - self.anomaly_map, self.image, normalize=normalize + self.anomaly_map, + self.image, + normalize=normalize, ) if self.gt_mask is not None and self.gt_mask.max() <= 1.0: @@ -81,7 +79,10 @@ def __init__( if self.pred_mask is not None and self.pred_mask.max() <= 1.0: self.pred_mask *= 255 self.segmentations = mark_boundaries( - self.image, self.pred_mask, color=(1, 0, 0), mode="thick" + self.image, + self.pred_mask, + color=(1, 0, 0), + mode="thick", ) if self.segmentations.max() <= 1.0: self.segmentations = (self.segmentations * 255).astype(np.uint8) @@ -101,28 +102,11 @@ def __repr__(self) -> str: f"anomaly_map={self.anomaly_map}, gt_mask={self.gt_mask}, " f"gt_boxes={self.gt_boxes}, pred_boxes={self.pred_boxes}, box_labels={self.box_labels}" ) - repr_str += ( - f", pred_mask={self.pred_mask}" if self.pred_mask is not None else "" - ) + repr_str += f", pred_mask={self.pred_mask}" if self.pred_mask is not None else "" repr_str += f", heat_map={self.heat_map}" if self.heat_map is not None else "" - repr_str += ( - f", segmentations={self.segmentations}" - if self.segmentations is not None - else "" - ) - repr_str += ( - f", normal_boxes={self.normal_boxes}" - if self.normal_boxes is not None - else "" - ) - repr_str += ( - f", anomalous_boxes={self.anomalous_boxes}" - if self.anomalous_boxes is not None - else "" - ) - repr_str += ( - f", text_descr={self.text_descr}" if self.text_descr is not None else "" - ) + repr_str += f", segmentations={self.segmentations}" if self.segmentations is not None else "" + repr_str += f", normal_boxes={self.normal_boxes}" if self.normal_boxes is not None else "" + repr_str += f", anomalous_boxes={self.anomalous_boxes}" if self.anomalous_boxes is not None else "" repr_str += ")" return repr_str @@ -173,13 +157,17 @@ def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]: height, width = batch["image"].shape[-2:] image = (read_image(path=batch["image_path"][i]) * 255).astype(np.uint8) image = cv2.resize( - image, dsize=(width, height), interpolation=cv2.INTER_AREA + image, + dsize=(width, height), + interpolation=cv2.INTER_AREA, ) elif "video_path" in batch: height, width = batch["image"].shape[-2:] image = batch["original_image"][i].squeeze().cpu().numpy() image = cv2.resize( - image, dsize=(width, height), interpolation=cv2.INTER_AREA + image, + dsize=(width, height), + interpolation=cv2.INTER_AREA, ) else: msg = "Batch must have either 'image_path' or 'video_path' defined." @@ -195,39 +183,20 @@ def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]: image_result = ImageResult( image=image, - text_descr=batch["str_output"][i] if "str_output" in batch else None, - pred_score=( - batch["pred_scores"][i].cpu().numpy().item() - if "pred_scores" in batch - else None - ), - pred_label=( - batch["pred_labels"][i].cpu().numpy().item() - if "pred_labels" in batch - else None - ), - anomaly_map=( - batch["anomaly_maps"][i].cpu().numpy() - if "anomaly_maps" in batch - else None - ), - pred_mask=( - batch["pred_masks"][i].squeeze().int().cpu().numpy() - if "pred_masks" in batch - else None - ), - gt_mask=( - batch["mask"][i].squeeze().int().cpu().numpy() - if "mask" in batch - else None - ), + pred_score=(batch["pred_scores"][i].cpu().numpy().item() if "pred_scores" in batch else None), + pred_label=(batch["pred_labels"][i].cpu().numpy().item() if "pred_labels" in batch else None), + anomaly_map=(batch["anomaly_maps"][i].cpu().numpy() if "anomaly_maps" in batch else None), + pred_mask=(batch["pred_masks"][i].squeeze().int().cpu().numpy() if "pred_masks" in batch else None), + gt_mask=(batch["mask"][i].squeeze().int().cpu().numpy() if "mask" in batch else None), gt_boxes=batch["boxes"][i].cpu().numpy() if "boxes" in batch else None, - yield GeneratorResult( - image=self.visualize_image(image_result), file_name=file_name pred_boxes=batch["pred_boxes"][i].cpu().numpy() if "pred_boxes" in batch else None, box_labels=batch["box_labels"][i].cpu().numpy() if "box_labels" in batch else None, normalize=self.normalize, ) + yield GeneratorResult( + image=self.visualize_image(image_result), + file_name=file_name, + ) def visualize_image(self, image_result: ImageResult) -> np.ndarray: """Generate the visualization for an image. @@ -272,7 +241,9 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: color=(255, 0, 0), ) image_grid.add_image( - image=gt_image, color_map="gray", title="Ground Truth" + image=gt_image, + color_map="gray", + title="Ground Truth", ) else: image_grid.add_image(image_result.image, "Image") @@ -282,7 +253,9 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: color=(0, 255, 0), ) pred_image = draw_boxes( - pred_image, image_result.anomalous_boxes, color=(255, 0, 0) + pred_image, + image_result.anomalous_boxes, + color=(255, 0, 0), ) image_grid.add_image(pred_image, "Predictions") if self.task == TaskType.SEGMENTATION: @@ -293,14 +266,19 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: image_grid.add_image(image_result.image, "Image") if image_result.gt_mask is not None: image_grid.add_image( - image=image_result.gt_mask, color_map="gray", title="Ground Truth" + image=image_result.gt_mask, + color_map="gray", + title="Ground Truth", ) image_grid.add_image(image_result.heat_map, "Predicted Heat Map") image_grid.add_image( - image=image_result.pred_mask, color_map="gray", title="Predicted Mask" + image=image_result.pred_mask, + color_map="gray", + title="Predicted Mask", ) image_grid.add_image( - image=image_result.segmentations, title="Segmentation Result" + image=image_result.segmentations, + title="Segmentation Result", ) elif self.task == TaskType.CLASSIFICATION: image_grid.add_image(image_result.image, title="Image") @@ -308,23 +286,23 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: image_grid.add_image(image_result.heat_map, "Predicted Heat Map") if image_result.pred_label: image_classified = add_anomalous_label( - image_result.image, image_result.pred_score + image_result.image, + image_result.pred_score, ) else: image_classified = add_normal_label( - image_result.image, 1 - image_result.pred_score + image_result.image, + 1 - image_result.pred_score, ) image_grid.add_image(image=image_classified, title="Prediction") elif self.task == TaskType.EXPLANATION: - description = "" - if image_result.text_descr: - description = image_result.text_descr - image_classified = add_normal_label( - image_result.image, 1 - image_result.pred_score + image_result.image, + 1 - image_result.pred_score, ) image_grid.add_image( - image_classified, title="Explanation of Image", description=description + image_classified, + title="Explanation of Image", ) return image_grid.generate() @@ -365,11 +343,13 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray: if self.task == TaskType.CLASSIFICATION: if image_result.pred_label: image_classified = add_anomalous_label( - image_result.image, image_result.pred_score + image_result.image, + image_result.pred_score, ) else: image_classified = add_normal_label( - image_result.image, 1 - image_result.pred_score + image_result.image, + 1 - image_result.pred_score, ) return image_classified msg = f"Unknown task type: {self.task}" @@ -393,7 +373,6 @@ def add_image( image: np.ndarray, title: str | None = None, color_map: str | None = None, - description: str | None = None, ) -> None: """Add an image to the grid. @@ -406,7 +385,6 @@ def add_image( "image": image, "title": title, "color_map": color_map, - "descr": description, } self.images.append(image_data) @@ -431,18 +409,7 @@ def generate(self) -> np.ndarray: axis.axes.yaxis.set_visible(b=False) axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255) if image_dict["title"] is not None: - # axis.title.set_text(image_dict["title"]) - pass - if image_dict["descr"] is not None: - # Wrap the text - # wrapped_text = textwrap.fill(image_dict["descr"][0]['response'], width=100/num_cols) # Adjust 'width' based on your subplot size and preference - wrapped_text = textwrap.fill( - image_dict["descr"], - width=70 // num_cols, - ) # Adjust 'width' based on your subplot size and preference - axis.set_title(wrapped_text, fontsize=10) - - self.figure.subplots_adjust(top=0.7) + axis.title.set_text(image_dict["title"]) self.figure.canvas.draw() # convert canvas to numpy array to prepare for visualization with opencv