Skip to content

Commit c8c1d8c

Browse files
committed
Reducing import time from ~2.7s to ~0.6s
- Many top of file imports moved to functions - Some type annotations removed - Added benchmark test to measure our import time.
1 parent 472794e commit c8c1d8c

File tree

9 files changed

+118
-70
lines changed

9 files changed

+118
-70
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ repos:
5050
- id: xcxc-check
5151
name: Check for note-to-self comments (xcxc)
5252
description: Grep all source files for xcxc which signifies a comment that shouldn't be checked in.
53-
entry: bash -c "[[ $(grep -rniI xcxc --exclude .pre-commit-config.yaml --exclude-dir _readthedocs --exclude-dir htmlcov ./* >&2 ; echo $?) == 1 ]]"
53+
entry: bash -c "[[ $(grep -rniI xcxc --exclude .pre-commit-config.yaml --exclude-dir _readthedocs --exclude-dir htmlcov --exclude-dir _results --exclude-dir env ./* >&2 ; echo $?) == 1 ]]"
5454
language: system
5555
pass_filenames: false
5656
always_run: true

benchmarks/benchmarks.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
For more information on writing benchmarks:
44
https://asv.readthedocs.io/en/stable/writing_benchmarks.html."""
55

6+
import subprocess
7+
68
from hyrax import example_benchmarks
79

810

@@ -14,3 +16,14 @@ def time_computation():
1416
def mem_list():
1517
"""Memory computations are prefixed with 'mem' or 'peakmem'."""
1618
return example_benchmarks.memory_computation()
19+
20+
21+
def time_import():
22+
"""
23+
time how long it takes to import our package. This should stay relatively fast.
24+
25+
Note, the actual import time will be slightly lower than this on a comparable system
26+
However, high import times do affect this metric proportionally.
27+
"""
28+
result = subprocess.run(["python", "-c", "import hyrax"])
29+
assert result.returncode == 0

src/hyrax/data_sets/data_set_registry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
# ruff: noqa: D102, B027
22
import logging
33
from collections.abc import Generator
4-
from typing import Optional
54

65
import numpy.typing as npt
7-
from astropy.table import Table
8-
from torch.utils.data import Dataset, IterableDataset
96

107
from hyrax.config_utils import ConfigDict
118
from hyrax.plugin_utils import get_or_load_class, update_registry
@@ -49,7 +46,7 @@ def __len__ ():
4946
5047
"""
5148

52-
def __init__(self, config: ConfigDict, metadata_table: Optional[Table] = None):
49+
def __init__(self, config: ConfigDict, metadata_table=None):
5350
"""
5451
.. py:method:: __init__
5552
@@ -117,6 +114,8 @@ def is_iterable(self):
117114
bool
118115
True if underlying dataset is iterable
119116
"""
117+
from torch.utils.data import Dataset, IterableDataset
118+
120119
if isinstance(self, (Dataset, IterableDataset)):
121120
return isinstance(self, IterableDataset)
122121
else:
@@ -132,6 +131,8 @@ def is_map(self):
132131
bool
133132
True if underlying dataset is map-style
134133
"""
134+
from torch.utils.data import Dataset, IterableDataset
135+
135136
if isinstance(self, (Dataset, IterableDataset)):
136137
# All torch IterableDatasets are also Datasets
137138
return not isinstance(self, IterableDataset)

src/hyrax/data_sets/fits_image_dataset.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,15 @@
5555

5656
import logging
5757
import time
58-
from collections.abc import Generator, Iterable, Iterator
58+
from collections.abc import Generator, Iterable
5959
from concurrent.futures import Executor
6060
from pathlib import Path
6161
from threading import Thread
6262
from typing import Any, Callable, Optional, Union
6363

6464
import numpy as np
6565
import numpy.typing as npt
66-
from astropy.io import fits
67-
from astropy.table import Table
68-
from torch import Tensor, from_numpy
6966
from torch.utils.data import Dataset
70-
from torchvision.transforms.v2 import CenterCrop, Compose, Lambda, Transform
7167

7268
from hyrax.config_utils import ConfigDict
7369

@@ -99,6 +95,8 @@ def __init__(self, config: ConfigDict):
9995
config : ConfigDict
10096
Nested configuration dictionary for hyrax
10197
"""
98+
from torchvision.transforms.v2 import Lambda
99+
102100
self._config = config
103101

104102
transform_str = config["data_set"]["transform"]
@@ -155,6 +153,9 @@ def _init_from_path(self, path: Union[Path, str]):
155153
Path or string specifying the directory path that is the root of all filenames in the
156154
catalog table
157155
"""
156+
from torch import Tensor
157+
from torchvision.transforms.v2 import Compose
158+
158159
self.path = path
159160

160161
# This is common code
@@ -186,7 +187,7 @@ def _init_from_path(self, path: Union[Path, str]):
186187

187188
logger.info(f"FitsImageDataSet has {len(self)} objects")
188189

189-
def _set_crop_transform(self) -> Transform:
190+
def _set_crop_transform(self):
190191
"""
191192
Returns the crop transform on the image
192193
@@ -196,6 +197,8 @@ def _set_crop_transform(self) -> Transform:
196197
197198
2) Return the crop transform only so it can be added to the transform stack appropriately.
198199
"""
200+
from torchvision.transforms.v2 import CenterCrop
201+
199202
self.cutout_shape = self.config["data_set"]["crop_to"] if self.config["data_set"]["crop_to"] else None
200203

201204
if not isinstance(self.cutout_shape, list) or len(self.cutout_shape) != 2:
@@ -205,7 +208,9 @@ def _set_crop_transform(self) -> Transform:
205208

206209
return CenterCrop(size=self.cutout_shape)
207210

208-
def _read_filter_catalog(self, filter_catalog_path: Optional[Path]) -> Optional[Table]:
211+
def _read_filter_catalog(self, filter_catalog_path: Optional[Path]):
212+
from astropy.table import Table
213+
209214
if filter_catalog_path is None:
210215
msg = "Must provide a filter catalog in config['data_set']['filter_catalog']"
211216
raise RuntimeError(msg)
@@ -250,7 +255,7 @@ def _read_filter_catalog(self, filter_catalog_path: Optional[Path]) -> Optional[
250255

251256
return table
252257

253-
def _parse_filter_catalog(self, table: Optional[Table]) -> None:
258+
def _parse_filter_catalog(self, table) -> None:
254259
"""Sets self.files by parsing the catalog.
255260
256261
Subclasses may override this function to control parsing of the table more directly, but the
@@ -305,7 +310,7 @@ def _before_preload(self) -> None:
305310
# fetching
306311
pass
307312

308-
def _prepare_metadata(self) -> Optional[Table]:
313+
def _prepare_metadata(self):
309314
# This happens when filter_catalog_table is injected in unit tests
310315
if FitsImageDataSet._called_from_test:
311316
return None
@@ -366,7 +371,7 @@ def __len__(self) -> int:
366371
"""
367372
return len(self.files)
368373

369-
def __getitem__(self, idx: int) -> Tensor:
374+
def __getitem__(self, idx: int):
370375
if idx >= len(self.files) or idx < 0:
371376
raise IndexError
372377

@@ -528,7 +533,7 @@ def _preload_tensor_cache(self):
528533
self._log_duration_tensorboard("preload_1k_obj_s", start_time)
529534
start_time = time.monotonic_ns()
530535

531-
def _lazy_map_executor(self, executor: Executor, ids: Iterable[str]) -> Iterator[Tensor]:
536+
def _lazy_map_executor(self, executor: Executor, ids: Iterable[str]):
532537
"""This is a version of concurrent.futures.Executor map() which lazily evaluates the iterator passed
533538
We do this because we do not want all of the tensors to remain in memory during pre-loading. We would
534539
prefer a smaller set of in-flight tensors.
@@ -554,9 +559,10 @@ def _lazy_map_executor(self, executor: Executor, ids: Iterable[str]) -> Iterator
554559
Iterator[torch.Tensor]
555560
An iterator over torch tensors, lazily loaded by running the work_fn as needed.
556561
"""
557-
558562
from concurrent.futures import FIRST_COMPLETED, Future, wait
559563

564+
from torch import Tensor
565+
560566
max_futures = FitsImageDataSet._determine_numprocs_preload()
561567
queue: list[Future[Tensor]] = []
562568
in_progress: set[Future[Tensor]] = set()
@@ -609,15 +615,17 @@ def _log_duration_tensorboard(self, name: str, start_time: int):
609615
duration_s = (now - start_time) / 1.0e9
610616
self.tensorboardx_logger.add_scalar(name, duration_s, since_tensorboard_start_us)
611617

612-
def _check_object_id_to_tensor_cache(self, object_id: str) -> Optional[Tensor]:
618+
def _check_object_id_to_tensor_cache(self, object_id: str):
613619
return self.tensors.get(object_id, None)
614620

615-
def _populate_object_id_to_tensor_cache(self, object_id: str) -> Tensor:
621+
def _populate_object_id_to_tensor_cache(self, object_id: str):
616622
data_torch = self._read_object_id(object_id)
617623
self.tensors[object_id] = data_torch
618624
return data_torch
619625

620-
def _read_object_id(self, object_id: str) -> Tensor:
626+
def _read_object_id(self, object_id: str):
627+
from astropy.io import fits
628+
621629
start_time = time.monotonic_ns()
622630

623631
# Read all the files corresponding to this object
@@ -635,7 +643,9 @@ def _read_object_id(self, object_id: str) -> Tensor:
635643
self._log_duration_tensorboard("object_total_read_time_s", start_time)
636644
return data_torch
637645

638-
def _convert_to_torch(self, data: list[npt.ArrayLike]) -> Tensor:
646+
def _convert_to_torch(self, data: list[npt.ArrayLike]):
647+
from torch import from_numpy
648+
639649
start_time = time.monotonic_ns()
640650

641651
# Push all the filter data into a tensor object
@@ -655,7 +665,7 @@ def _convert_to_torch(self, data: list[npt.ArrayLike]) -> Tensor:
655665
# Do we want to memoize them on first __getitem__ call?
656666
#
657667
# For now we just do it the naive way
658-
def _object_id_to_tensor(self, object_id: str) -> Tensor:
668+
def _object_id_to_tensor(self, object_id: str):
659669
"""Converts an object_id to a pytorch tensor with dimenstions (self.num_filters, self.cutout_shape[0],
660670
self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the
661671
far corners of the image from (0,0).

src/hyrax/data_sets/hsc_data_set.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,9 @@
1010
from typing import Optional
1111

1212
import numpy as np
13-
from astropy.io import fits
14-
from astropy.table import Table
1513
from schwimmbad import MultiPool
16-
from torchvision.transforms.v2 import CenterCrop
1714

1815
from hyrax.config_utils import ConfigDict
19-
from hyrax.download import Downloader
20-
from hyrax.downloadCutout.downloadCutout import (
21-
parse_bool,
22-
parse_degree,
23-
parse_latitude,
24-
parse_longitude,
25-
parse_rerun,
26-
parse_tract_opt,
27-
parse_type,
28-
)
2916

3017
from .fits_image_dataset import FitsImageDataSet, files_dict
3118

@@ -43,6 +30,8 @@ def __init__(self, config: ConfigDict):
4330
.. py:method:: __init__
4431
4532
"""
33+
from hyrax.download import Downloader
34+
4635
# Note "rebuild_manifest" is not a config, its a hack for rebuild_manifest mode
4736
# to ensure we don't use the manifest we believe is corrupt.
4837
rebuild_manifest = config["rebuild_manifest"] if "rebuild_manifest" in config else False # noqa: SIM401
@@ -61,7 +50,9 @@ def __init__(self, config: ConfigDict):
6150

6251
super().__init__(config)
6352

64-
def _read_filter_catalog(self, filter_catalog_path: Optional[Path]) -> Optional[Table]:
53+
def _read_filter_catalog(self, filter_catalog_path: Optional[Path]):
54+
from astropy.table import Table
55+
6556
try:
6657
retval = super()._read_filter_catalog(filter_catalog_path)
6758
except RuntimeError:
@@ -84,7 +75,7 @@ def _read_filter_catalog(self, filter_catalog_path: Optional[Path]) -> Optional[
8475
#
8576
# In the HSC case this will also have to do fallback and call
8677
# _scan_file_dimensions() and/or _scan_file_names() and pass back only the files dict.
87-
def _parse_filter_catalog(self, table: Table) -> None:
78+
def _parse_filter_catalog(self, table) -> None:
8879
object_id_missing = self.object_id_column_name not in table.colnames if table is not None else True
8980
filter_missing = self.filter_column_name not in table.colnames if table is not None else True
9081
filename_missing = self.filename_column_name not in table.colnames if table is not None else True
@@ -137,6 +128,8 @@ def _parse_filter_catalog(self, table: Table) -> None:
137128
return self.files
138129

139130
def _set_crop_transform(self):
131+
from torchvision.transforms.v2 import CenterCrop
132+
140133
cutout_shape = self.config["data_set"]["crop_to"] if self.config["data_set"]["crop_to"] else None
141134
self.cutout_shape = self._check_file_dimensions() if cutout_shape is None else cutout_shape
142135
return CenterCrop(size=self.cutout_shape)
@@ -285,6 +278,8 @@ def _scan_file_dimension(processing_unit: tuple[str, list[str]]) -> tuple[str, l
285278

286279
@staticmethod
287280
def _fits_file_dims(filepath) -> tuple[int, int]:
281+
from astropy.io import fits
282+
288283
try:
289284
with fits.open(filepath) as hdul:
290285
return (hdul[1].shape[0], hdul[1].shape[1])
@@ -439,6 +434,19 @@ def _check_file_dimensions(self) -> tuple[int, int]:
439434
return cutout_width, cutout_height
440435

441436
def _rebuild_manifest(self, config):
437+
from astropy.table import Table
438+
439+
from hyrax.download import Downloader
440+
from hyrax.downloadCutout.downloadCutout import (
441+
parse_bool,
442+
parse_degree,
443+
parse_latitude,
444+
parse_longitude,
445+
parse_rerun,
446+
parse_tract_opt,
447+
parse_type,
448+
)
449+
442450
if self.filter_catalog:
443451
raise RuntimeError("Cannot rebuild manifest. Set the filter_catalog=false and rerun")
444452

src/hyrax/data_sets/hyrax_cifar_data_set.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
import logging
33

44
import numpy as np
5-
import torchvision.transforms as transforms
6-
from astropy.table import Table
75
from torch.utils.data import Dataset, IterableDataset
8-
from torchvision.datasets import CIFAR10
96

107
from hyrax.config_utils import ConfigDict
118

@@ -18,6 +15,10 @@ class HyraxCifarBase:
1815
"""Base class for Hyrax Cifar datasets"""
1916

2017
def __init__(self, config: ConfigDict):
18+
import torchvision.transforms as transforms
19+
from astropy.table import Table
20+
from torchvision.datasets import CIFAR10
21+
2122
transform = transforms.Compose(
2223
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
2324
)

src/hyrax/data_sets/inference_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import numpy.typing as npt
9-
from torch import Tensor, from_numpy
109
from torch.utils.data import Dataset
1110

1211
from hyrax.config_utils import find_most_recent_results_dir
@@ -98,7 +97,9 @@ def ids(self) -> Generator[str]:
9897
"""
9998
return (str(id) for id in self.batch_index["id"])
10099

101-
def __getitem__(self, idx: Union[int, np.ndarray]) -> Tensor:
100+
def __getitem__(self, idx: Union[int, np.ndarray]):
101+
from torch import from_numpy
102+
102103
try:
103104
_ = (e for e in idx) # type: ignore[union-attr]
104105
except TypeError:

0 commit comments

Comments
 (0)