Skip to content

[NOMERGE] Review new Transforms API #5500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __new__(
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# Do we still need this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need, no. Want, yes. I actually want to keep them after this is promoted out of prototype. It is a convenient way to work with the features outside of a training pipeline. Compare for example

from torchvision.prototype import features, transforms

transform = transforms.ConvertBoundingBoxFormat(format="xyxy")

bounding_box = features.BoundingBox(...)
xyxy_bounding_box = transform(bounding_box)

or

from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F

bounding_box = features.BoundingBox(...)
xyxy_bounding_box = features.BoundingBox.new_like(
    bounding_box, 
    convert_bounding_box_format(bounding_box, old_format=bounding_box.format, new_format="xyxy"),
    format="xyxy",
)

to

from torchvision.prototype import features

bounding_box = features.BoundingBox(...)
xyxy_bounding_box = bounding_box.to_format("xyxy")

Still, I agree, the lazy import needs to go. My proposal is to move the actual conversion code from transforms.functional to a private function like _convert_bounding_box_format in this module. We could use it here in to_format and simply expose it in transforms.functional.


This comment applies to all other comments that ask the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the verbosity of option 2 is an indication that we need a utility. It doesn't necessarily means that the method should be part of BoundingBox, so let's explore options. Note that we should avoid the use of strings in favour of their enums in our examples and internal code to get better static analysis.

This comment applies to all other comments that ask the same.

The 2 visualization methods show() and draw_bounding_box() on Image are not the same thing. We should remove them as we don't want to necessarily push PIL on the new API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we should avoid the use of strings in favour of their enums in our examples and internal code to get better static analysis.

Internally I agree 100%, but do we actually want to discourage users from using them in application code? "xyxy" is more concise than BoundingBoxFormat.XYXY and potentially needs one less import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have previously made the decision to move with enums (you can see this on multiple APIs including io, multi-weights etc). For consistency we should continue this here.

Where we agree is that people should be allowed to pass strings if they want to. I want to discourage them from doing so by offering all the examples an internal code using enums but if they choose to pass strings so be it. I just don't want to promote this usage. Note that this is true for virtually all places where we have enums.

Copy link
Collaborator

@pmeier pmeier Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, quietly supporting strings is still a good idea, but for everything else like documentation or examples we should use the enums.


# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import convert_bounding_box_format
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def image_size(self) -> Tuple[int, int]:
def decode(self) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# Do we still need this?

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import decode_image_with_pil
Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ColorSpace(StrEnum):
OTHER = 0
GRAYSCALE = 1
RGB = 3
# On io, we currently support Alpha transparency for Grayscale and RGB. We also support palette for PNG.
# Our enum must know about these. We should also allow for users to strip out the alpha transparency which is
# currently not supported by any colour transform.
Comment on lines +20 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also support palette for PNG.

Could you add a link for that? Otherwise, we are talking about this enum, correct?

class ImageReadMode(Enum):
"""
Support for various modes while reading images.
Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
``ImageReadMode.GRAY`` for converting to grayscale,
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
"""
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4

I'll send a PR to include the RGBA and GRAYSCALE_ALPHA into the enum and also provide conversion functions for them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this, conversion from RGBA to RGB is more complicated. For a proper conversion we need to know the background color. We could assume white which will probably work for most cases, but we probably should discuss this first.

Copy link
Contributor Author

@datumbox datumbox Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Palettes are supported if the user reads with IMAGE_READ_MODE_UNCHANGED. The outcome will be an image with a single channel that has integer IDs for every colour. This is ideal for reading masks that have integer IDs for the classes.

Effectively palettes look like grayscale images but their values being ids on the 0-255 scales. Given they can be useful only for niche applications, I don't feel strongly about supporting them on this API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paletted images can be also used to store segmentation masks e.g. VOC



class Image(_Feature):
Expand Down Expand Up @@ -78,9 +81,11 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# Do we still need this?
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()

def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
# Do we still need this?
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
11 changes: 11 additions & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot

# What are the migration plans for Classes without new API equivalents? There are two categories:
# 1. Deprecated methods which have equivalents on the new API (_legacy.py?):
# - Grayscale, RandomGrayscale: use ConvertImageColorSpace
Comment on lines +17 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

# 2. Those without equivalents on the new API:
# - Pad, RandomCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomPerspective, FiveCrop, TenCrop, ColorJitter,
# RandomRotation, RandomAffine, GaussianBlur, RandomInvert, RandomPosterize, RandomSolarize, RandomAdjustSharpness,
# RandomAutocontrast, RandomEqualize, LinearTransformation (must be added)
Comment on lines +20 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 For everything but FiveCrop and TenCrop. They are outliers in the sense that they get one input and produce multiple outputs. This kind of structure would be suited a lot better to a datapipe rather than a transform.

This is also reinforced by the fact that these transforms can only be put at the end of any pipeline with the old API given that the output type is not compatible with all the other transforms. Even then, using such a transform would most likely need a custom collate function.

I couldn't find any usage in our reference scripts. Does someone know how exactly they are used so I can propose something more concrete here?

Copy link
Contributor Author

@datumbox datumbox Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is if we want to roll out the API with BC support, we must include implementations for all previous methods + offer alternatives for those that would be placed in deprecated state. Though I understand the challenge, excluding these methods is not OK and we should flag if others like these exist.

The specific transforms are used to produce multiple crops for inference in classification. It is a common trick (especially in the early days of ImageNet competition) to boost the accuracy of the model by averaging the responses of multiple crops. You can see an example of how this is used here.

# - PILToTensor, ToPILImage (_legacy.py?)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also have a more general ConvertImageType transform here that takes one of "pil", "vanilla_tensor", "feature" as input? That would follow the same approach that we have for the other conversions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I mean. Let's say we are close to releasing the prototype and we have achieved a significant level of BC. On day 0, we will need a PILToTensor to exist. These are likely be in a deprecated state with warnings informing the user that this functionality is removed and they should use X. Your proposal seems like a valid X alternative.

So here I think we need to figure out 2 things:

  • What do we do with these legacy classes + provide their implementations on the new API (or aliases). Make sure that the classes continue to work along with the other non-deprecated classes in existing pipelines.
  • Provide an alternative class that offers the new functionality in the new API.

# - ToTensor (deprecate vfdev-5?)
# We need a plan for both categories implemented on the new API.
10 changes: 10 additions & 0 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
# The above approach is very composable but will lead to verbose code. Instead, we can create a base class
# that inherits from Transform (say RandomTransform) that receives the `p` on constructor and by default
# implements the `p` random check on forward. This is likely to happen on the final clean ups, so perhaps
# update the comment to indicate accordingly OR create an issue to track this discussion.
self.p = p
self.scale = scale
self.ratio = ratio
Expand Down Expand Up @@ -84,6 +88,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, image
# FYI: Now taht we are not JIT-scriptable, I probably can avoid copying-pasting the image to itself in this
# scenario. Perhaps a simple clone would do.
Comment on lines +91 to +92
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there is major cleanup coming up. A lot of the annotations on the transforms are either wrong or to strict to appease torchscript. We should revisit before we promote out of prototype. There is also a user request for this: #5398

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. There is a non-trivial number of clean ups we need to do prior landing this API to the main area. Please create an issue that describes these clean ups (see this example). I would also recommend creating a Project to keep track of all the related tickets (see this example).


return dict(zip("ijhwv", (i, j, h, w, v)))

Expand All @@ -95,6 +101,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.erase_image_tensor(input, **params)
# FYI: we plan to add support for PIL, as part of Batteries Included
else:
return input

Expand All @@ -112,6 +119,8 @@ def __init__(self, *, alpha: float) -> None:
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

def _get_params(self, sample: Any) -> Dict[str, Any]:
# Question: Shall we enforce here the existence of Labels in the sample? If yes, this method of validating
# input won't work if get_params() gets public and the user sis able to provide params in forward.
return dict(lam=float(self._dist.sample(())))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
Expand All @@ -134,6 +143,7 @@ def __init__(self, *, alpha: float) -> None:
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

def _get_params(self, sample: Any) -> Dict[str, Any]:
# Question: Same as above for Labels.
lam = float(self._dist.sample(()))

image = query_image(sample)
Expand Down
11 changes: 11 additions & 0 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(
self.fill = fill

def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
# Question: This looks like a utility method that doesn't depend on self and could move out of this
# class to simplify its structure (for someone who tries to understand its API to implement a new method).
# Thoughts?
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
Expand All @@ -46,6 +49,9 @@ def _check_unsupported(self, input: Any) -> None:
def _extract_image(
self, sample: Any
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
# How about providing the unsupported types as parameter to this method to avoid hardcoding it? This can
# allow us to remove the separate _check_unsupported method. Also this method could be removed from the
# class to simplify its structure.
def fn(
id: Tuple[Any, ...], input: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
Expand All @@ -67,6 +73,7 @@ def fn(
def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Optional[List[float]]:
# Question: How do you feel about turning this also a util?
fill = self.fill

if isinstance(image, PIL.Image.Image) or fill is None:
Expand Down Expand Up @@ -204,6 +211,8 @@ class AutoAugment(_AutoAugmentBase):
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
# How do you feel about explicitly passing height, width here instead of image_size? I did it in an earlier PR
# but removed it due to the verbosity. Thoughts?
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
Expand All @@ -222,6 +231,8 @@ class AutoAugment(_AutoAugmentBase):
}

def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None:
# I think using kwargs on the constructor makes it hard to users to understand which parameters to provide.
# I recommend passing explicitly interpolation and fill in all AA constructors.
super().__init__(**kwargs)
self.policy = policy
self._policies = self._get_policies(policy)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def extra_repr(self) -> str:

class RandomChoice(Transform):
def __init__(self, *transforms: Transform) -> None:
# This method should receive optionally a list of probabilities and sample transforms proportionally.
super().__init__()
self.transforms = transforms
for idx, transform in enumerate(transforms):
Expand Down
8 changes: 6 additions & 2 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import convert_image_dtype
from torchvision.transforms.functional import convert_image_dtype # We should have our an alias for this on the new API


class ConvertBoundingBoxFormat(Transform):
Expand All @@ -23,6 +23,10 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:


class ConvertImageDtype(Transform):
# Question: Why do we have both this and a ToDtype Transform? Is this due to BC? Ideally we could move people off
# from methods that did an implicit normalization of the values (like this one, or to_tensor). cc @vfdev-5
# If that's the case, we should move to _legacy and add deprecation warnings from day one to push people to use
# the new methods.
Comment on lines +26 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that @datumbox and I discussed a while back is abolish the close correspondence of the dtype and the value range for images. Right now we assume the range [0, 1] for floating point images and [0, torch.iinfo(dtype).max] for integral images. A possible solution for this would be have a value_range meta data on the features.Image class. It can be set with the defaults from above so for a regular use case the user won't feel any difference.

For example, let's say we have a uint8 image

import torch
from torchvision.prototype import features

image = features.Image(torch.randint(0, 256, (3, 128, 128)))

By default print(image.value_range) would print (0, 255). If we now convert the dtype like image = image.to(torch.float32), printing the value range again would give (0.0, 255.0) rather than (0.0, 1.0) what happens through ConvertImageDtype.

This means, all transformations that implicitly rely on the value range like Normalize need to be adapted to use the explicit value instead, but we could get rid of ConvertImageDtype all together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely love this and 100% agree we should do it. I've previously asked @vfdev-5 to create an issue and see which methods of the low-level kernels should be expanded to receive a max_value. We can write this in a BC manner on the low-level kernels; if the max_value is not passed we do the previous behaviour. Tensors will continue making this assumption, but Images shouldn't.

I feel this is a critical part of the proposal that needs to be investigated prior implementing all other transforms. This is probably one of the few remained uninvestigated bits of the proposal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI #5502

def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
Expand Down Expand Up @@ -59,7 +63,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor):
if self.old_color_space is None:
raise RuntimeError("")
raise RuntimeError("") # Add better exception message

return F.convert_image_color_space_tensor(
input, old_color_space=self.old_color_space, new_color_space=self.color_space
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)

# FYI: I spoke with @NicolasHug offline and provided a valid use-case where it would be useful to make this public,
# as you originally intended. We don't have to do this now but if we do end up exposing this publicly, we should
# rename it to avoid conflicts with the previous static get_params() which worked completely differently.
# We dn't have to do anything now but we should discuss this again after the feedback.
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict()

Expand Down
16 changes: 16 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# He should create an issue that lists the steps that need to be performed for rolling out the API to main TorchVision.
# I got a similar for the models, see here: https://github.com/pytorch/vision/issues/4679
# One of the key things we would need to take care of is that all the kernels below will need logging. This is because
# there will be no high-level kernel (like `F` on main) and we would instead need to put tracking directly in each
# low-level kernels which will be now public (now functional_pil|tensor are private).

from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
convert_bounding_box_format,
Expand Down Expand Up @@ -63,3 +69,13 @@
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot

# What are the migration plans for public methods without new API equivalents? There are two categories:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Reminder for migration plan on these as well. Might be already on your radar as part of the Transform Classes.

# 1. Deprecated methods which have equivalents on the new API (_legacy.py?):
# - get_image_size, get_image_num_channels: use get_dimensions_image_tensor|pil
# - to_grayscale, rgb_to_grayscale: use convert_image_color_space_tensor|pil
# 2. Those without equivalents on the new API:
# - five_crop, ten_crop (must be added)
# - pil_to_tensor, to_pil_image (_legacy.py?)
# - to_tensor() (deprecate vfdev-5?)
# We need a plan for both categories implemented on the new API.
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def resize_segmentation_mask(


# TODO: handle max_size
# Where is the issue for this TODO?
def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
Expand Down Expand Up @@ -220,6 +221,8 @@ def perspective_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> torch.Tensor:
# We should go ahead and update the _FT API to accept InterpolationMode. It's currently considered private and
# there are no BC guarantees. This will allow you to stop needing to do `.value` in many of these places.
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)


Expand All @@ -229,6 +232,7 @@ def perspective_image_pil(
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None,
) -> PIL.Image.Image:
# Same thing here. We should move the pil_modes_mapping convertion in the `_FP` side.
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)


Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def convert_bounding_box_format(


def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
return grayscale.expand(3, 1, 1) # This approach assumes single image and not batch


def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
# the new color_space should only be RGB or Grayscale
if new_color_space == old_color_space:
return image.clone()

Expand All @@ -73,6 +74,12 @@ def convert_image_color_space_tensor(
if new_color_space == ColorSpace.GRAYSCALE:
image = _FT.rgb_to_grayscale(image)

# we need a way to strip off alpha transparencies:
# - RGBA => RGB
# - RGBA => Grayscale
# - Gray-Alpha => RGB
# - Gray-Alpha => Gray

return image


Expand All @@ -83,6 +90,7 @@ def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:
def convert_image_color_space_pil(
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> PIL.Image.Image:
# the new color_space should only be RGB or Grayscale
if new_color_space == old_color_space:
return image.copy()

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ def gaussian_blur_image_tensor(


def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image:
# It would really help to remove to_tensor from here vfdev-5
return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma))