Skip to content

Commit e250db3

Browse files
authored
Adding Repeated Augment Sampler (#5051)
* Adding repaeted data-augument sampler * rebase on top of latest main * fix formatting * rename file * adding coode source
1 parent 47ae092 commit e250db3

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

references/classification/sampler.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import math
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
7+
class RASampler(torch.utils.data.Sampler):
8+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
9+
with repeated augmentation.
10+
It ensures that different each augmented version of a sample will be visible to a
11+
different process (GPU).
12+
Heavily based on 'torch.utils.data.DistributedSampler'.
13+
14+
This is borrowed from the DeiT Repo:
15+
https://github.com/facebookresearch/deit/blob/main/samplers.py
16+
"""
17+
18+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
19+
if num_replicas is None:
20+
if not dist.is_available():
21+
raise RuntimeError("Requires distributed package to be available!")
22+
num_replicas = dist.get_world_size()
23+
if rank is None:
24+
if not dist.is_available():
25+
raise RuntimeError("Requires distributed package to be available!")
26+
rank = dist.get_rank()
27+
self.dataset = dataset
28+
self.num_replicas = num_replicas
29+
self.rank = rank
30+
self.epoch = 0
31+
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
32+
self.total_size = self.num_samples * self.num_replicas
33+
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34+
self.shuffle = shuffle
35+
36+
def __iter__(self):
37+
# Deterministically shuffle based on epoch
38+
g = torch.Generator()
39+
g.manual_seed(self.epoch)
40+
if self.shuffle:
41+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
42+
else:
43+
indices = list(range(len(self.dataset)))
44+
45+
# Add extra samples to make it evenly divisible
46+
indices = [ele for ele in indices for i in range(3)]
47+
indices += indices[: (self.total_size - len(indices))]
48+
assert len(indices) == self.total_size
49+
50+
# Subsample
51+
indices = indices[self.rank : self.total_size : self.num_replicas]
52+
assert len(indices) == self.num_samples
53+
54+
return iter(indices[: self.num_selected_samples])
55+
56+
def __len__(self):
57+
return self.num_selected_samples
58+
59+
def set_epoch(self, epoch):
60+
self.epoch = epoch

references/classification/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torchvision
1010
import transforms
1111
import utils
12+
from references.classification.sampler import RASampler
1213
from torch import nn
1314
from torch.utils.data.dataloader import default_collate
1415
from torchvision.transforms.functional import InterpolationMode
@@ -172,7 +173,10 @@ def load_data(traindir, valdir, args):
172173

173174
print("Creating data loaders")
174175
if args.distributed:
175-
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
176+
if args.ra_sampler:
177+
train_sampler = RASampler(dataset, shuffle=True)
178+
else:
179+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
176180
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
177181
else:
178182
train_sampler = torch.utils.data.RandomSampler(dataset)
@@ -481,6 +485,7 @@ def get_args_parser(add_help=True):
481485
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
482486
)
483487
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
488+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")
484489

485490
# Prototype models only
486491
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

0 commit comments

Comments
 (0)