Skip to content

Rename BucketBatcher argument to avoid name collision #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,13 +546,13 @@ def test_bucket_batcher_iterdatapipe(self) -> None:

# Functional Test: drop last reduces length
batch_dp = source_dp.bucketbatch(
batch_size=3, drop_last=True, batch_num=100, bucket_num=1, in_batch_shuffle=True
batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True
)
self.assertEqual(9, len(list(batch_dp.unbatch())))

# Functional Test: drop last is False preserves length
batch_dp = source_dp.bucketbatch(
batch_size=3, drop_last=False, batch_num=100, bucket_num=1, in_batch_shuffle=False
batch_size=3, drop_last=False, batch_num=100, bucket_num=1, use_in_batch_shuffle=False
)
self.assertEqual(10, len(list(batch_dp.unbatch())))

Expand All @@ -561,15 +561,15 @@ def _return_self(x):

# Functional Test: using sort_key, with in_batch_shuffle
batch_dp = source_dp.bucketbatch(
batch_size=3, drop_last=True, batch_num=100, bucket_num=1, in_batch_shuffle=True, sort_key=_return_self
batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True, sort_key=_return_self
)
# bucket_num = 1 means there will be no shuffling if a sort key is given
self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(batch_dp))
self.assertEqual(9, len(list(batch_dp.unbatch())))

# Functional Test: using sort_key, without in_batch_shuffle
# Functional Test: using sort_key, without use_in_batch_shuffle
batch_dp = source_dp.bucketbatch(
batch_size=3, drop_last=True, batch_num=100, bucket_num=2, in_batch_shuffle=False, sort_key=_return_self
batch_size=3, drop_last=True, batch_num=100, bucket_num=2, use_in_batch_shuffle=False, sort_key=_return_self
)
self.assertEqual(9, len(list(batch_dp.unbatch())))

Expand All @@ -580,7 +580,7 @@ def _return_self(x):
drop_last=True,
batch_num=100,
bucket_num=2,
in_batch_shuffle=False,
use_in_batch_shuffle=False,
sort_key=_return_self,
)
n_elements_before_reset = 2
Expand Down
16 changes: 7 additions & 9 deletions torchdata/datapipes/iter/transform/bucketbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from functools import partial
from typing import Callable, Iterator, List, Optional, TypeVar

from torch.utils.data import DataChunk

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes import DataChunk, functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

T_co = TypeVar("T_co", covariant=True)
Expand Down Expand Up @@ -62,7 +60,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
batch_num: Number of batches within a bucket (i.e. `bucket_size = batch_size * batch_num`)
bucket_num: Number of buckets to consist a pool for shuffling (i.e. `pool_size = bucket_size * bucket_num`)
sort_key: Callable to sort a bucket (list)
in_batch_shuffle: iF True, do in-batch shuffle; if False, buffer shuffle
use_in_batch_shuffle: if True, do in-batch shuffle; if False, buffer shuffle

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -74,7 +72,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
>>> return sorted(bucket)
>>> batch_dp = source_dp.bucketbatch(
>>> batch_size=3, drop_last=True, batch_num=100,
>>> bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket
>>> bucket_num=1, use_in_batch_shuffle=False, sort_key=sort_bucket
>>> )
>>> list(batch_dp)
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]
Expand All @@ -85,7 +83,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
batch_num: int
bucket_num: int
sort_key: Optional[Callable]
in_batch_shuffle: bool
use_in_batch_shuffle: bool

def __new__(
cls,
Expand All @@ -95,7 +93,7 @@ def __new__(
batch_num: int = 100,
bucket_num: int = 1,
sort_key: Optional[Callable] = None,
in_batch_shuffle: bool = True,
use_in_batch_shuffle: bool = True,
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert batch_num > 0, "Number of batches is required to be larger than 0!"
Expand All @@ -106,7 +104,7 @@ def __new__(

# Shuffle by pool_size
if bucket_num > 1 or sort_key is None:
if in_batch_shuffle:
if use_in_batch_shuffle:
datapipe = datapipe.batch(batch_size=pool_size, drop_last=False).in_batch_shuffle().unbatch()
else:
datapipe = datapipe.shuffle(buffer_size=pool_size)
Expand All @@ -118,7 +116,7 @@ def __new__(
# Shuffle the batched data
if sort_key is not None:
# In-batch shuffle each bucket seems not that useful, it seems misleading since .batch is called prior.
if in_batch_shuffle:
if use_in_batch_shuffle:
datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).in_batch_shuffle().unbatch()
else:
datapipe = datapipe.shuffle(buffer_size=bucket_size)
Expand Down