diff --git a/test/preprocess-bench.py b/test/preprocess-bench.py index 85599362a73..9acf9e71e54 100644 --- a/test/preprocess-bench.py +++ b/test/preprocess-bench.py @@ -6,6 +6,11 @@ import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets +try: + import accimage + print("Using accimage.Image") +except ImportError: + print("Using PIL.Image") parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 5eb3126ae96..9fb528f1614 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,6 +1,10 @@ import torch.utils.data as data from PIL import Image +try: + import accimage +except ImportError: + accimage = None import os import os.path @@ -47,7 +51,14 @@ def __init__(self, root, transform=None, target_transform=None): def __getitem__(self, index): path, target = self.imgs[index] - img = Image.open(os.path.join(self.root, path)).convert('RGB') + if accimage is None: + img = Image.open(os.path.join(self.root, path)).convert('RGB') + else: + try: + img = accimage.Image(os.path.join(self.root, path)) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + img = Image.open(os.path.join(self.root, path)).convert('RGB') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 48be812569b..9b42f30bb1d 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -3,6 +3,10 @@ import math import random from PIL import Image, ImageOps +try: + import accimage +except ImportError: + accimage = None import numpy as np import numbers import types @@ -28,7 +32,11 @@ class ToTensor(object): """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ def __call__(self, pic): - if isinstance(pic, np.ndarray): + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.empty([pic.channels, pic.height, pic.width]) + pic.copyto(nppic) + img = torch.from_numpy(nppic) + elif isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic) else: