diff --git a/test/test_transforms.py b/test/test_transforms.py index 392978d988b..0b71bae788b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1991,6 +1991,9 @@ def test_random_erasing(self): p_value = stats.binom_test(count_bigger_then_ones, trial, p=0.5) self.assertGreater(p_value, 0.0001) + # Checking if RandomErasing can be printed as string + t.__repr__() + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 916956e29fd..4eb0ab23c92 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1630,6 +1630,14 @@ def forward(self, img): return F.erase(img, x, y, h, w, v, self.inplace) return img + def __repr__(self): + s = '(p={}, '.format(self.p) + s += 'scale={}, '.format(self.scale) + s += 'ratio={}, '.format(self.ratio) + s += 'value={}, '.format(self.value) + s += 'inplace={})'.format(self.inplace) + return self.__class__.__name__ + s + class GaussianBlur(torch.nn.Module): """Blurs image with randomly chosen Gaussian blur.