Skip to content

Replace get_image_size/num_channels with get_dimensions #5487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import query_image
from ._utils import query_image, get_image_dimensions


class RandomErasing(Transform):
Expand Down Expand Up @@ -41,8 +41,7 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)
img_c, img_h, img_w = get_image_dimensions(image)

if isinstance(self.value, (int, float)):
value = [self.value]
Expand Down Expand Up @@ -138,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))

image = query_image(sample)
W, H = F.get_image_size(image)
_, H, W = get_image_dimensions(image)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively

from ._utils import query_image
from ._utils import query_image, get_image_dimensions

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -47,7 +47,7 @@ def dispatch(
return input

image = query_image(sample)
num_channels = F.get_image_num_channels(image)
num_channels, *_ = get_image_dimensions(image)

fill = self.fill
if isinstance(fill, (int, float)):
Expand Down Expand Up @@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

policy = self._policies[int(torch.randint(len(self._policies), ()))]

Expand All @@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:

magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

magnitudes = magnitudes_fn(10, image_size)
magnitudes = magnitudes_fn(10, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
Expand All @@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
Expand All @@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
Expand Down Expand Up @@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int

from ._utils import query_image
from ._utils import query_image, get_image_dimensions


class HorizontalFlip(Transform):
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
width, height = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
area = height * width

log_ratio = torch.log(torch.tensor(self.ratio))
Expand Down
16 changes: 15 additions & 1 deletion torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Tuple, Union

import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP


def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
Expand All @@ -17,3 +18,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima
return next(query_recursively(fn, sample))
except StopIteration:
raise TypeError("No image was found in the sample")


def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
channels, height, width = _FT.get_dimensions(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = _FP.get_dimensions(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
1 change: 0 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip
from ._meta_conversion import (
convert_bounding_box_format,
convert_image_color_space_tensor,
Expand Down
18 changes: 8 additions & 10 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import InterpolationMode
from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix

Expand Down Expand Up @@ -40,8 +39,7 @@ def resize_image_tensor(
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
old_width, old_height = _FT.get_image_size(image)
num_channels = _FT.get_image_num_channels(image)
num_channels, old_height, old_width = _FT.get_dimensions(image)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
Expand Down Expand Up @@ -143,7 +141,7 @@ def affine_image_tensor(

center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = _FT.get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]

Expand All @@ -169,7 +167,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
width, height = get_image_size(img)
_, height, width = _FP.get_dimensions(img)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

Expand All @@ -186,7 +184,7 @@ def rotate_image_tensor(
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = _FT.get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]

Expand Down Expand Up @@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor(

def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = _FT.get_dimensions(img)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)

image_width, image_height = get_image_size(img)
_, image_height, image_width = _FT.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand All @@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch

def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = _FP.get_dimensions(img)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)

image_width, image_height = get_image_size(img)
_, image_height, image_width = _FP.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand Down
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import PIL.Image
import torch
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import to_tensor, to_pil_image


get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions

normalize_image_tensor = _FT.normalize


Expand Down
29 changes: 0 additions & 29 deletions torchvision/prototype/transforms/functional/_utils.py

This file was deleted.

11 changes: 10 additions & 1 deletion torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool:
return isinstance(img, Image.Image)


@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
channels = len(img.getbands())
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")


@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
Expand All @@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == "L" else 3
return len(img.getbands())
raise TypeError(f"Unexpected type {type(img)}")


Expand Down
7 changes: 7 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None:
raise TypeError("Threshold should be less than bound of img.")


def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]


def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
Expand Down