From 936c951d2f6e26c09818368262762f93bf06c919 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 16 Mar 2022 15:35:31 -0400 Subject: [PATCH] Rename BucketBatcher argument to avoid name collision [ghstack-poisoned] --- test/test_datapipe.py | 12 ++++++------ .../datapipes/iter/transform/bucketbatcher.py | 16 +++++++--------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index b16c3de48..69c96929f 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -546,13 +546,13 @@ def test_bucket_batcher_iterdatapipe(self) -> None: # Functional Test: drop last reduces length batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=1, in_batch_shuffle=True + batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True ) self.assertEqual(9, len(list(batch_dp.unbatch()))) # Functional Test: drop last is False preserves length batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=False, batch_num=100, bucket_num=1, in_batch_shuffle=False + batch_size=3, drop_last=False, batch_num=100, bucket_num=1, use_in_batch_shuffle=False ) self.assertEqual(10, len(list(batch_dp.unbatch()))) @@ -561,15 +561,15 @@ def _return_self(x): # Functional Test: using sort_key, with in_batch_shuffle batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=1, in_batch_shuffle=True, sort_key=_return_self + batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True, sort_key=_return_self ) # bucket_num = 1 means there will be no shuffling if a sort key is given self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(batch_dp)) self.assertEqual(9, len(list(batch_dp.unbatch()))) - # Functional Test: using sort_key, without in_batch_shuffle + # Functional Test: using sort_key, without use_in_batch_shuffle batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=2, in_batch_shuffle=False, sort_key=_return_self + batch_size=3, drop_last=True, batch_num=100, bucket_num=2, use_in_batch_shuffle=False, sort_key=_return_self ) self.assertEqual(9, len(list(batch_dp.unbatch()))) @@ -580,7 +580,7 @@ def _return_self(x): drop_last=True, batch_num=100, bucket_num=2, - in_batch_shuffle=False, + use_in_batch_shuffle=False, sort_key=_return_self, ) n_elements_before_reset = 2 diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py index e2e04b002..3714d0977 100644 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ b/torchdata/datapipes/iter/transform/bucketbatcher.py @@ -5,9 +5,7 @@ from functools import partial from typing import Callable, Iterator, List, Optional, TypeVar -from torch.utils.data import DataChunk - -from torchdata.datapipes import functional_datapipe +from torchdata.datapipes import DataChunk, functional_datapipe from torchdata.datapipes.iter import IterDataPipe T_co = TypeVar("T_co", covariant=True) @@ -62,7 +60,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): batch_num: Number of batches within a bucket (i.e. `bucket_size = batch_size * batch_num`) bucket_num: Number of buckets to consist a pool for shuffling (i.e. `pool_size = bucket_size * bucket_num`) sort_key: Callable to sort a bucket (list) - in_batch_shuffle: iF True, do in-batch shuffle; if False, buffer shuffle + use_in_batch_shuffle: if True, do in-batch shuffle; if False, buffer shuffle Example: >>> from torchdata.datapipes.iter import IterableWrapper @@ -74,7 +72,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): >>> return sorted(bucket) >>> batch_dp = source_dp.bucketbatch( >>> batch_size=3, drop_last=True, batch_num=100, - >>> bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket + >>> bucket_num=1, use_in_batch_shuffle=False, sort_key=sort_bucket >>> ) >>> list(batch_dp) [[3, 4, 5], [6, 7, 8], [0, 1, 2]] @@ -85,7 +83,7 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): batch_num: int bucket_num: int sort_key: Optional[Callable] - in_batch_shuffle: bool + use_in_batch_shuffle: bool def __new__( cls, @@ -95,7 +93,7 @@ def __new__( batch_num: int = 100, bucket_num: int = 1, sort_key: Optional[Callable] = None, - in_batch_shuffle: bool = True, + use_in_batch_shuffle: bool = True, ): assert batch_size > 0, "Batch size is required to be larger than 0!" assert batch_num > 0, "Number of batches is required to be larger than 0!" @@ -106,7 +104,7 @@ def __new__( # Shuffle by pool_size if bucket_num > 1 or sort_key is None: - if in_batch_shuffle: + if use_in_batch_shuffle: datapipe = datapipe.batch(batch_size=pool_size, drop_last=False).in_batch_shuffle().unbatch() else: datapipe = datapipe.shuffle(buffer_size=pool_size) @@ -118,7 +116,7 @@ def __new__( # Shuffle the batched data if sort_key is not None: # In-batch shuffle each bucket seems not that useful, it seems misleading since .batch is called prior. - if in_batch_shuffle: + if use_in_batch_shuffle: datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).in_batch_shuffle().unbatch() else: datapipe = datapipe.shuffle(buffer_size=bucket_size)