Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/hyrax/data_sets/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def __init__(self, config: dict, request: dict):
self.all_metadata_fields = {}
self.requested_fields = {}

# This dictionary maintains a mapping of friendly name to callable collate
# functions defined on the requested dataset class.
self.custom_collate_functions = {}
Comment thread
drewoldag marked this conversation as resolved.

self.primary_dataset = None
self.primary_dataset_id_field_name = None

Expand Down Expand Up @@ -242,6 +246,11 @@ def prepare_datasets(self):
dataset_cls = fetch_dataset_class(dataset_class)
dataset_instance = dataset_cls(config=dataset_specific_config, data_location=data_location)

# If the dataset instance has a `collate` method, store it for use in
# the DataLoader.collate function.
if hasattr(dataset_instance, "collate") and callable(dataset_instance.collate):
self.custom_collate_functions[friendly_name] = dataset_instance.collate
Comment thread
drewoldag marked this conversation as resolved.

# Store the prepared dataset instance in the `self.prepped_datasets`
self.prepped_datasets[friendly_name] = dataset_instance

Expand Down Expand Up @@ -576,16 +585,11 @@ def collate(self, batch: list[dict]) -> dict:
"""

batch_dict: dict[str, dict[str, list], list] = {}
custom_collate: dict[str, list] = {}

# Aggregate values per friendly_name -> field -> list(values)
for sample in batch:
for friendly_name, fields in sample.items():
# Here we should check the self.custom_collate_function dictionary
# If we discover that friendly_name maps to a particular custom
# collation function (i.e. one defined on the dataset), we should
# include just the samples for that dataset in the batch passed to the custom
# collate function. For now, we will skip that functionality.

# Special handling for "object_id" for the time being. "object_id"
# hangs on the edge of the data dictionary so that it can be consumed
# during `infer`, specifically `_save_batch`. Originally it was
Expand All @@ -597,6 +601,15 @@ def collate(self, batch: list[dict]) -> dict:
batch_dict.setdefault("object_id", []).append(str(val))
continue

# If we find that `friendly_name` is in self.custom_collate_functions
# we accumulate the samples from that dataset and hand off to
# the appropriate custom collate function after the for loop.
if friendly_name in self.custom_collate_functions:
# ! By convention, the dataset's custom collate function will
# ! expect the friendly name to be "data".
custom_collate.setdefault(friendly_name, []).append({"data": fields})
Comment thread
drewoldag marked this conversation as resolved.
continue

if friendly_name not in batch_dict:
batch_dict[friendly_name] = {}

Expand All @@ -607,10 +620,34 @@ def collate(self, batch: list[dict]) -> dict:
if "object_id" in batch_dict:
batch_dict["object_id"] = np.asarray(batch_dict["object_id"], dtype=str)

# Try to convert lists of values into numpy arrays (stack when possible)
# Handle custom collate functions for datasets that define them
for friendly_name, samples in custom_collate.items():
# Get the collate function from the mapping dictionary
custom_collate_fn = self.custom_collate_functions[friendly_name]

# Pass the list of data samples to the collation
custom_collated_data = custom_collate_fn(samples)

# Add the collated data to the batch dictionary
# ! By convention, the returned dictionary will contain two keys,
# ! "data" (the default friendly name) and "object_id". Only keep
Comment thread
drewoldag marked this conversation as resolved.
Outdated
# ! "data", but we assign it to the friendly name by `model_inputs`.
Comment thread
drewoldag marked this conversation as resolved.
Outdated
batch_dict[friendly_name] = custom_collated_data["data"]
Comment thread
drewoldag marked this conversation as resolved.

# Try to convert lists of values into numpy arrays. We skip the "object_id"
# key since it's already been handled, as well as any keys that are in the
# self.custom_collate_function dictionary because those should have been
# handled by the corresponding dataset class custom collate function.
for friendly_name, fields in batch_dict.items():
if friendly_name == "object_id":
continue

# ! Assuming what is returned from custom_collate is already correctly
# ! numpy formatted. This is a big assumption. We should provide some
# ! pre-packaged tests for users developing custom collate functions.
Comment thread
drewoldag marked this conversation as resolved.
if friendly_name in self.custom_collate_functions:
continue

for field, values in list(fields.items()):
# If all values are numpy arrays and have identical shapes -> stack
if all(isinstance(v, np.ndarray) for v in values):
Expand Down
2 changes: 1 addition & 1 deletion src/hyrax/verbs/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self, model_directory: str = None):
[x] Read in the user config
[x] Prepare all the datasets requested
[x] Implement a simple strategy for reading in batches of data samples
[ ] Process the samples with any custom collate functions as well as a default collate function
[x] Process the samples with any custom collate functions as well as a default collate function
[x] Pass the collated batch to the appropriate to_tensor function
[ ] Send that output to the ONNX-ified model
[x] Persist the results of inference.
Expand Down
37 changes: 37 additions & 0 deletions tests/hyrax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,40 @@ def data_provider(multimodal_config):
h.config["model_inputs"] = multimodal_config
dp = DataProvider(h.config, multimodal_config["train"])
return dp


@pytest.fixture(scope="function")
def custom_collate_data_provider(multimodal_config):
"""Use the multimodal_config fixture to create a DataProvider instance
with custom collate functions for each dataset."""

from hyrax.data_sets.random.hyrax_random_dataset import HyraxRandomDataset

@staticmethod
def collate(batch):
"""Contrived custom collate function that will return collated image
data as well as a boolean 'mask' of the same shape.
"""
returned_data = {"data": {}}
if "image" in batch[0]["data"]:
batch_array = np.stack([item["data"]["image"] for item in batch], axis=0)
returned_data["data"]["image"] = batch_array
returned_data["data"]["image_mask"] = np.ones_like(batch_array, dtype=bool)

if "object_id" in batch[0]["data"]:
returned_data["data"]["object_id"] = np.stack(
[item["data"]["object_id"] for item in batch], axis=0
)
returned_data["object_id"] = returned_data["data"]["object_id"]

if "label" in batch[0]["data"]:
returned_data["data"]["label"] = np.stack([item["data"]["label"] for item in batch], axis=0)
Comment thread
drewoldag marked this conversation as resolved.

return returned_data

HyraxRandomDataset.collate = collate

h = hyrax.Hyrax()
h.config["model_inputs"] = multimodal_config
dp = DataProvider(h.config, multimodal_config["train"])
return dp
Comment thread
drewoldag marked this conversation as resolved.
Outdated
50 changes: 50 additions & 0 deletions tests/hyrax/test_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,53 @@ def test_collate_function(data_provider):

# assert that the object_id key is a numpy array
assert isinstance(collated_batch["object_id"], np.ndarray)


def test_finds_custom_collate_function(custom_collate_data_provider):
"""Test that DataProvider correctly identifies datasets
that have custom collate functions defined.
"""

dp = custom_collate_data_provider

assert "random_0" in dp.custom_collate_functions
assert callable(dp.custom_collate_functions["random_0"])
assert "random_1" in dp.custom_collate_functions
assert callable(dp.custom_collate_functions["random_1"])


def test_custom_collate_function_applied(custom_collate_data_provider):
"""Test that DataProvider correctly applies custom collate functions
for datasets that define them in the DataProvider.collate method.
"""

import numpy as np

dp = custom_collate_data_provider

# Create a batch of samples
batch_size = len(dp)
batch = [dp[i] for i in range(batch_size)]

# Collate the batch
collated_batch = dp.collate(batch)

# Verify the structure of the collated batch for random_0
assert isinstance(collated_batch, dict)

# Note: expected fields includes "image_mask" which is added by the custom
# collate function.
expected_fields = ["object_id", "image", "label", "image_mask"]
for field in expected_fields:
assert field in collated_batch["random_0"]
assert len(collated_batch["random_0"][field]) == batch_size

# Verify the structure of the collated batch for random_1. Note that "image_mask"
# is also added by the custom collate function.
expected_fields = ["image", "image_mask"]
for field in expected_fields:
assert field in collated_batch["random_1"]
assert len(collated_batch["random_1"][field]) == batch_size
Comment thread
drewoldag marked this conversation as resolved.

# assert that the object_id key is a numpy array
assert isinstance(collated_batch["object_id"], np.ndarray)
Loading