5555
5656import logging
5757import time
58- from collections .abc import Generator , Iterable , Iterator
58+ from collections .abc import Generator , Iterable
5959from concurrent .futures import Executor
6060from pathlib import Path
6161from threading import Thread
6262from typing import Any , Callable , Optional , Union
6363
6464import numpy as np
6565import numpy .typing as npt
66- from astropy .io import fits
67- from astropy .table import Table
68- from torch import Tensor , from_numpy
6966from torch .utils .data import Dataset
70- from torchvision .transforms .v2 import CenterCrop , Compose , Lambda , Transform
7167
7268from hyrax .config_utils import ConfigDict
7369
@@ -99,6 +95,8 @@ def __init__(self, config: ConfigDict):
9995 config : ConfigDict
10096 Nested configuration dictionary for hyrax
10197 """
98+ from torchvision .transforms .v2 import Lambda
99+
102100 self ._config = config
103101
104102 transform_str = config ["data_set" ]["transform" ]
@@ -155,6 +153,9 @@ def _init_from_path(self, path: Union[Path, str]):
155153 Path or string specifying the directory path that is the root of all filenames in the
156154 catalog table
157155 """
156+ from torch import Tensor
157+ from torchvision .transforms .v2 import Compose
158+
158159 self .path = path
159160
160161 # This is common code
@@ -186,7 +187,7 @@ def _init_from_path(self, path: Union[Path, str]):
186187
187188 logger .info (f"FitsImageDataSet has { len (self )} objects" )
188189
189- def _set_crop_transform (self ) -> Transform :
190+ def _set_crop_transform (self ):
190191 """
191192 Returns the crop transform on the image
192193
@@ -196,6 +197,8 @@ def _set_crop_transform(self) -> Transform:
196197
197198 2) Return the crop transform only so it can be added to the transform stack appropriately.
198199 """
200+ from torchvision .transforms .v2 import CenterCrop
201+
199202 self .cutout_shape = self .config ["data_set" ]["crop_to" ] if self .config ["data_set" ]["crop_to" ] else None
200203
201204 if not isinstance (self .cutout_shape , list ) or len (self .cutout_shape ) != 2 :
@@ -205,7 +208,9 @@ def _set_crop_transform(self) -> Transform:
205208
206209 return CenterCrop (size = self .cutout_shape )
207210
208- def _read_filter_catalog (self , filter_catalog_path : Optional [Path ]) -> Optional [Table ]:
211+ def _read_filter_catalog (self , filter_catalog_path : Optional [Path ]):
212+ from astropy .table import Table
213+
209214 if filter_catalog_path is None :
210215 msg = "Must provide a filter catalog in config['data_set']['filter_catalog']"
211216 raise RuntimeError (msg )
@@ -250,7 +255,7 @@ def _read_filter_catalog(self, filter_catalog_path: Optional[Path]) -> Optional[
250255
251256 return table
252257
253- def _parse_filter_catalog (self , table : Optional [ Table ] ) -> None :
258+ def _parse_filter_catalog (self , table ) -> None :
254259 """Sets self.files by parsing the catalog.
255260
256261 Subclasses may override this function to control parsing of the table more directly, but the
@@ -305,7 +310,7 @@ def _before_preload(self) -> None:
305310 # fetching
306311 pass
307312
308- def _prepare_metadata (self ) -> Optional [ Table ] :
313+ def _prepare_metadata (self ):
309314 # This happens when filter_catalog_table is injected in unit tests
310315 if FitsImageDataSet ._called_from_test :
311316 return None
@@ -366,7 +371,7 @@ def __len__(self) -> int:
366371 """
367372 return len (self .files )
368373
369- def __getitem__ (self , idx : int ) -> Tensor :
374+ def __getitem__ (self , idx : int ):
370375 if idx >= len (self .files ) or idx < 0 :
371376 raise IndexError
372377
@@ -528,7 +533,7 @@ def _preload_tensor_cache(self):
528533 self ._log_duration_tensorboard ("preload_1k_obj_s" , start_time )
529534 start_time = time .monotonic_ns ()
530535
531- def _lazy_map_executor (self , executor : Executor , ids : Iterable [str ]) -> Iterator [ Tensor ] :
536+ def _lazy_map_executor (self , executor : Executor , ids : Iterable [str ]):
532537 """This is a version of concurrent.futures.Executor map() which lazily evaluates the iterator passed
533538 We do this because we do not want all of the tensors to remain in memory during pre-loading. We would
534539 prefer a smaller set of in-flight tensors.
@@ -554,9 +559,10 @@ def _lazy_map_executor(self, executor: Executor, ids: Iterable[str]) -> Iterator
554559 Iterator[torch.Tensor]
555560 An iterator over torch tensors, lazily loaded by running the work_fn as needed.
556561 """
557-
558562 from concurrent .futures import FIRST_COMPLETED , Future , wait
559563
564+ from torch import Tensor
565+
560566 max_futures = FitsImageDataSet ._determine_numprocs_preload ()
561567 queue : list [Future [Tensor ]] = []
562568 in_progress : set [Future [Tensor ]] = set ()
@@ -609,15 +615,17 @@ def _log_duration_tensorboard(self, name: str, start_time: int):
609615 duration_s = (now - start_time ) / 1.0e9
610616 self .tensorboardx_logger .add_scalar (name , duration_s , since_tensorboard_start_us )
611617
612- def _check_object_id_to_tensor_cache (self , object_id : str ) -> Optional [ Tensor ] :
618+ def _check_object_id_to_tensor_cache (self , object_id : str ):
613619 return self .tensors .get (object_id , None )
614620
615- def _populate_object_id_to_tensor_cache (self , object_id : str ) -> Tensor :
621+ def _populate_object_id_to_tensor_cache (self , object_id : str ):
616622 data_torch = self ._read_object_id (object_id )
617623 self .tensors [object_id ] = data_torch
618624 return data_torch
619625
620- def _read_object_id (self , object_id : str ) -> Tensor :
626+ def _read_object_id (self , object_id : str ):
627+ from astropy .io import fits
628+
621629 start_time = time .monotonic_ns ()
622630
623631 # Read all the files corresponding to this object
@@ -635,7 +643,9 @@ def _read_object_id(self, object_id: str) -> Tensor:
635643 self ._log_duration_tensorboard ("object_total_read_time_s" , start_time )
636644 return data_torch
637645
638- def _convert_to_torch (self , data : list [npt .ArrayLike ]) -> Tensor :
646+ def _convert_to_torch (self , data : list [npt .ArrayLike ]):
647+ from torch import from_numpy
648+
639649 start_time = time .monotonic_ns ()
640650
641651 # Push all the filter data into a tensor object
@@ -655,7 +665,7 @@ def _convert_to_torch(self, data: list[npt.ArrayLike]) -> Tensor:
655665 # Do we want to memoize them on first __getitem__ call?
656666 #
657667 # For now we just do it the naive way
658- def _object_id_to_tensor (self , object_id : str ) -> Tensor :
668+ def _object_id_to_tensor (self , object_id : str ):
659669 """Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_shape[0],
660670 self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the
661671 far corners of the image from (0,0).
0 commit comments