Skip to content

Commit e6d4868

Browse files
committed
Allow variable number of repetitions for RA
1 parent c34a914 commit e6d4868

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

references/classification/sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
1515
https://github.com/facebookresearch/deit/blob/main/samplers.py
1616
"""
1717

18-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
18+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
1919
if num_replicas is None:
2020
if not dist.is_available():
2121
raise RuntimeError("Requires distributed package to be available!")
@@ -28,11 +28,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
2828
self.num_replicas = num_replicas
2929
self.rank = rank
3030
self.epoch = 0
31-
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
31+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
3232
self.total_size = self.num_samples * self.num_replicas
3333
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
3434
self.shuffle = shuffle
3535
self.seed = seed
36+
self.repetitions = repetitions
3637

3738
def __iter__(self):
3839
# Deterministically shuffle based on epoch
@@ -44,7 +45,7 @@ def __iter__(self):
4445
indices = list(range(len(self.dataset)))
4546

4647
# Add extra samples to make it evenly divisible
47-
indices = [ele for ele in indices for i in range(3)]
48+
indices = [ele for ele in indices for i in range(self.repetitions)]
4849
indices += indices[: (self.total_size - len(indices))]
4950
assert len(indices) == self.total_size
5051

references/classification/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def load_data(traindir, valdir, args):
174174
print("Creating data loaders")
175175
if args.distributed:
176176
if args.ra_sampler:
177-
train_sampler = RASampler(dataset, shuffle=True)
177+
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
178178
else:
179179
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
180180
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
@@ -485,7 +485,10 @@ def get_args_parser(add_help=True):
485485
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
486486
)
487487
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
488-
parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")
488+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
489+
parser.add_argument(
490+
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
491+
)
489492

490493
# Prototype models only
491494
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

0 commit comments

Comments
 (0)