Find and use custom collate functions defined on dataset classes#561
Find and use custom collate functions defined on dataset classes#561
Conversation
…n `DataProvider.collate`. Added and updated some unit tests.
There was a problem hiding this comment.
Pull request overview
This PR adds support for custom collate functions in dataset classes. The DataProvider now discovers and applies custom collate methods defined on dataset classes as static methods, allowing datasets to implement their own data batching logic beyond the default numpy array stacking.
Key changes:
- Added
custom_collate_functionsdictionary to track dataset-specific collate methods - Modified
prepare_datasets()to detect and store custom collate functions from dataset instances - Updated
collate()method to route data samples to custom collate functions when available
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| src/hyrax/data_sets/data_provider.py | Implements custom collate function discovery and application logic in DataProvider |
| tests/hyrax/conftest.py | Adds test fixture with monkey-patched custom collate function for HyraxRandomDataset |
| tests/hyrax/test_data_provider.py | Adds two new tests verifying custom collate function detection and application |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #561 +/- ##
==========================================
+ Coverage 55.26% 55.34% +0.07%
==========================================
Files 53 53
Lines 5155 5175 +20
==========================================
+ Hits 2849 2864 +15
- Misses 2306 2311 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@drewoldag I've opened a new pull request, #562, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@drewoldag I've opened a new pull request, #563, to work on those changes. Once the pull request is ready, I'll request review from you. |
Click here to view all benchmarks. |
… functions. Removing the monkey patched collate function in unit tests.
In this PR we update
DataProviderto include a dictionary that maps "friendly_name" to a callable custom collate function if one has been defined on the dataset class associated with the friendly name.Additionally, in
DataProvider.collate, as we collate a batch of data, if we discover that some of the data is from a dataset class that implements a custom collation function, we'll apply that function to the correct portion of data. If there is no custom collate function defined, then DataProvider will do the work of collating the data samples into single large numpy arrays.Finally, a couple of unit tests were added as well as a test fixture that will monkey patch the HyraxRandomDataset to include a
collatestatic method.There are a few assumptions being made in the implementation of this PR:
collate@staticmethodin the dataset class.[{'data': {'field_1': <...>, ..., 'field_n': <...>}}, ...]{'data': {'field_1': [<...>, ..., <...>], 'field_n': [<...>, ..., <...>]}, 'object_id': [...]}*.