diff --git a/data/base_dataset.py b/data/base_dataset.py index 1bd57d0..1275d60 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -6,6 +6,12 @@ import numpy as np import torch.utils.data as data from PIL import Image +try: + from PIL.Image import Resampling + RESAMPLING_METHOD = Resampling.BICUBIC +except ImportError: + from PIL.Image import BICUBIC + RESAMPLING_METHOD = BICUBIC import torchvision.transforms as transforms from abc import ABC, abstractmethod @@ -95,8 +101,8 @@ def get_affine_mat(opt, size): affine_inv = np.linalg.inv(affine) return affine, affine_inv, flip -def apply_img_affine(img, affine_inv, method=Image.BICUBIC): - return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) +def apply_img_affine(img, affine_inv, method=RESAMPLING_METHOD): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=RESAMPLING_METHOD) def apply_lm_affine(landmark, affine, flip, size): _, h = size diff --git a/util/preprocess.py b/util/preprocess.py index a6de6ea..c516f45 100644 --- a/util/preprocess.py +++ b/util/preprocess.py @@ -3,7 +3,14 @@ import numpy as np from scipy.io import loadmat -from PIL import Image + +try: + from PIL.Image import Resampling + RESAMPLING_METHOD = Resampling.BICUBIC +except ImportError: + from PIL.Image import BICUBIC + RESAMPLING_METHOD = BICUBIC + import cv2 import os from skimage import transform as trans @@ -142,11 +149,11 @@ def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) below = up + target_size - img = img.resize((w, h), resample=Image.BICUBIC) + img = img.resize((w, h), resample=RESAMPLING_METHOD) img = img.crop((left, up, right, below)) if mask is not None: - mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.resize((w, h), resample=RESAMPLING_METHOD) mask = mask.crop((left, up, right, below)) lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - diff --git a/util/util.py b/util/util.py index 0d689ca..0db5ec9 100644 --- a/util/util.py +++ b/util/util.py @@ -4,6 +4,12 @@ import numpy as np import torch from PIL import Image +try: + from PIL.Image import Resampling + RESAMPLING_METHOD = Resampling.BICUBIC +except ImportError: + from PIL.Image import BICUBIC + RESAMPLING_METHOD = BICUBIC import os import importlib import argparse @@ -107,9 +113,9 @@ def save_image(image_numpy, image_path, aspect_ratio=1.0): if aspect_ratio is None: pass elif aspect_ratio > 1.0: - image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + image_pil = image_pil.resize((h, int(w * aspect_ratio)), RESAMPLING_METHOD) elif aspect_ratio < 1.0: - image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil = image_pil.resize((int(h / aspect_ratio), w), RESAMPLING_METHOD) image_pil.save(image_path) @@ -166,13 +172,13 @@ def correct_resize_label(t, size): return torch.stack(resized, dim=0).to(device) -def correct_resize(t, size, mode=Image.BICUBIC): +def correct_resize(t, size, mode=RESAMPLING_METHOD): device = t.device t = t.detach().cpu() resized = [] for i in range(t.size(0)): one_t = t[i:i + 1] - one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + one_image = Image.fromarray(tensor2im(one_t)).resize(size, RESAMPLING_METHOD) resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 resized.append(resized_t) return torch.stack(resized, dim=0).to(device)