Skip to content
59 changes: 42 additions & 17 deletions src/hyrax/verbs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ class Visualize(Verb):
cli_name = "visualize"
add_parser_kwargs = {}

# Dataset groups that the Visualize verb knows about.
# REQUIRED_SPLITS must be present in the data request configuration.
# OPTIONAL_SPLITS are used when present but do not cause an error if absent.
# Note: "umap" results may also be required in the future.
REQUIRED_SPLITS = ("infer",)
OPTIONAL_SPLITS = ()

@staticmethod
def setup_parser(parser: ArgumentParser):
"""CLI not implemented for this verb"""
Expand Down Expand Up @@ -78,7 +85,8 @@ def run(
from holoviews.streams import Lasso, Params, RangeXY, SelectionXY, Tap
from scipy.spatial import KDTree

from hyrax.data_sets.inference_dataset import InferenceDataSet
from hyrax.data_sets.result_factories import load_results_dataset
from hyrax.pytorch_ignite import setup_dataset

if self.config["data_set"]["object_id_column_name"]:
self.object_id_column_name = self.config["data_set"]["object_id_column_name"]
Expand Down Expand Up @@ -106,10 +114,25 @@ def run(
)

# Get the umap data and put it in a kdtree for indexing.
self.umap_results = InferenceDataSet(self.config, results_dir=input_dir, verb="umap")
logger.info(f"Rendering UMAP from the following directory: {self.umap_results.results_dir}")
self.umap_results = load_results_dataset(self.config, results_dir=input_dir, verb="umap")

# Build a DataProvider from the live config for metadata access.
# This avoids implicit coupling between result datasets and their original data sources.
datasets = setup_dataset(self.config)
if not set(Visualize.REQUIRED_SPLITS).intersection(set(datasets.keys())):
required_keys = ", ".join(sorted(Visualize.REQUIRED_SPLITS))
available_keys = ", ".join(sorted(datasets.keys())) or "<none>"
msg = (
f"Visualize requires dataset entries {required_keys} in the data request configuration "
f"Available dataset keys: {available_keys}"
)
raise RuntimeError(msg)
# NOTE: this presently depends on only a single required
# split, but here is here the data provider logic would be
# extended if needed.
self.metadata_provider = datasets[Visualize.REQUIRED_SPLITS[0]]

available_fields = self.umap_results.metadata_fields()
available_fields = self.metadata_provider.metadata_fields()
for field in fields.copy():
if field not in available_fields:
logger.warning(f"Field {field} is unavailable for this dataset")
Expand Down Expand Up @@ -142,7 +165,7 @@ def run(
if self.color_column:
try:
# Check if column exists
available_fields = self.umap_results.metadata_fields()
available_fields = self.metadata_provider.metadata_fields()
if self.color_column not in available_fields:
logger.warning(
f"Column '{self.color_column}' not found in dataset."
Expand All @@ -154,7 +177,7 @@ def run(
all_indices = list(range(len(self.umap_results)))

# Extract metadata for the specified column
metadata = self.umap_results.metadata(all_indices, [self.color_column])
metadata = self.metadata_provider.metadata(all_indices, [self.color_column])
self.color_values = metadata[self.color_column]
logger.info(f"Successfully loaded color values from column '{self.color_column}'")
import numpy as np
Expand Down Expand Up @@ -308,7 +331,7 @@ def visible_points(self, x_range: Union[tuple, list], y_range: Union[tuple, list

if np.any(np.isinf([x_range, y_range])):
# Show all points without filtering
points = np.array([point.numpy() for point in self.umap_results])
points = np.array([np.asarray(point) for point in self.umap_results])
point_indices = list(range(len(self.umap_results)))
else:
# Use existing filtering logic
Expand Down Expand Up @@ -358,7 +381,7 @@ def update_points(self, **kwargs) -> None:
self.points, self.points_id, self.points_idx = self.poly_select_points(kwargs["geometry"])
elif self._called_tap(kwargs):
_, idx = self.tree.query([kwargs["x"], kwargs["y"]])
self.points = np.array([self.umap_results[idx].numpy()])
self.points = np.array([np.asarray(self.umap_results[idx])])
self.points_id = np.array([list(self.umap_results.ids())[idx]])
self.points_idx = np.array([idx])
elif self._called_box_select(kwargs):
Expand Down Expand Up @@ -419,7 +442,7 @@ def poly_select_points(self, geometry) -> tuple[npt.ArrayLike, npt.ArrayLike, np
# Coarse grain the points within the axis-aligned bounding box of the geometry
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(geometry)
point_indexes_coarse = self.box_select_indexes([xmin, xmax], [ymin, ymax])
points_coarse = self.umap_results[point_indexes_coarse].numpy()
points_coarse = np.asarray(self.umap_results[point_indexes_coarse])

tri = Delaunay(geometry)
mask = tri.find_simplex(points_coarse) != -1
Expand Down Expand Up @@ -456,7 +479,7 @@ def box_select_points(

indexes = self.box_select_indexes(x_range, y_range)
ids = np.array(list(self.umap_results.ids()))[indexes]
points = self.umap_results[indexes].numpy()
points = np.asarray(self.umap_results[indexes])
return points, ids, indexes

def box_select_indexes(self, x_range: Union[tuple, list], y_range: Union[tuple, list]):
Expand Down Expand Up @@ -495,7 +518,7 @@ def _inside_box(pt):
return x > xmin and x < xmax and y > ymin and y < ymax

# Filter for points properly inside the box
return [i for i in indexes if _inside_box(self.umap_results[i].numpy())]
return [i for i in indexes if _inside_box(np.asarray(self.umap_results[i]))]

def selected_objects(self, **kwargs):
"""
Expand Down Expand Up @@ -527,7 +550,7 @@ def _table_from_points(self):

# These are the rest of the columns, pulled from metadata
try:
metadata = self.umap_results.metadata(self.points_idx, self.data_fields)
metadata = self.metadata_provider.metadata(self.points_idx, self.data_fields)
except Exception as e:
# Leave in this try/catch beause some notebook implementations dont
# allow us to return an exception to the console.
Expand All @@ -551,8 +574,10 @@ def _bounding_box(points):
return (xmin, xmax, ymin, ymax)

def _even_aspect_bounding_box(self):
import numpy as np

# Bring aspect ratio to 1:1 by expanding the smaller axis range
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(point.numpy() for point in self.umap_results)
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(np.asarray(point) for point in self.umap_results)

x_dim = xmax - xmin
x_center = (xmax + xmin) / 2.0
Expand Down Expand Up @@ -585,7 +610,7 @@ def get_selected_df(self):

df = pd.DataFrame(self.points, columns=["x", "y"])
df[self.object_id_column_name] = self.points_id
meta = self.umap_results.metadata(self.points_idx, self.data_fields)
meta = self.metadata_provider.metadata(self.points_idx, self.data_fields)
meta_df = pd.DataFrame(meta, columns=self.data_fields)

cols = [self.object_id_column_name, "x", "y"] + self.data_fields
Expand Down Expand Up @@ -656,7 +681,7 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
sampled_ids = [id_map[idx] for idx in chosen_idx]

# Get metadata - this is in the same order as chosen_idx
meta = self.umap_results.metadata(
meta = self.metadata_provider.metadata(
chosen_idx, [self.object_id_column_name, self.filename_column_name]
)

Expand Down Expand Up @@ -696,12 +721,12 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
if len(self.torch_tensor_bands) == 1:
# Single-band extraction
band_idx = self.torch_tensor_bands[0]
arr = tensor[band_idx].numpy()
arr = np.asarray(tensor[band_idx])
else:
# RGB extraction (3 bands)
rgb_arrays = []
for band_idx in self.torch_tensor_bands:
rgb_arrays.append(tensor[band_idx].numpy())
rgb_arrays.append(np.asarray(tensor[band_idx]))
# Stack along new axis to create (H, W, 3) RGB array
arr = np.stack(rgb_arrays, axis=-1)
else:
Expand Down