Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from anomalib.utils.normalization import NormalizationMethod
from anomalib.utils.path import create_versioned_dir
from anomalib.utils.types import NORMALIZATION, THRESHOLD
from anomalib.utils.visualization import ImageVisualizer
from anomalib.utils.visualization import BaseVisualizer, ExplanationVisualizer, ImageVisualizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -432,9 +432,15 @@ def _setup_anomalib_callbacks(self) -> None:
_callbacks.append(_ThresholdCallback(self.threshold))
_callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names))

visualizer: BaseVisualizer
if self.task == TaskType.EXPLANATION:
visualizer = ExplanationVisualizer()
else:
visualizer = ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE)

_callbacks.append(
_VisualizationCallback(
visualizers=ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE),
visualizers=visualizer,
save=True,
root=self._cache.args["default_root_dir"] / "images",
),
Expand Down
3 changes: 1 addition & 2 deletions src/anomalib/models/image/vlm_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs)
"""Validation step."""
del args, kwargs # These variables are not used.
responses = [(self.vlm_backend.predict(img_path, self.prompt)) for img_path in batch["image_path"]]

batch["str_output"] = responses
batch["explanation"] = responses
batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device)
return batch

Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/utils/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# SPDX-License-Identifier: Apache-2.0

from .base import BaseVisualizer, GeneratorResult, VisualizationStep
from .explanation import ExplanationVisualizer
from .image import ImageResult, ImageVisualizer
from .metrics import MetricsVisualizer

__all__ = [
"BaseVisualizer",
"ExplanationVisualizer",
"ImageResult",
"ImageVisualizer",
"GeneratorResult",
Expand Down
106 changes: 106 additions & 0 deletions src/anomalib/utils/visualization/explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Explanation visualization generator.

Note: This is a temporary visualizer, and will be replaced with the new visualizer in the future.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Iterator
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw, ImageFont

from .base import BaseVisualizer, GeneratorResult, VisualizationStep


class ExplanationVisualizer(BaseVisualizer):
"""Explanation visualization generator."""

def __init__(self) -> None:
super().__init__(visualize_on=VisualizationStep.BATCH)
self.padding = 3
self.font = ImageFont.load_default(size=16)

def generate(self, **kwargs) -> Iterator[GeneratorResult]:
"""Generate images and return them as an iterator."""
outputs = kwargs.get("outputs", None)
if outputs is None:
msg = "Outputs must be provided to generate images."
raise ValueError(msg)
return self._visualize_batch(outputs)

def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]:
"""Visualize batch of images."""
batch_size = batch["image"].shape[0]
height, width = batch["image"].shape[-2:]
for i in range(batch_size):
image = batch["image"][i]
explanation = batch["explanation"][i]
file_name = Path(batch["image_path"][i])
image = Image.open(file_name)
image = image.resize((width, height))
image = self._draw_image(width, height, image=image, explanation=explanation)
yield GeneratorResult(image=image, file_name=file_name)

def _draw_image(self, width: int, height: int, image: Image, explanation: str) -> np.ndarray:
text_canvas: Image = self._get_explanation_image(width, height, image, explanation)
label_canvas: Image = self._get_label_image(explanation)

final_width = max(text_canvas.size[0], width)
final_height = height + text_canvas.size[1]
combined_image = Image.new("RGB", (final_width, final_height), (255, 255, 255))
combined_image.paste(image, (self.padding, 0))
combined_image.paste(label_canvas, (10, 10))
combined_image.paste(text_canvas, (0, height))
return np.array(combined_image)

def _get_label_image(self, explanation: str) -> Image:
# Draw label
# Can't use pred_labels as it is computed from the pred_scores using image_threshold. It gives incorrect value.
# So, using explanation. This will probably change with the new design.
label = "Anomalous" if explanation.startswith("Y") else "Normal"
label_color = "red" if label == "Anomalous" else "green"
label_canvas = Image.new("RGB", (100, 20), color=label_color)
draw = ImageDraw.Draw(label_canvas)
draw.text((0, 0), label, font=self.font, fill="white", align="center")
return label_canvas

def _get_explanation_image(self, width: int, height: int, image: Image, explanation: str) -> Image:
# compute wrap width
text_canvas = Image.new("RGB", (width, height), color="white")
dummy_image = ImageDraw.Draw(image)
text_bbox = dummy_image.textbbox((0, 0), explanation, font=self.font, align="center")
text_canvas_width = text_bbox[2] - text_bbox[0] + self.padding

# split lines based on the width
lines = list(explanation.split("\n"))
line_with_max_len = max(lines, key=len)
new_width = int(width * len(line_with_max_len) // text_canvas_width)

# wrap text based on the new width
lines = []
current_line: list[str] = []
for word in explanation.split(" "):
test_line = " ".join([*current_line, word])
if len(test_line) <= new_width:
current_line.append(word)
else:
lines.append(" ".join(current_line))
current_line = [word]
lines.append(" ".join(current_line))
wrapped_lines = "\n".join(lines)

# recompute height
dummy_image = Image.new("RGB", (new_width, height), color="white")
draw = ImageDraw.Draw(dummy_image)
text_bbox = draw.textbbox((0, 0), wrapped_lines, font=self.font, align="center")
new_width = int(text_bbox[2] - text_bbox[0] + self.padding)
new_height = int(text_bbox[3] - text_bbox[1] + self.padding)

# Final text image
text_canvas = Image.new("RGB", (new_width, new_height), color="white")
draw = ImageDraw.Draw(text_canvas)
draw.text((self.padding // 2, 0), wrapped_lines, font=self.font, fill="black", align="center")
return text_canvas