Skip to content

Commit d5a8d7f

Browse files
authored
Iterable dataset support (#273)
- Adds an example CIFAR iterable dataset - Changes to pytorch_ignite.py to work around an ignite bug in iterable datasets pytorch/ignite#3372 - New is_iterable() and is_map() interface on dataset base class to unify discernment logic - Support for abstract base classes that derive from HyraxDataset not being themselves checked for required methods. - Documentation added for iterable external data sets - End to end tests for iterable datasets
1 parent fee6355 commit d5a8d7f

File tree

9 files changed

+175
-38
lines changed

9 files changed

+175
-38
lines changed

docs/external_libraries.rst

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,19 @@ items in the batch. This loss is logged to MLflow and tensorboard.
7575
Defining a dataset class
7676
------------------------
7777

78-
Dataset classes are written as subclasses of both ``hyrax.data_sets.HyraxDataset`` and
79-
``torch.utils.data.Dataset``. Datasets must minimally define the methods below. These are similar in form to
80-
Torch's `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_
78+
Dataset classes are written as subclasses of ``hyrax.data_sets.HyraxDataset``. Datasets must choose to be
79+
either "map style", and also inherit from ``torch.utils.data.Dataset`` or "iterable" and inherit from
80+
``torch.utils.data.IterableDataset``. `Look here <https://pytorch.org/docs/stable/data.html#dataset-types>`_
81+
for an overview of the difference between map style and iterable datasets.
8182

82-
A fully worked example of creating a custom dataset class is in the example notebook
83+
A fully worked example of creating a custom map-style dataset class is in the example notebook
8384
:doc:`/pre_executed/custom_dataset`
8485

86+
The methods required are detailed by category below.
87+
88+
All datasets
89+
............
90+
8591
``__init__(self, config)``
8692
.................................
8793
On creation of your dataset Hyrax passes the entire Hyrax config as a nested dictionry in the ``config``
@@ -92,6 +98,9 @@ dataset will be done by Hyrax, when running the relevant verb. Further detail on
9298
You must call ``super().__init__(config)`` or ``super().__init__(config, metadata_table)`` in your
9399
``__init__`` function
94100

101+
Map style datasets
102+
..................
103+
95104
``__getitem(self, idx:int)``
96105
............................
97106
Return a single item in your dataset given a zero-based index.
@@ -100,6 +109,16 @@ Return a single item in your dataset given a zero-based index.
100109
.................
101110
Return the length of your dataset.
102111

112+
Iterable datasets
113+
.................
114+
115+
``__iter__(self)``
116+
.................
117+
Yield a single item in your dataset, or supply a generator function which does the same.
118+
If your dataset has an end, yield StopIteration at the end.
119+
120+
Warning: Iterable datasets which do not yield StopIteration are not currently supported in hyrax.
121+
103122
Optional Overrides
104123
..................
105124

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ dynamic = ["version"]
1717
requires-python = ">=3.9"
1818
dependencies = [
1919
"astropy", # Used to load fits files of sources to query HSC cutout server
20-
"pytorch-ignite", # Used for distributed training, logging, etc.
20+
# Pin to the current version of pytorch ignite so workarounds to
21+
# https://github.com/pytorch/ignite/issues/3372 function correctly
22+
# while allowing us to release packages that don't depend on dev versions
23+
# of pytorch-ignite.
24+
"pytorch-ignite <= 0.5.2", # Used for distributed training, logging, etc.
25+
"more-itertools", # Used to work around the issue in pytorch-ignite above
2126
"toml", # Used to load configuration files as dictionaries
2227
"tomlkit", # Used to load configuration files as dictionaries and retain comments
2328
"torch", # Used for CNN model and in train.py

src/hyrax/data_sets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"DATA_SET_REGISTRY",
99
"HyraxCifarDataSet",
1010
"FitsImageDataSet",
11+
"HyraxCifarIterableDataSet",
1112
"HSCDataSet",
1213
"InferenceDataSet",
1314
"Dataset",

src/hyrax/data_sets/data_set_registry.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy.typing as npt
77
from astropy.table import Table
8+
from torch.utils.data import Dataset, IterableDataset
89

910
from hyrax.config_utils import ConfigDict
1011
from hyrax.plugin_utils import get_or_load_class, update_registry
@@ -86,15 +87,46 @@ def __init__(config):
8687
self._metadata_table = metadata_table
8788
self.tensorboardx_logger = None
8889

90+
def is_iterable(self):
91+
"""
92+
Returns true if underlying dataset is iterable style, supporting __iter__ vs map style
93+
where __getitem__/__len__ are the preferred access methods.
94+
95+
Returns
96+
-------
97+
bool
98+
True if underlying dataset is iterable
99+
"""
100+
if isinstance(self, (Dataset, IterableDataset)):
101+
return isinstance(self, IterableDataset)
102+
else:
103+
return hasattr(self, "__iter__")
104+
105+
def is_map(self):
106+
"""
107+
Returns true if underlying dataset is map style, supporting __getitem__/__len__ vs iterable
108+
where __iter__ is the preferred access method.
109+
110+
Returns
111+
-------
112+
bool
113+
True if underlying dataset is map-style
114+
"""
115+
if isinstance(self, (Dataset, IterableDataset)):
116+
# All torch IterableDatasets are also Datasets
117+
return not isinstance(self, IterableDataset)
118+
else:
119+
return hasattr(self, "__getitem__")
120+
89121
@property
90122
def config(self):
91123
return self._config
92124

93125
def __init_subclass__(cls):
94-
from torch.utils.data import IterableDataset
126+
from abc import ABC
95127

96-
if IterableDataset in cls.__bases__ or hasattr(cls, "__iter__"):
97-
logger.error("Hyrax does not fully support iterable data sets yet. Proceed at your own risk.")
128+
if ABC in cls.__bases__:
129+
return
98130

99131
# Paranoia. Deriving from a torch dataset class should ensure this, but if an external dataset author
100132
# Forgets to to do that, we tell them.
@@ -126,10 +158,10 @@ def ids(self) -> Generator[str]:
126158
A generator yielding all the string IDs of the dataset.
127159
128160
"""
129-
if hasattr(self, "__len__"):
161+
if self.is_map():
130162
for x in range(len(self)):
131163
yield str(x)
132-
elif hasattr(self, "__iter__"):
164+
elif self.is_iterable():
133165
for index, _ in enumerate(iter(self)):
134166
yield (str(index))
135167
else:
@@ -145,10 +177,10 @@ def shape(self) -> tuple:
145177
tuple
146178
Shape tuple of the tensor that will be returned from the dataset.
147179
"""
148-
if hasattr(self, "__getitem__"):
180+
if self.is_map():
149181
data_sample = self[0]
150182
return data_sample[0].shape if isinstance(data_sample, tuple) else data_sample.shape
151-
elif hasattr(self, "__iter__"):
183+
elif self.is_iterable():
152184
data_sample = next(iter(self))
153185
return data_sample[0].shape if isinstance(data_sample, tuple) else data_sample.shape
154186
else:

src/hyrax/data_sets/hsc_data_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _prune_objects(self, filters_ref: list[str], cutout_shape: Optional[tuple[in
321321
# Drop objects that can't meet the cutout size provided
322322
for shape in self.dims[object_id]:
323323
if shape[0] < cutout_shape[0] or shape[1] < cutout_shape[1]:
324-
msg = f"A file for object {object_id} has shape ({shape[1]}px, {shape[1]}px)"
324+
msg = f"A file for object {object_id} has shape ({shape[0]}px, {shape[1]}px)"
325325
msg += " this is too small for the given cutout size of "
326326
msg += f"({cutout_shape[0]}px, {cutout_shape[1]}px)"
327327
self._mark_for_prune(object_id, msg)

src/hyrax/data_sets/hyrax_cifar_data_set.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torchvision.transforms as transforms
66
from astropy.table import Table
7+
from torch.utils.data import IterableDataset
78
from torchvision.datasets import CIFAR10
89

910
from hyrax.config_utils import ConfigDict
@@ -14,8 +15,8 @@
1415

1516

1617
class HyraxCifarDataSet(HyraxDataset, CIFAR10):
17-
"""This is simply a version of CIFAR10 that has our needed shape method, and is initialized using
18-
Hyrax config with a transformation that works well for example code.
18+
"""This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation
19+
that works well for example code.
1920
2021
We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that
2122
into Train/test/Validate according to configuration.
@@ -30,3 +31,28 @@ def __init__(self, config: ConfigDict):
3031
)
3132
metadata_table = Table({"label": np.array([self[index][1] for index in range(len(self))])})
3233
super().__init__(config, metadata_table)
34+
35+
36+
class HyraxCifarIterableDataSet(HyraxDataset, IterableDataset):
37+
"""This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation
38+
that works well for example code. This version only supports iteration, and not map-style access
39+
40+
We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that
41+
into Train/test/Validate according to configuration.
42+
"""
43+
44+
def __init__(self, config: ConfigDict):
45+
transform = transforms.Compose(
46+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
47+
)
48+
self.cifar = CIFAR10(
49+
root=config["general"]["data_dir"], train=True, download=True, transform=transform
50+
)
51+
metadata_table = Table(
52+
{"label": np.array([self.cifar[index][1] for index in range(len(self.cifar))])}
53+
)
54+
super().__init__(config, metadata_table)
55+
56+
def __iter__(self):
57+
for item in self.cifar:
58+
yield item

src/hyrax/infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def run(config: ConfigDict):
4343
data_set = setup_dataset(config, tensorboardx_logger)
4444

4545
model = setup_model(config, data_set)
46-
logger.info(f"data set has length {len(data_set)}") # type: ignore[arg-type]
46+
if data_set.is_map():
47+
logger.info(f"data set has length {len(data_set)}") # type: ignore[arg-type]
4748
data_loader = dist_data_loader(data_set, config, split=config["infer"]["split"])
4849

4950
log_runtime_config(config, results_dir)

src/hyrax/pytorch_ignite.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import mlflow
1313

1414
import torch
15-
from ignite.engine import Engine, Events
15+
from ignite.engine import Engine, EventEnum, Events
1616
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
1717
from ignite.handlers.tqdm_logger import ProgressBar
1818
from tensorboardX import SummaryWriter
@@ -103,6 +103,9 @@ def dist_data_loader(
103103
For multiple splits, we return a dictionary where the keys are the names of the splits
104104
and the value is either a Dataloader as described above or the value None if the split
105105
was not configured.
106+
107+
If an iterable dataset is passed, we cannot create multiple splits with a pyTorch sampler object
108+
so we return the same thing for all splits, which is a dataloader representing the entire iterable
106109
"""
107110
# Handle case where no split is needed.
108111
if isinstance(split, bool):
@@ -118,18 +121,25 @@ def dist_data_loader(
118121
if seed is not None:
119122
torch_rng.manual_seed(seed)
120123

121-
# Create the indexes for all splits based on config.
122-
indexes = create_splits(data_set, config)
123-
124-
# Create samplers and dataloaders for each split we are interested in
125-
samplers = {
126-
s: SubsetRandomSampler(indexes[s], generator=torch_rng) if indexes.get(s) else None for s in split
127-
}
128-
129-
dataloaders = {
130-
split: idist.auto_dataloader(data_set, sampler=sampler, **config["data_loader"]) if sampler else None
131-
for split, sampler in samplers.items()
132-
}
124+
if data_set.is_iterable():
125+
dataloaders = {
126+
s: idist.auto_dataloader(data_set, pin_memory=True, **config["data_loader"]) for s in split
127+
}
128+
else:
129+
# Create the indexes for all splits based on config.
130+
indexes = create_splits(data_set, config)
131+
132+
# Create samplers and dataloaders for each split we are interested in
133+
samplers = {
134+
s: SubsetRandomSampler(indexes[s], generator=torch_rng) if indexes.get(s) else None for s in split
135+
}
136+
137+
dataloaders = {
138+
split: idist.auto_dataloader(data_set, sampler=sampler, **config["data_loader"])
139+
if sampler
140+
else None
141+
for split, sampler in samplers.items()
142+
}
133143

134144
# Return only one if we were only passed one split in, return the dictionary otherwise.
135145
return dataloaders[split[0]] if len(split) == 1 else dataloaders
@@ -363,6 +373,7 @@ def create_validator(
363373
model = idist.auto_model(model)
364374

365375
validator = create_engine("train_step", device, model)
376+
fixup_engine(validator)
366377

367378
@validator.on(Events.STARTED)
368379
def set_model_to_eval_mode():
@@ -372,12 +383,12 @@ def set_model_to_eval_mode():
372383
def set_model_to_train_mode():
373384
model.train()
374385

375-
@validator.on(Events.EPOCH_COMPLETED)
386+
@validator.on(HyraxEvents.HYRAX_EPOCH_COMPLETED)
376387
def log_training_loss():
377388
logger.debug(f"Validation run time: {validator.state.times['EPOCH_COMPLETED']:.2f}[s]")
378389
logger.debug(f"Validation metrics: {validator.state.output}")
379390

380-
@trainer.on(Events.EPOCH_COMPLETED)
391+
@trainer.on(HyraxEvents.HYRAX_EPOCH_COMPLETED)
381392
def run_validation():
382393
validator.run(validation_data_loader)
383394

@@ -386,7 +397,7 @@ def log_validation_loss(validator, trainer):
386397
tensorboardx_logger.add_scalar("training/validation/loss", validator.state.output["loss"], step)
387398
mlflow.log_metrics({"validation/loss": validator.state.output["loss"]}, step=step)
388399

389-
validator.add_event_handler(Events.EPOCH_COMPLETED, log_validation_loss, trainer)
400+
validator.add_event_handler(HyraxEvents.HYRAX_EPOCH_COMPLETED, log_validation_loss, trainer)
390401

391402
return validator
392403

@@ -419,6 +430,7 @@ def create_trainer(
419430
model.train()
420431
model = idist.auto_model(model)
421432
trainer = create_engine("train_step", device, model)
433+
fixup_engine(trainer)
422434

423435
optimizer = extract_model_method(model, "optimizer")
424436

@@ -435,18 +447,19 @@ def create_trainer(
435447
to_save,
436448
DiskSaver(results_directory, require_empty=False),
437449
n_saved=1,
438-
global_step_transform=global_step_from_engine(trainer),
450+
global_step_transform=global_step_from_engine(trainer, Events.EPOCH_COMPLETED),
439451
filename_pattern="{name}_epoch_{global_step}.{ext}",
440452
)
441453

442454
def neg_loss_score(engine):
455+
print(engine.state)
443456
return -engine.state.output["loss"]
444457

445458
best_checkpoint = Checkpoint(
446459
to_save,
447460
DiskSaver(results_directory, require_empty=False),
448461
n_saved=1,
449-
global_step_transform=global_step_from_engine(trainer),
462+
global_step_transform=global_step_from_engine(trainer, Events.EPOCH_COMPLETED),
450463
score_name="loss",
451464
score_function=neg_loss_score,
452465
greater_or_equal=True,
@@ -473,13 +486,13 @@ def log_training_loss_tensorboard(trainer):
473486
tensorboardx_logger.add_scalar("training/training/loss", trainer.state.output["loss"], step)
474487
mlflow.log_metrics({"training/loss": trainer.state.output["loss"]}, step=step)
475488

476-
@trainer.on(Events.EPOCH_COMPLETED)
489+
@trainer.on(HyraxEvents.HYRAX_EPOCH_COMPLETED)
477490
def log_training_loss(trainer):
478491
logger.debug(f"Epoch {trainer.state.epoch} run time: {trainer.state.times['EPOCH_COMPLETED']:.2f}[s]")
479492
logger.debug(f"Epoch {trainer.state.epoch} metrics: {trainer.state.output}")
480493

481-
trainer.add_event_handler(Events.EPOCH_COMPLETED, latest_checkpoint)
482-
trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint)
494+
trainer.add_event_handler(HyraxEvents.HYRAX_EPOCH_COMPLETED, latest_checkpoint)
495+
trainer.add_event_handler(HyraxEvents.HYRAX_EPOCH_COMPLETED, best_checkpoint)
483496

484497
@trainer.on(Events.COMPLETED)
485498
def log_total_time(trainer):
@@ -498,3 +511,38 @@ def log_best_checkpoint_location(_, best_checkpoint):
498511
pbar.attach(trainer)
499512

500513
return trainer
514+
515+
516+
class HyraxEvents(EventEnum):
517+
"""
518+
Workaround event for a pytorch ignite bug. See fixup_engine for details
519+
"""
520+
521+
HYRAX_EPOCH_COMPLETED = "HyraxEpochCompleted"
522+
523+
524+
def fixup_engine(engine: Engine) -> Engine:
525+
"""
526+
Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where
527+
engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)
528+
529+
We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final
530+
iteration. This is just before the erronious state reset.
531+
532+
This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined
533+
(https://github.com/pytorch/ignite/pull/3373) in version 0.6.0 estimated August 2025
534+
"""
535+
from more_itertools import peekable
536+
537+
engine.register_events(*HyraxEvents)
538+
539+
@engine.on(Events.ITERATION_COMPLETED)
540+
def maintain_event_handler(engine):
541+
# Ensure we have a peekable iterator in the engine.
542+
if not hasattr(engine._dataloader_iter, "peek"):
543+
# Replace with a pass-through peekable iterator
544+
engine._dataloader_iter = peekable(engine._dataloader_iter)
545+
546+
# On the last iteration the peekable iterator evaluates as true
547+
if not engine._dataloader_iter:
548+
engine.fire_event(HyraxEvents.HYRAX_EPOCH_COMPLETED)

0 commit comments

Comments
 (0)