Skip to content

Commit f7ece91

Browse files
🔨 Minor Refactor (#2345)
Refactor Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 834c777 commit f7ece91

File tree

9 files changed

+97
-145
lines changed

9 files changed

+97
-145
lines changed

src/anomalib/models/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
Fastflow,
2525
Fre,
2626
Ganomaly,
27-
Llava,
28-
Llavanext,
2927
Padim,
3028
Patchcore,
3129
ReverseDistillation,
@@ -63,8 +61,6 @@ class UnknownModelError(ModuleNotFoundError):
6361
"AiVad",
6462
"VlmAd",
6563
"WinClip",
66-
"Llava",
67-
"Llavanext",
6864
]
6965

7066
logger = logging.getLogger(__name__)

src/anomalib/models/image/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from .fastflow import Fastflow
1515
from .fre import Fre
1616
from .ganomaly import Ganomaly
17-
from .llava import Llava
18-
from .llava_next import Llavanext
1917
from .padim import Padim
2018
from .patchcore import Patchcore
2119
from .reverse_distillation import ReverseDistillation
@@ -45,6 +43,4 @@
4543
"Uflow",
4644
"VlmAd",
4745
"WinClip",
48-
"Llava",
49-
"Llavanext",
5046
]

src/anomalib/models/image/vlm_ad/backends/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
from abc import ABC, abstractmethod
77
from pathlib import Path
88

9+
from anomalib.models.image.vlm_ad.utils import Prompt
10+
911

1012
class Backend(ABC):
1113
"""Base backend."""
1214

1315
@abstractmethod
14-
def __init__(self, model_name: str, api_key: str | None = None) -> None:
16+
def __init__(self, model_name: str) -> None:
1517
"""Initialize the backend."""
1618

1719
@abstractmethod
1820
def add_reference_images(self, image: str | Path) -> None:
1921
"""Add reference images for k-shot."""
2022

2123
@abstractmethod
22-
def predict(self, image: str | Path) -> str:
24+
def predict(self, image: str | Path, prompt: Prompt) -> str:
2325
"""Predict the anomaly label."""

src/anomalib/models/image/vlm_ad/backends/chat_gpt.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from pathlib import Path
99
from typing import TYPE_CHECKING
1010

11+
from anomalib.models.image.vlm_ad.utils import Prompt
1112
from anomalib.utils.exceptions import try_import
1213

1314
from .base import Backend
14-
from .dataclasses import Prompt
1515

1616
if try_import("openai"):
1717
from openai import OpenAI
@@ -27,11 +27,8 @@
2727
class ChatGPT(Backend):
2828
"""ChatGPT backend."""
2929

30-
def __init__(self, api_key: str | None = None, model_name: str = "gpt-4o-mini") -> None:
30+
def __init__(self, api_key: str, model_name: str) -> None:
3131
"""Initialize the ChatGPT backend."""
32-
if api_key is None:
33-
msg = "API key is required for ChatGPT backend."
34-
raise ValueError(msg)
3532
self.api_key = api_key
3633
self._ref_images_encoded: list[str] = []
3734
self.model_name: str = model_name
@@ -51,30 +48,30 @@ def add_reference_images(self, image: str | Path) -> None:
5148
"""Add reference images for k-shot."""
5249
self._ref_images_encoded.append(self._encode_image_to_url(image))
5350

54-
def predict(self, image: str | Path) -> str:
51+
def predict(self, image: str | Path, prompt: Prompt) -> str:
5552
"""Predict the anomaly label."""
5653
image_encoded = self._encode_image_to_url(image)
5754
messages = []
5855

5956
# few-shot
6057
if len(self._ref_images_encoded) > 0:
61-
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images_encoded))
58+
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded))
6259

63-
messages.append(self._generate_message(content=self.prompt.predict, images=[image_encoded]))
60+
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded]))
6461

6562
response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name)
6663
return response.choices[0].message.content
6764

6865
@staticmethod
6966
def _generate_message(content: str, images: list[str] | None) -> dict:
7067
"""Generate a message."""
71-
message = {"role": "user"}
72-
if images is None:
73-
message["content"] = content
68+
message: dict[str, list[dict] | str] = {"role": "user"}
69+
if images is not None:
70+
_content: list[dict[str, str | dict]] = [{"type": "text", "text": content}]
71+
_content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images])
72+
message["content"] = _content
7473
else:
75-
message["content"] = [{"type": "text", "text": content}]
76-
for image in images:
77-
message["content"].append({"type": "image_url", "image_url": {"url": image}})
74+
message["content"] = content
7875
return message
7976

8077
def _encode_image_to_url(self, image: str | Path) -> str:
@@ -89,20 +86,3 @@ def _encode_image_to_base_64(image: str | Path) -> str:
8986
"""Encode the image to base64."""
9087
image = Path(image)
9188
return base64.b64encode(image.read_bytes()).decode("utf-8")
92-
93-
@property
94-
def prompt(self) -> Prompt:
95-
"""Get the Ollama prompt."""
96-
return Prompt(
97-
predict=(
98-
"You are given an image. It is either normal or anomalous."
99-
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
100-
"Then give the reason for your decision.\n"
101-
"For example, 'YES: The image has a crack on the wall.'"
102-
),
103-
few_shot=(
104-
"These are a few examples of normal picture without any anomalies."
105-
" You have to use these to determine if the image I provide in the next"
106-
" chat is normal or anomalous."
107-
),
108-
)

src/anomalib/models/image/vlm_ad/backends/dataclasses.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/anomalib/models/image/vlm_ad/backends/huggingface.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
import logging
7-
from enum import Enum
87
from pathlib import Path
98

109
from PIL import Image
1110
from transformers.modeling_utils import PreTrainedModel
1211

12+
from anomalib.models.image.vlm_ad.utils import Prompt
1313
from anomalib.utils.exceptions import try_import
1414

1515
from .base import Backend
16-
from .dataclasses import Prompt
1716

1817
if try_import("transformers"):
1918
import transformers
@@ -26,49 +25,25 @@
2625
logger = logging.getLogger(__name__)
2726

2827

29-
class LlavaNextModels(Enum):
30-
"""Available models."""
31-
32-
VICUNA_7B = "llava-hf/llava-v1.6-vicuna-7b-hf"
33-
VICUNA_13B = "llava-hf/llava-v1.6-vicuna-13b-hf"
34-
MISTRAL_7B = "llava-hf/llava-v1.6-mistral-7b-hf"
35-
36-
3728
class Huggingface(Backend):
3829
"""Huggingface backend."""
3930

4031
def __init__(
4132
self,
33+
model_name: str,
4234
api_key: str | None = None,
43-
model_name: str | LlavaNextModels = LlavaNextModels.VICUNA_7B,
4435
) -> None:
4536
"""Initialize the Huggingface backend."""
4637
if api_key:
4738
logger.warning("API key is not required for Huggingface backend.")
48-
self.model_name: str = LlavaNextModels(model_name).value
39+
self.model_name: str = model_name
4940
self._ref_images: list[str] = []
5041
self._processor: ProcessorMixin | None = None
5142
self._model: PreTrainedModel | None = None
5243

53-
@property
54-
def prompt(self) -> Prompt:
55-
"""Get the Ollama prompt."""
56-
return Prompt(
57-
predict=(
58-
"You are given an image. It is either normal or anomalous."
59-
" First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
60-
"Then give the reason for your decision.\n"
61-
"For example, 'YES: The image has a crack on the wall.'"
62-
),
63-
few_shot=(
64-
"These are a few examples of normal picture without any anomalies."
65-
" You have to use these to determine if the image I provide in the next"
66-
" chat is normal or anomalous."
67-
),
68-
)
69-
7044
@property
7145
def processor(self) -> ProcessorMixin:
46+
"""Get the Huggingface processor."""
7247
if self._processor is None:
7348
if transformers is None:
7449
msg = "transformers is not installed."
@@ -78,41 +53,41 @@ def processor(self) -> ProcessorMixin:
7853

7954
@property
8055
def model(self) -> PreTrainedModel:
56+
"""Get the Huggingface model."""
8157
if self._model is None:
8258
if transformers is None:
8359
msg = "transformers is not installed."
8460
raise ValueError(msg)
85-
self._model: PreTrainedModel = transformers.LlavaNextForConditionalGeneration.from_pretrained(
86-
self.model_name,
87-
)
61+
self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name)
8862
return self._model
8963

9064
@staticmethod
9165
def _generate_message(content: str, images: list[str] | None) -> dict:
9266
"""Generate a message."""
93-
message = {"role": "user"}
94-
message["content"] = [{"type": "text", "text": content}]
67+
message: dict[str, str | list[dict]] = {"role": "user"}
68+
_content: list[dict[str, str]] = [{"type": "text", "text": content}]
9569
if images is not None:
96-
for _ in images:
97-
message["content"].append({"type": "image"})
70+
_content.extend([{"type": "image"} for _ in images])
71+
message["content"] = _content
9872
return message
9973

10074
def add_reference_images(self, image: str | Path) -> None:
75+
"""Add reference images for k-shot."""
10176
self._ref_images.append(Image.open(image))
10277

103-
def predict(self, image_path: str | Path) -> str:
78+
def predict(self, image_path: str | Path, prompt: Prompt) -> str:
10479
"""Predict the anomaly label."""
10580
image = Image.open(image_path)
106-
messages = []
81+
messages: list[dict] = []
10782

10883
if len(self._ref_images) > 0:
109-
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images))
84+
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images))
11085

111-
messages.append(self._generate_message(content=self.prompt.predict, images=[image]))
112-
prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)]
86+
messages.append(self._generate_message(content=prompt.predict, images=[image]))
87+
processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)]
11388

11489
images = [*self._ref_images, image]
115-
inputs = self.processor(images, prompt, return_tensors="pt", padding=True).to(self.model.device)
90+
inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device)
11691
outputs = self.model.generate(**inputs, max_new_tokens=100)
11792
result = self.processor.decode(outputs[0], skip_special_tokens=True)
11893
print(result)

src/anomalib/models/image/vlm_ad/backends/ollama.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import logging
1212
from pathlib import Path
1313

14+
from anomalib.models.image.vlm_ad.utils import Prompt
1415
from anomalib.utils.exceptions import try_import
1516

1617
from .base import Backend
17-
from .dataclasses import Prompt
1818

1919
if try_import("ollama"):
2020
from ollama import chat
@@ -28,43 +28,24 @@
2828
class Ollama(Backend):
2929
"""Ollama backend."""
3030

31-
def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None:
31+
def __init__(self, model_name: str) -> None:
3232
"""Initialize the Ollama backend."""
33-
if api_key:
34-
logger.warning("API key is not required for Ollama backend.")
3533
self.model_name: str = model_name
3634
self._ref_images_encoded: list[str] = []
3735

3836
def add_reference_images(self, image: str | Path) -> None:
3937
"""Encode the image to base64."""
4038
self._ref_images_encoded.append(_encode_image(image))
4139

42-
@property
43-
def prompt(self) -> Prompt:
44-
"""Get the Ollama prompt."""
45-
return Prompt(
46-
predict=(
47-
"You are given an image. It is either normal or anomalous."
48-
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
49-
"Then give the reason for your decision.\n"
50-
"For example, 'YES: The image has a crack on the wall.'"
51-
),
52-
few_shot=(
53-
"These are a few examples of normal picture without any anomalies."
54-
" You have to use these to determine if the image I provide in the next"
55-
" chat is normal or anomalous."
56-
),
57-
)
58-
5940
@staticmethod
6041
def _generate_message(content: str, images: list[str] | None) -> dict:
6142
"""Generate a message."""
62-
message = {"role": "user", "content": content}
43+
message: dict[str, str | list[str]] = {"role": "user", "content": content}
6344
if images:
6445
message["images"] = images
6546
return message
6647

67-
def predict(self, image: str | Path) -> str:
48+
def predict(self, image: str | Path, prompt: Prompt) -> str:
6849
"""Predict the anomaly label."""
6950
if not chat:
7051
msg = "Ollama is not installed. Please install it using `pip install ollama`."
@@ -74,9 +55,9 @@ def predict(self, image: str | Path) -> str:
7455

7556
# few-shot
7657
if len(self._ref_images_encoded) > 0:
77-
messages.append(self._generate_message(content=self.prompt.few_shot, images=self._ref_images_encoded))
58+
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded))
7859

79-
messages.append(self._generate_message(content=self.prompt.predict, images=[image_encoded]))
60+
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded]))
8061

8162
response = chat(
8263
model=self.model_name,

0 commit comments

Comments
 (0)