Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions src/hyrax/data_sets/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def __init__(self, config: dict, request: dict):
self.all_metadata_fields = {}
self.requested_fields = {}

# This dictionary maintains a mapping of friendly name to callable collate
# functions defined on the requested dataset class.
self.custom_collate_functions = {}

self.primary_dataset = None
self.primary_dataset_id_field_name = None

Expand Down Expand Up @@ -242,6 +246,11 @@ def prepare_datasets(self):
dataset_cls = fetch_dataset_class(dataset_class)
dataset_instance = dataset_cls(config=dataset_specific_config, data_location=data_location)

# If the dataset instance has a `collate` method, store it for use in
# the DataLoader.collate function.
if hasattr(dataset_instance, "collate") and callable(dataset_instance.collate):
self.custom_collate_functions[friendly_name] = dataset_instance.collate

# Store the prepared dataset instance in the `self.prepped_datasets`
self.prepped_datasets[friendly_name] = dataset_instance

Expand Down Expand Up @@ -576,16 +585,11 @@ def collate(self, batch: list[dict]) -> dict:
"""

batch_dict: dict[str, dict[str, list], list] = {}
custom_collate: dict[str, list] = {}

# Aggregate values per friendly_name -> field -> list(values)
for sample in batch:
for friendly_name, fields in sample.items():
# Here we should check the self.custom_collate_function dictionary
# If we discover that friendly_name maps to a particular custom
# collation function (i.e. one defined on the dataset), we should
# include just the samples for that dataset in the batch passed to the custom
# collate function. For now, we will skip that functionality.

# Special handling for "object_id" for the time being. "object_id"
# hangs on the edge of the data dictionary so that it can be consumed
# during `infer`, specifically `_save_batch`. Originally it was
Expand All @@ -597,6 +601,15 @@ def collate(self, batch: list[dict]) -> dict:
batch_dict.setdefault("object_id", []).append(str(val))
continue

# If we find that `friendly_name` is in self.custom_collate_functions
# we accumulate the samples from that dataset and hand off to
# the appropriate custom collate function after the for loop.
if friendly_name in self.custom_collate_functions:
# ! By convention, the dataset's custom collate function will
# ! expect the friendly name to be "data".
custom_collate.setdefault(friendly_name, []).append({"data": fields})
continue

if friendly_name not in batch_dict:
batch_dict[friendly_name] = {}

Expand All @@ -607,10 +620,54 @@ def collate(self, batch: list[dict]) -> dict:
if "object_id" in batch_dict:
batch_dict["object_id"] = np.asarray(batch_dict["object_id"], dtype=str)

# Try to convert lists of values into numpy arrays (stack when possible)
# Handle custom collate functions for datasets that define them
for friendly_name, samples in custom_collate.items():
# Get the collate function from the mapping dictionary
custom_collate_fn = self.custom_collate_functions[friendly_name]

# Pass the list of data samples to the collation
try:
custom_collated_data = custom_collate_fn(samples)
except Exception as err:
logger.error(
f"Error occurred while collating batch for dataset '{friendly_name}' "
"using its custom collate function."
)
raise RuntimeError(
f"Error occurred while collating batch for dataset '{friendly_name}' "
"using its custom collate function."
) from err

# ! By convention, the returned dictionary from a custom collate function
# ! should contain a "data" key (the default friendly name). Only "data"
# ! is used here; any other keys in the returned dictionary are ignored.
if "data" not in custom_collated_data:
logger.error(
f"Custom collate function for dataset '{friendly_name}' did not return "
"a 'data' key in the result."
)
raise RuntimeError(
f"Custom collate function for dataset '{friendly_name}' did not return "
"a 'data' key in the result."
)

# Add the collated data to the batch dictionary
batch_dict[friendly_name] = custom_collated_data["data"]

# Try to convert lists of values into numpy arrays. We skip the "object_id"
# key since it's already been handled, as well as any keys that are in the
# self.custom_collate_function dictionary because those should have been
# handled by the corresponding dataset class custom collate function.
for friendly_name, fields in batch_dict.items():
if friendly_name == "object_id":
continue

# ! Assuming what is returned from custom_collate is already correctly
# ! numpy formatted. This is a big assumption. We should provide some
# ! pre-packaged tests for users developing custom collate functions.
if friendly_name in self.custom_collate_functions:
continue

for field, values in list(fields.items()):
# If all values are numpy arrays and have identical shapes -> stack
if all(isinstance(v, np.ndarray) for v in values):
Expand Down
2 changes: 1 addition & 1 deletion src/hyrax/verbs/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self, model_directory: str = None):
[x] Read in the user config
[x] Prepare all the datasets requested
[x] Implement a simple strategy for reading in batches of data samples
[ ] Process the samples with any custom collate functions as well as a default collate function
[x] Process the samples with any custom collate functions as well as a default collate function
[x] Pass the collated batch to the appropriate to_tensor function
[ ] Send that output to the ONNX-ified model
[x] Persist the results of inference.
Expand Down
39 changes: 39 additions & 0 deletions tests/hyrax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,42 @@ def data_provider(multimodal_config):
h.config["model_inputs"] = multimodal_config
dp = DataProvider(h.config, multimodal_config["train"])
return dp


@pytest.fixture(scope="function")
def custom_collate_data_provider(multimodal_config):
"""Use the multimodal_config fixture to create a DataProvider instance
with custom collate functions for each dataset."""

from hyrax.data_sets.random.hyrax_random_dataset import HyraxRandomDataset

@staticmethod
def collate(batch):
"""Contrived custom collate function that will return collated image
data as well as a boolean 'mask' of the same shape.
"""
returned_data = {"data": {}}
if "image" in batch[0]["data"]:
batch_array = np.stack([item["data"]["image"] for item in batch], axis=0)
returned_data["data"]["image"] = batch_array
returned_data["data"]["image_mask"] = np.ones_like(batch_array, dtype=bool)

if "object_id" in batch[0]["data"]:
returned_data["data"]["object_id"] = np.stack(
[item["data"]["object_id"] for item in batch], axis=0
)
returned_data["object_id"] = returned_data["data"]["object_id"]

if "label" in batch[0]["data"]:
returned_data["data"]["label"] = np.stack([item["data"]["label"] for item in batch], axis=0)

return returned_data

HyraxRandomDataset.collate = collate

h = hyrax.Hyrax()
h.config["model_inputs"] = multimodal_config
dp = DataProvider(h.config, multimodal_config["train"])

yield dp
delattr(HyraxRandomDataset, "collate")
50 changes: 50 additions & 0 deletions tests/hyrax/test_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,53 @@ def test_collate_function(data_provider):

# assert that the object_id key is a numpy array
assert isinstance(collated_batch["object_id"], np.ndarray)


def test_finds_custom_collate_function(custom_collate_data_provider):
"""Test that DataProvider correctly identifies datasets
that have custom collate functions defined.
"""

dp = custom_collate_data_provider

assert "random_0" in dp.custom_collate_functions
assert callable(dp.custom_collate_functions["random_0"])
assert "random_1" in dp.custom_collate_functions
assert callable(dp.custom_collate_functions["random_1"])


def test_custom_collate_function_applied(custom_collate_data_provider):
"""Test that DataProvider correctly applies custom collate functions
for datasets that define them in the DataProvider.collate method.
"""

import numpy as np

dp = custom_collate_data_provider

# Create a batch of samples
batch_size = len(dp)
batch = [dp[i] for i in range(batch_size)]

# Collate the batch
collated_batch = dp.collate(batch)

# Verify the structure of the collated batch for random_0
assert isinstance(collated_batch, dict)

# Note: expected fields includes "image_mask" which is added by the custom
# collate function.
expected_fields = ["object_id", "image", "label", "image_mask"]
for field in expected_fields:
assert field in collated_batch["random_0"]
assert len(collated_batch["random_0"][field]) == batch_size

# Verify the structure of the collated batch for random_1. Note that "image_mask"
# is also added by the custom collate function.
expected_fields = ["image", "image_mask"]
for field in expected_fields:
assert field in collated_batch["random_1"]
assert len(collated_batch["random_1"][field]) == batch_size

# assert that the object_id key is a numpy array
assert isinstance(collated_batch["object_id"], np.ndarray)