diff --git a/references/classification/sampler.py b/references/classification/sampler.py index 3c5e8b014b1..e9dc1735a58 100644 --- a/references/classification/sampler.py +++ b/references/classification/sampler.py @@ -36,10 +36,10 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, self.repetitions = repetitions def __iter__(self): - # Deterministically shuffle based on epoch - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) if self.shuffle: + # Deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset)))