@@ -20,6 +20,13 @@ class Visualize(Verb):
2020 cli_name = "visualize"
2121 add_parser_kwargs = {}
2222
23+ # Dataset groups that the Visualize verb knows about.
24+ # REQUIRED_SPLITS must be present in the data request configuration.
25+ # OPTIONAL_SPLITS are used when present but do not cause an error if absent.
26+ # Note: "umap" results may also be required in the future.
27+ REQUIRED_SPLITS = ("infer" ,)
28+ OPTIONAL_SPLITS = ()
29+
2330 @staticmethod
2431 def setup_parser (parser : ArgumentParser ):
2532 """CLI not implemented for this verb"""
@@ -78,7 +85,8 @@ def run(
7885 from holoviews .streams import Lasso , Params , RangeXY , SelectionXY , Tap
7986 from scipy .spatial import KDTree
8087
81- from hyrax .data_sets .inference_dataset import InferenceDataSet
88+ from hyrax .data_sets .result_factories import load_results_dataset
89+ from hyrax .pytorch_ignite import setup_dataset
8290
8391 if self .config ["data_set" ]["object_id_column_name" ]:
8492 self .object_id_column_name = self .config ["data_set" ]["object_id_column_name" ]
@@ -106,10 +114,25 @@ def run(
106114 )
107115
108116 # Get the umap data and put it in a kdtree for indexing.
109- self .umap_results = InferenceDataSet (self .config , results_dir = input_dir , verb = "umap" )
110- logger .info (f"Rendering UMAP from the following directory: { self .umap_results .results_dir } " )
117+ self .umap_results = load_results_dataset (self .config , results_dir = input_dir , verb = "umap" )
118+
119+ # Build a DataProvider from the live config for metadata access.
120+ # This avoids implicit coupling between result datasets and their original data sources.
121+ datasets = setup_dataset (self .config )
122+ if not set (Visualize .REQUIRED_SPLITS ).intersection (set (datasets .keys ())):
123+ required_keys = ", " .join (sorted (Visualize .REQUIRED_SPLITS ))
124+ available_keys = ", " .join (sorted (datasets .keys ())) or "<none>"
125+ msg = (
126+ f"Visualize requires dataset entries { required_keys } in the data request configuration "
127+ f"Available dataset keys: { available_keys } "
128+ )
129+ raise RuntimeError (msg )
130+ # NOTE: this presently depends on only a single required
131+ # split, but here is here the data provider logic would be
132+ # extended if needed.
133+ self .metadata_provider = datasets [Visualize .REQUIRED_SPLITS [0 ]]
111134
112- available_fields = self .umap_results .metadata_fields ()
135+ available_fields = self .metadata_provider .metadata_fields ()
113136 for field in fields .copy ():
114137 if field not in available_fields :
115138 logger .warning (f"Field { field } is unavailable for this dataset" )
@@ -142,7 +165,7 @@ def run(
142165 if self .color_column :
143166 try :
144167 # Check if column exists
145- available_fields = self .umap_results .metadata_fields ()
168+ available_fields = self .metadata_provider .metadata_fields ()
146169 if self .color_column not in available_fields :
147170 logger .warning (
148171 f"Column '{ self .color_column } ' not found in dataset."
@@ -154,7 +177,7 @@ def run(
154177 all_indices = list (range (len (self .umap_results )))
155178
156179 # Extract metadata for the specified column
157- metadata = self .umap_results .metadata (all_indices , [self .color_column ])
180+ metadata = self .metadata_provider .metadata (all_indices , [self .color_column ])
158181 self .color_values = metadata [self .color_column ]
159182 logger .info (f"Successfully loaded color values from column '{ self .color_column } '" )
160183 import numpy as np
@@ -308,7 +331,7 @@ def visible_points(self, x_range: Union[tuple, list], y_range: Union[tuple, list
308331
309332 if np .any (np .isinf ([x_range , y_range ])):
310333 # Show all points without filtering
311- points = np .array ([point . numpy ( ) for point in self .umap_results ])
334+ points = np .array ([np . asarray ( point ) for point in self .umap_results ])
312335 point_indices = list (range (len (self .umap_results )))
313336 else :
314337 # Use existing filtering logic
@@ -358,7 +381,7 @@ def update_points(self, **kwargs) -> None:
358381 self .points , self .points_id , self .points_idx = self .poly_select_points (kwargs ["geometry" ])
359382 elif self ._called_tap (kwargs ):
360383 _ , idx = self .tree .query ([kwargs ["x" ], kwargs ["y" ]])
361- self .points = np .array ([self .umap_results [idx ]. numpy ( )])
384+ self .points = np .array ([np . asarray ( self .umap_results [idx ])])
362385 self .points_id = np .array ([list (self .umap_results .ids ())[idx ]])
363386 self .points_idx = np .array ([idx ])
364387 elif self ._called_box_select (kwargs ):
@@ -419,7 +442,7 @@ def poly_select_points(self, geometry) -> tuple[npt.ArrayLike, npt.ArrayLike, np
419442 # Coarse grain the points within the axis-aligned bounding box of the geometry
420443 (xmin , xmax , ymin , ymax ) = Visualize ._bounding_box (geometry )
421444 point_indexes_coarse = self .box_select_indexes ([xmin , xmax ], [ymin , ymax ])
422- points_coarse = self .umap_results [point_indexes_coarse ]. numpy ( )
445+ points_coarse = np . asarray ( self .umap_results [point_indexes_coarse ])
423446
424447 tri = Delaunay (geometry )
425448 mask = tri .find_simplex (points_coarse ) != - 1
@@ -456,7 +479,7 @@ def box_select_points(
456479
457480 indexes = self .box_select_indexes (x_range , y_range )
458481 ids = np .array (list (self .umap_results .ids ()))[indexes ]
459- points = self .umap_results [indexes ]. numpy ( )
482+ points = np . asarray ( self .umap_results [indexes ])
460483 return points , ids , indexes
461484
462485 def box_select_indexes (self , x_range : Union [tuple , list ], y_range : Union [tuple , list ]):
@@ -495,7 +518,7 @@ def _inside_box(pt):
495518 return x > xmin and x < xmax and y > ymin and y < ymax
496519
497520 # Filter for points properly inside the box
498- return [i for i in indexes if _inside_box (self .umap_results [i ]. numpy ( ))]
521+ return [i for i in indexes if _inside_box (np . asarray ( self .umap_results [i ]))]
499522
500523 def selected_objects (self , ** kwargs ):
501524 """
@@ -527,7 +550,7 @@ def _table_from_points(self):
527550
528551 # These are the rest of the columns, pulled from metadata
529552 try :
530- metadata = self .umap_results .metadata (self .points_idx , self .data_fields )
553+ metadata = self .metadata_provider .metadata (self .points_idx , self .data_fields )
531554 except Exception as e :
532555 # Leave in this try/catch beause some notebook implementations dont
533556 # allow us to return an exception to the console.
@@ -551,8 +574,10 @@ def _bounding_box(points):
551574 return (xmin , xmax , ymin , ymax )
552575
553576 def _even_aspect_bounding_box (self ):
577+ import numpy as np
578+
554579 # Bring aspect ratio to 1:1 by expanding the smaller axis range
555- (xmin , xmax , ymin , ymax ) = Visualize ._bounding_box (point . numpy ( ) for point in self .umap_results )
580+ (xmin , xmax , ymin , ymax ) = Visualize ._bounding_box (np . asarray ( point ) for point in self .umap_results )
556581
557582 x_dim = xmax - xmin
558583 x_center = (xmax + xmin ) / 2.0
@@ -585,7 +610,7 @@ def get_selected_df(self):
585610
586611 df = pd .DataFrame (self .points , columns = ["x" , "y" ])
587612 df [self .object_id_column_name ] = self .points_id
588- meta = self .umap_results .metadata (self .points_idx , self .data_fields )
613+ meta = self .metadata_provider .metadata (self .points_idx , self .data_fields )
589614 meta_df = pd .DataFrame (meta , columns = self .data_fields )
590615
591616 cols = [self .object_id_column_name , "x" , "y" ] + self .data_fields
@@ -656,7 +681,7 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
656681 sampled_ids = [id_map [idx ] for idx in chosen_idx ]
657682
658683 # Get metadata - this is in the same order as chosen_idx
659- meta = self .umap_results .metadata (
684+ meta = self .metadata_provider .metadata (
660685 chosen_idx , [self .object_id_column_name , self .filename_column_name ]
661686 )
662687
@@ -696,12 +721,12 @@ def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray:
696721 if len (self .torch_tensor_bands ) == 1 :
697722 # Single-band extraction
698723 band_idx = self .torch_tensor_bands [0 ]
699- arr = tensor [band_idx ]. numpy ( )
724+ arr = np . asarray ( tensor [band_idx ])
700725 else :
701726 # RGB extraction (3 bands)
702727 rgb_arrays = []
703728 for band_idx in self .torch_tensor_bands :
704- rgb_arrays .append (tensor [band_idx ]. numpy ( ))
729+ rgb_arrays .append (np . asarray ( tensor [band_idx ]))
705730 # Stack along new axis to create (H, W, 3) RGB array
706731 arr = np .stack (rgb_arrays , axis = - 1 )
707732 else :
0 commit comments