Skip to content

Commit 7593023

Browse files
gitosaurusCopilot
authored andcommitted
visualize verb should find all its inputs from config. (#713)
* `visualize` verb should find all its inputs from config. Rather that assuming that the dataset provides its own breadcrumbs to its input, use the session-wide config to locate that data and its necessary metadata. * Insist that "infer" be in the data request * Use setup_dataset from pytorch_ignite in visualize verb * Replace .numpy() calls with np.asarray() at interface boundaries in visualize.py * Add REQUIRED_SPLITS and OPTIONAL_SPLITS to Visualize verb --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: gitosaurus <6794831+gitosaurus@users.noreply.github.com>
1 parent f1ad1b7 commit 7593023

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

src/hyrax/verbs/visualize.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ class Visualize(Verb):
2020
cli_name = "visualize"
2121
add_parser_kwargs = {}
2222

23+
# Dataset groups that the Visualize verb knows about.
24+
# REQUIRED_SPLITS must be present in the data request configuration.
25+
# OPTIONAL_SPLITS are used when present but do not cause an error if absent.
26+
# Note: "umap" results may also be required in the future.
27+
REQUIRED_SPLITS = ("infer",)
28+
OPTIONAL_SPLITS = ()
29+
2330
@staticmethod
2431
def setup_parser(parser: ArgumentParser):
2532
"""CLI not implemented for this verb"""
@@ -78,7 +85,8 @@ def run(
7885
from holoviews.streams import Lasso, Params, RangeXY, SelectionXY, Tap
7986
from scipy.spatial import KDTree
8087

81-
from hyrax.data_sets.inference_dataset import InferenceDataSet
88+
from hyrax.data_sets.result_factories import load_results_dataset
89+
from hyrax.pytorch_ignite import setup_dataset
8290

8391
if self.config["data_set"]["object_id_column_name"]:
8492
self.object_id_column_name = self.config["data_set"]["object_id_column_name"]
@@ -106,10 +114,25 @@ def run(
106114
)
107115

108116
# Get the umap data and put it in a kdtree for indexing.
109-
self.umap_results = InferenceDataSet(self.config, results_dir=input_dir, verb="umap")
110-
logger.info(f"Rendering UMAP from the following directory: {self.umap_results.results_dir}")
117+
self.umap_results = load_results_dataset(self.config, results_dir=input_dir, verb="umap")
118+
119+
# Build a DataProvider from the live config for metadata access.
120+
# This avoids implicit coupling between result datasets and their original data sources.
121+
datasets = setup_dataset(self.config)
122+
if not set(Visualize.REQUIRED_SPLITS).intersection(set(datasets.keys())):
123+
required_keys = ", ".join(sorted(Visualize.REQUIRED_SPLITS))
124+
available_keys = ", ".join(sorted(datasets.keys())) or "<none>"
125+
msg = (
126+
f"Visualize requires dataset entries {required_keys} in the data request configuration "
127+
f"Available dataset keys: {available_keys}"
128+
)
129+
raise RuntimeError(msg)
130+
# NOTE: this presently depends on only a single required
131+
# split, but here is here the data provider logic would be
132+
# extended if needed.
133+
self.metadata_provider = datasets[Visualize.REQUIRED_SPLITS[0]]
111134

112-
available_fields = self.umap_results.metadata_fields()
135+
available_fields = self.metadata_provider.metadata_fields()
113136
for field in fields.copy():
114137
if field not in available_fields:
115138
logger.warning(f"Field {field} is unavailable for this dataset")
@@ -142,7 +165,7 @@ def run(
142165
if self.color_column:
143166
try:
144167
# Check if column exists
145-
available_fields = self.umap_results.metadata_fields()
168+
available_fields = self.metadata_provider.metadata_fields()
146169
if self.color_column not in available_fields:
147170
logger.warning(
148171
f"Column '{self.color_column}' not found in dataset."
@@ -154,7 +177,7 @@ def run(
154177
all_indices = list(range(len(self.umap_results)))
155178

156179
# Extract metadata for the specified column
157-
metadata = self.umap_results.metadata(all_indices, [self.color_column])
180+
metadata = self.metadata_provider.metadata(all_indices, [self.color_column])
158181
self.color_values = metadata[self.color_column]
159182
logger.info(f"Successfully loaded color values from column '{self.color_column}'")
160183
import numpy as np
@@ -308,7 +331,7 @@ def visible_points(self, x_range: Union[tuple, list], y_range: Union[tuple, list
308331

309332
if np.any(np.isinf([x_range, y_range])):
310333
# Show all points without filtering
311-
points = np.array([point.numpy() for point in self.umap_results])
334+
points = np.array([np.asarray(point) for point in self.umap_results])
312335
point_indices = list(range(len(self.umap_results)))
313336
else:
314337
# Use existing filtering logic
@@ -358,7 +381,7 @@ def update_points(self, **kwargs) -> None:
358381
self.points, self.points_id, self.points_idx = self.poly_select_points(kwargs["geometry"])
359382
elif self._called_tap(kwargs):
360383
_, idx = self.tree.query([kwargs["x"], kwargs["y"]])
361-
self.points = np.array([self.umap_results[idx].numpy()])
384+
self.points = np.array([np.asarray(self.umap_results[idx])])
362385
self.points_id = np.array([list(self.umap_results.ids())[idx]])
363386
self.points_idx = np.array([idx])
364387
elif self._called_box_select(kwargs):
@@ -419,7 +442,7 @@ def poly_select_points(self, geometry) -> tuple[npt.ArrayLike, npt.ArrayLike, np
419442
# Coarse grain the points within the axis-aligned bounding box of the geometry
420443
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(geometry)
421444
point_indexes_coarse = self.box_select_indexes([xmin, xmax], [ymin, ymax])
422-
points_coarse = self.umap_results[point_indexes_coarse].numpy()
445+
points_coarse = np.asarray(self.umap_results[point_indexes_coarse])
423446

424447
tri = Delaunay(geometry)
425448
mask = tri.find_simplex(points_coarse) != -1
@@ -456,7 +479,7 @@ def box_select_points(
456479

457480
indexes = self.box_select_indexes(x_range, y_range)
458481
ids = np.array(list(self.umap_results.ids()))[indexes]
459-
points = self.umap_results[indexes].numpy()
482+
points = np.asarray(self.umap_results[indexes])
460483
return points, ids, indexes
461484

462485
def box_select_indexes(self, x_range: Union[tuple, list], y_range: Union[tuple, list]):
@@ -495,7 +518,7 @@ def _inside_box(pt):
495518
return x > xmin and x < xmax and y > ymin and y < ymax
496519

497520
# Filter for points properly inside the box
498-
return [i for i in indexes if _inside_box(self.umap_results[i].numpy())]
521+
return [i for i in indexes if _inside_box(np.asarray(self.umap_results[i]))]
499522

500523
def selected_objects(self, **kwargs):
501524
"""
@@ -527,7 +550,7 @@ def _table_from_points(self):
527550

528551
# These are the rest of the columns, pulled from metadata
529552
try:
530-
metadata = self.umap_results.metadata(self.points_idx, self.data_fields)
553+
metadata = self.metadata_provider.metadata(self.points_idx, self.data_fields)
531554
except Exception as e:
532555
# Leave in this try/catch beause some notebook implementations dont
533556
# allow us to return an exception to the console.
@@ -551,8 +574,10 @@ def _bounding_box(points):
551574
return (xmin, xmax, ymin, ymax)
552575

553576
def _even_aspect_bounding_box(self):
577+
import numpy as np
578+
554579
# Bring aspect ratio to 1:1 by expanding the smaller axis range
555-
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(point.numpy() for point in self.umap_results)
580+
(xmin, xmax, ymin, ymax) = Visualize._bounding_box(np.asarray(point) for point in self.umap_results)
556581

557582
x_dim = xmax - xmin
558583
x_center = (xmax + xmin) / 2.0
@@ -585,7 +610,7 @@ def get_selected_df(self):
585610

586611
df = pd.DataFrame(self.points, columns=["x", "y"])
587612
df[self.object_id_column_name] = self.points_id
588-
meta = self.umap_results.metadata(self.points_idx, self.data_fields)
613+
meta = self.metadata_provider.metadata(self.points_idx, self.data_fields)
589614
meta_df = pd.DataFrame(meta, columns=self.data_fields)
590615

591616
cols = [self.object_id_column_name, "x", "y"] + self.data_fields
@@ -656,7 +681,7 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
656681
sampled_ids = [id_map[idx] for idx in chosen_idx]
657682

658683
# Get metadata - this is in the same order as chosen_idx
659-
meta = self.umap_results.metadata(
684+
meta = self.metadata_provider.metadata(
660685
chosen_idx, [self.object_id_column_name, self.filename_column_name]
661686
)
662687

@@ -696,12 +721,12 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
696721
if len(self.torch_tensor_bands) == 1:
697722
# Single-band extraction
698723
band_idx = self.torch_tensor_bands[0]
699-
arr = tensor[band_idx].numpy()
724+
arr = np.asarray(tensor[band_idx])
700725
else:
701726
# RGB extraction (3 bands)
702727
rgb_arrays = []
703728
for band_idx in self.torch_tensor_bands:
704-
rgb_arrays.append(tensor[band_idx].numpy())
729+
rgb_arrays.append(np.asarray(tensor[band_idx]))
705730
# Stack along new axis to create (H, W, 3) RGB array
706731
arr = np.stack(rgb_arrays, axis=-1)
707732
else:

0 commit comments

Comments
 (0)