Add drop_last option to DataLoader batching #3448
Open
+83
−99
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Introduces a drop_last flag to DataLoaderBuilder and FixBatchStrategy, allowing incomplete batches to be optionally dropped during data loading. Updates the builder API and batch strategy logic to support this feature, improving flexibility for batch processing.
Pull Request Template
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
Fixes issue: DataLoader yields as many iterations as num_workers instead of correct batch count; no drop_last support (see #3316)
Changes
Problem:
The DataLoader previously yielded one batch per worker per epoch, regardless of batch size or dataset size, leading to incorrect iteration counts. There was also no way to drop incomplete batches, unlike PyTorch’s DataLoader.
Solution:
drop_last
flag toDataLoaderBuilder
andFixBatchStrategy
.drop_last
.drop_last
is true, incomplete batches are dropped.num_workers
.Testing
cargo test --workspace --all-features
to ensure all tests pass.ceil(dataset_size / batch_size)
for variousnum_workers
values.drop_last
flag correctly drops incomplete batches when enabled.num_workers
only affects parallelism, not batch count.