Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add `VlmAd` metric by [Bepitic](https://github.com/Bepitic) and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2344
- Add `AUPIMO` tutorials notebooks in https://github.com/openvinotoolkit/anomalib/pull/2330 and https://github.com/openvinotoolkit/anomalib/pull/2336
- Add `AUPIMO` metric by [jpcbertoldo](https://github.com/jpcbertoldo) in https://github.com/openvinotoolkit/anomalib/pull/1726 and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2329

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ core = [
"open-clip-torch>=2.23.0,<2.26.1",
]
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
vlm = ["ollama", "openai", "python-dotenv","transformers"]
loggers = [
"comet-ml>=3.31.7",
"gradio>=4",
Expand Down Expand Up @@ -84,7 +85,7 @@ test = [
"coverage[toml]",
"tox",
]
full = ["anomalib[core,openvino,loggers,notebooks]"]
full = ["anomalib[core,openvino,loggers,notebooks, vlm]"]
dev = ["anomalib[full,docs,test]"]

[project.scripts]
Expand Down
5 changes: 2 additions & 3 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def setup(
elif self.task == TaskType.CLASSIFICATION:
pixel_metric_names = []
logger.warning(
"Cannot perform pixel-level evaluation when task type is classification. "
"Ignoring the following pixel-level metrics: %s",
self.pixel_metric_names,
"Cannot perform pixel-level evaluation when task type is {self.task.value}. "
f"Ignoring the following pixel-level metrics: {self.pixel_metric_names}",
)
else:
pixel_metric_names = (
Expand Down
16 changes: 12 additions & 4 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from anomalib.utils.normalization import NormalizationMethod
from anomalib.utils.path import create_versioned_dir
from anomalib.utils.types import NORMALIZATION, THRESHOLD
from anomalib.utils.visualization import ImageVisualizer
from anomalib.utils.visualization import BaseVisualizer, ExplanationVisualizer, ImageVisualizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -322,7 +322,7 @@ def _setup_trainer(self, model: AnomalyModule) -> None:
self._cache.update(model)

# Setup anomalib callbacks to be used with the trainer
self._setup_anomalib_callbacks()
self._setup_anomalib_callbacks(model)

# Temporarily set devices to 1 to avoid issues with multiple processes
self._cache.args["devices"] = 1
Expand Down Expand Up @@ -405,7 +405,7 @@ def _setup_transform(
if not getattr(dataloader.dataset, "transform", None):
dataloader.dataset.transform = transform

def _setup_anomalib_callbacks(self) -> None:
def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []

Expand All @@ -432,9 +432,17 @@ def _setup_anomalib_callbacks(self) -> None:
_callbacks.append(_ThresholdCallback(self.threshold))
_callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names))

visualizer: BaseVisualizer

# TODO(ashwinvaidya17): temporary # noqa: TD003 ignoring as visualizer is getting a complete overhaul
if model.__class__.__name__ == "VlmAd":
visualizer = ExplanationVisualizer()
else:
visualizer = ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE)

_callbacks.append(
_VisualizationCallback(
visualizers=ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE),
visualizers=visualizer,
save=True,
root=self._cache.args["default_root_dir"] / "images",
),
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Rkde,
Stfpm,
Uflow,
VlmAd,
WinClip,
)
from .video import AiVad
Expand Down Expand Up @@ -58,6 +59,7 @@ class UnknownModelError(ModuleNotFoundError):
"Stfpm",
"Uflow",
"AiVad",
"VlmAd",
"WinClip",
]

Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .rkde import Rkde
from .stfpm import Stfpm
from .uflow import Uflow
from .vlm_ad import VlmAd
from .winclip import WinClip

__all__ = [
Expand All @@ -40,5 +41,6 @@
"Rkde",
"Stfpm",
"Uflow",
"VlmAd",
"WinClip",
]
8 changes: 8 additions & 0 deletions src/anomalib/models/image/vlm_ad/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Visual Anomaly Model."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import VlmAd

__all__ = ["VlmAd"]
11 changes: 11 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""VLM backends."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .base import Backend
from .chat_gpt import ChatGPT
from .huggingface import Huggingface
from .ollama import Ollama

__all__ = ["Backend", "ChatGPT", "Huggingface", "Ollama"]
30 changes: 30 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Base backend."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from pathlib import Path

from anomalib.models.image.vlm_ad.utils import Prompt


class Backend(ABC):
"""Base backend."""

@abstractmethod
def __init__(self, model_name: str) -> None:
"""Initialize the backend."""

@abstractmethod
def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""

@abstractmethod
def predict(self, image: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""

@property
@abstractmethod
def num_reference_images(self) -> int:
"""Get the number of reference images."""
109 changes: 109 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/chat_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""ChatGPT backend."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import base64
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

from dotenv import load_dotenv
from lightning_utilities.core.imports import package_available

from anomalib.models.image.vlm_ad.utils import Prompt

from .base import Backend

if package_available("openai"):
from openai import OpenAI
else:
OpenAI = None

if TYPE_CHECKING:
from openai.types.chat import ChatCompletion

logger = logging.getLogger(__name__)


class ChatGPT(Backend):
"""ChatGPT backend."""

def __init__(self, model_name: str, api_key: str | None = None) -> None:
"""Initialize the ChatGPT backend."""
self._ref_images_encoded: list[str] = []
self.model_name: str = model_name
self._client: OpenAI | None = None
self.api_key = self._get_api_key(api_key)

@property
def client(self) -> OpenAI:
"""Get the OpenAI client."""
if OpenAI is None:
msg = "OpenAI is not installed. Please install it to use ChatGPT backend."
raise ImportError(msg)
if self._client is None:
self._client = OpenAI(api_key=self.api_key)
return self._client

def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""
self._ref_images_encoded.append(self._encode_image_to_url(image))

@property
def num_reference_images(self) -> int:
"""Get the number of reference images."""
return len(self._ref_images_encoded)

def predict(self, image: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
image_encoded = self._encode_image_to_url(image)
messages = []

# few-shot
if len(self._ref_images_encoded) > 0:
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded))

messages.append(self._generate_message(content=prompt.predict, images=[image_encoded]))

response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name)
return response.choices[0].message.content

@staticmethod
def _generate_message(content: str, images: list[str] | None) -> dict:
"""Generate a message."""
message: dict[str, list[dict] | str] = {"role": "user"}
if images is not None:
_content: list[dict[str, str | dict]] = [{"type": "text", "text": content}]
_content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images])
message["content"] = _content
else:
message["content"] = content
return message

def _encode_image_to_url(self, image: str | Path) -> str:
"""Encode the image to base64 and embed in url string."""
image_path = Path(image)
extension = image_path.suffix
base64_encoded = self._encode_image_to_base_64(image_path)
return f"data:image/{extension};base64,{base64_encoded}"

@staticmethod
def _encode_image_to_base_64(image: str | Path) -> str:
"""Encode the image to base64."""
image = Path(image)
return base64.b64encode(image.read_bytes()).decode("utf-8")

def _get_api_key(self, api_key: str | None = None) -> str:
if api_key is None:
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
msg = (
f"OpenAI API key must be provided to use {self.model_name}."
" Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable"
" or in a `.env` file."
)
raise ValueError(msg)
return api_key
96 changes: 96 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Huggingface backend."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from pathlib import Path

from lightning_utilities.core.imports import package_available
from PIL import Image
from transformers.modeling_utils import PreTrainedModel

from anomalib.models.image.vlm_ad.utils import Prompt

from .base import Backend

if package_available("transformers"):
import transformers
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import ProcessorMixin
else:
transformers = None


logger = logging.getLogger(__name__)


class Huggingface(Backend):
"""Huggingface backend."""

def __init__(
self,
model_name: str,
) -> None:
"""Initialize the Huggingface backend."""
self.model_name: str = model_name
self._ref_images: list[str] = []
self._processor: ProcessorMixin | None = None
self._model: PreTrainedModel | None = None

@property
def processor(self) -> ProcessorMixin:
"""Get the Huggingface processor."""
if self._processor is None:
if transformers is None:
msg = "transformers is not installed."
raise ValueError(msg)
self._processor = transformers.LlavaNextProcessor.from_pretrained(self.model_name)
return self._processor

@property
def model(self) -> PreTrainedModel:
"""Get the Huggingface model."""
if self._model is None:
if transformers is None:
msg = "transformers is not installed."
raise ValueError(msg)
self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name)
return self._model

@staticmethod
def _generate_message(content: str, images: list[str] | None) -> dict:
"""Generate a message."""
message: dict[str, str | list[dict]] = {"role": "user"}
_content: list[dict[str, str]] = [{"type": "text", "text": content}]
if images is not None:
_content.extend([{"type": "image"} for _ in images])
message["content"] = _content
return message

def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""
self._ref_images.append(Image.open(image))

@property
def num_reference_images(self) -> int:
"""Get the number of reference images."""
return len(self._ref_images)

def predict(self, image_path: str | Path, prompt: Prompt) -> str:
"""Predict the anomaly label."""
image = Image.open(image_path)
messages: list[dict] = []

if len(self._ref_images) > 0:
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images))

messages.append(self._generate_message(content=prompt.predict, images=[image]))
processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)]

images = [*self._ref_images, image]
inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=100)
result = self.processor.decode(outputs[0], skip_special_tokens=True)
print(result)
return result
Loading