@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
15
15
https://github.com/facebookresearch/deit/blob/main/samplers.py
16
16
"""
17
17
18
- def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = True , seed = 0 ):
18
+ def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = True , seed = 0 , repetitions = 3 ):
19
19
if num_replicas is None :
20
20
if not dist .is_available ():
21
21
raise RuntimeError ("Requires distributed package to be available!" )
@@ -28,11 +28,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
28
28
self .num_replicas = num_replicas
29
29
self .rank = rank
30
30
self .epoch = 0
31
- self .num_samples = int (math .ceil (len (self .dataset ) * 3.0 / self .num_replicas ))
31
+ self .num_samples = int (math .ceil (len (self .dataset ) * float ( repetitions ) / self .num_replicas ))
32
32
self .total_size = self .num_samples * self .num_replicas
33
33
self .num_selected_samples = int (math .floor (len (self .dataset ) // 256 * 256 / self .num_replicas ))
34
34
self .shuffle = shuffle
35
35
self .seed = seed
36
+ self .repetitions = repetitions
36
37
37
38
def __iter__ (self ):
38
39
# Deterministically shuffle based on epoch
@@ -44,7 +45,7 @@ def __iter__(self):
44
45
indices = list (range (len (self .dataset )))
45
46
46
47
# Add extra samples to make it evenly divisible
47
- indices = [ele for ele in indices for i in range (3 )]
48
+ indices = [ele for ele in indices for i in range (self . repetitions )]
48
49
indices += indices [: (self .total_size - len (indices ))]
49
50
assert len (indices ) == self .total_size
50
51
0 commit comments