From e6d48684341e58ea50f6231f715e3b5595d553e3 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 10 Dec 2021 09:31:06 +0100 Subject: [PATCH] Allow variable number of repetitions for RA --- references/classification/sampler.py | 7 ++++--- references/classification/train.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/references/classification/sampler.py b/references/classification/sampler.py index a55e25a16b1..3c5e8b014b1 100644 --- a/references/classification/sampler.py +++ b/references/classification/sampler.py @@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler): https://github.com/facebookresearch/deit/blob/main/samplers.py """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available!") @@ -28,11 +28,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): self.num_replicas = num_replicas self.rank = rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) self.shuffle = shuffle self.seed = seed + self.repetitions = repetitions def __iter__(self): # Deterministically shuffle based on epoch @@ -44,7 +45,7 @@ def __iter__(self): indices = list(range(len(self.dataset))) # Add extra samples to make it evenly divisible - indices = [ele for ele in indices for i in range(3)] + indices = [ele for ele in indices for i in range(self.repetitions)] indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size diff --git a/references/classification/train.py b/references/classification/train.py index 8a942b99a5f..507356c8048 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -174,7 +174,7 @@ def load_data(traindir, valdir, args): print("Creating data loaders") if args.distributed: if args.ra_sampler: - train_sampler = RASampler(dataset, shuffle=True) + train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) @@ -485,7 +485,10 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training") + parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") + parser.add_argument( + "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" + ) # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")