Skip to content

[NOMERGE] Transforms V2 API overview #6486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _copy_paste(
xyxy_boxes = masks_to_boxes(masks)
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
# we need to add +1 to x2y2. We need to investigate that.
# datumbox: I had a look on other reference implementations and I see a similar +1 to make it exclusive. I think we can keep it.
xyxy_boxes[:, 2:] += 1
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def __init__(
self.num_magnitude_bins = num_magnitude_bins

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: This idiom doesn't work here because we do multiple random() calls within the transform.
# For example `_get_random_item()` or `torch.rand(())`.
# We should revert back to overwriting the `forward()` reverting the majority of refactoring done at
# https://github.com/pytorch/vision/pull/6463
# this applies to all AutoAugment methods here, not just RandAugment.
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt

Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/transforms/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def extra_repr(self) -> str:


class RandomChoice(Transform):
# datumbox: Shouldn't this be `transforms: Sequence[Callable]` for BC, similar to what we did for `Compose`?
def __init__(self, *transforms: Transform, probabilities: Optional[List[float]] = None) -> None:
if probabilities is None:
probabilities = [1] * len(transforms)
Expand All @@ -46,7 +47,7 @@ def __init__(self, *transforms: Transform, probabilities: Optional[List[float]]

self.transforms = transforms
for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform)
self.add_module(str(idx), transform) # datumbox: we probably can't use add_module here. See https://github.com/pytorch/vision/pull/6391

total = sum(probabilities)
self.probabilities = [p / total for p in probabilities]
Expand All @@ -58,6 +59,7 @@ def forward(self, *inputs: Any) -> Any:


class RandomOrder(Transform):
# datumbox: same as above here
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self.transforms = transforms
Expand Down
8 changes: 8 additions & 0 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self) -> None:
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: why check here again for types? Can't we just define
# `_transformed_types = (PIL.Image.Image, np.ndarray)`
# and then eliminate this if/else?
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(inpt)
else:
Expand All @@ -43,6 +46,7 @@ def __init__(self) -> None:
super().__init__()

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: similar here?
if isinstance(inpt, PIL.Image.Image):
return _F.pil_to_tensor(inpt)
else:
Expand All @@ -63,6 +67,7 @@ def __init__(self, mode: Optional[str] = None) -> None:
self.mode = mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: similar here?
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
return _F.to_pil_image(inpt, mode=self.mode)
else:
Expand Down Expand Up @@ -115,5 +120,8 @@ def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: What's the reason of converting it to RGB to then conver it to Grayscale? Is it to ensure the input
# is an RGB and we rely on a no-op behaviour to avoid errors? I think this logic should change to do the convertion
# only when needed. This will lead to both clearer code and more efficient.
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
8 changes: 8 additions & 0 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples
# What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image
# datumbox: One issue with implementing the above proposal is that `query_chw()` is supposed to return channels
# that bboxes and masks don't have. It could return None, as the majority of ops don't need it. Thoughts?
# The same applies to all similar TODOs in this file.
_, height, width = query_chw(sample)
area = height * width

Expand Down Expand Up @@ -163,6 +166,8 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a
# list here to align it with TenCrop.
# datumbox: The above TODO is no longer valid. Perhaps it should be adapted to say the opposite, aka
# shall we break BC and return a list to align with TenCrop?
if isinstance(inpt, features.Image):
output = F.five_crop_image_tensor(inpt, self.size)
return tuple(features.Image.new_like(inpt, o) for o in output)
Expand Down Expand Up @@ -476,6 +481,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: This is a prime candidate for speed optimization to avoid repeated pad calls. We should add a TODO here
if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)

Expand Down Expand Up @@ -804,6 +810,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

if needs_crop:
bounding_boxes = query_bounding_box(sample)
# datumbox: This transform should be able to work without bboxes. Hence if there are no bounding_boxes
# we should set `is_valid=None`.
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
)
Expand Down
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
self.format = format

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: We shouldn't have if/elses here. We could just set:
# `_transformed_types = (features.BoundingBox,)`
# and eliminate them.
if isinstance(inpt, features.BoundingBox):
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"])
Expand All @@ -42,7 +45,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image) # datumbox: Should we be using `torch.Tensor` or is_simple_tensor()?

def __init__(
self,
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
# datumbox: the use of types here is BC breaking. Perhaps we should ensure the default behaviour when types
# is not passed aligns with the old behaviour.
super().__init__()
self.fn = fn
self.types = types
Expand Down Expand Up @@ -61,7 +63,8 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso
self.mean_vector = mean_vector

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

# datumbox: This feels like it's on the wrong place. Pass throughs are usually defined at `_transformed_types`
# and unsupported types are handled in `forward()`. Should we refactor this to align with the rest of the code?
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
Expand Down Expand Up @@ -92,11 +95,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

class Normalize(Transform):
def __init__(self, mean: List[float], std: List[float]):
# datumbox: The previous documentation says taht these need to be sequencies but here we have lists. Should
# refactor this to ensure BC?
super().__init__()
self.mean = mean
self.std = std

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: This works because the `F` kernel handles the passthrough. It's not the job of the kernel to do
# this. We should instead declare the supported types in `_transformed_types`.
return F.normalize(inpt, mean=self.mean, std=self.std)


Expand Down Expand Up @@ -131,6 +138,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ToDtype(Lambda):
# datumbox: if on Lambda we add a default behaviour for types, we might want to do the same here.
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
Expand Down
10 changes: 10 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

class DecodeImage(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: We shouldn't have if/elses here. We could just set:
# `_transformed_types = (features.EncodedImage,)`
# and eliminate them.
if isinstance(inpt, features.EncodedImage):
output = F.decode_image_with_pil(inpt)
return features.Image(output)
Expand All @@ -25,6 +28,9 @@ def __init__(self, num_categories: int = -1):
self.num_categories = num_categories

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: We shouldn't have if/elses here. We could just set:
# `_transformed_types = (features.Label,)`
# and eliminate them.
if isinstance(inpt, features.Label):
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
Expand All @@ -45,12 +51,14 @@ class ToImageTensor(Transform):

# Updated transformed types for ToImageTensor
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
# datumbox: I don't think Tensor and features._Feature should be here. This will lead to bboxes being converted to Tensors

def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: Similarly we should rely on `_transformed_types` and avoid if/else here
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
output = F.to_image_tensor(inpt, copy=self.copy)
return features.Image(output)
Expand All @@ -62,12 +70,14 @@ class ToImagePIL(Transform):

# Updated transformed types for ToImagePIL
_transformed_types = (torch.Tensor, features._Feature, PIL.Image.Image, np.ndarray)
# datumbox: same as above

def __init__(self, *, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# datumbox: same as above
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
return F.to_image_pil(inpt, mode=self.mode)
else:
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if isinstance(i, features.BoundingBox):
return i
return i # datumbox: should we be throwing an exception if we found more than one, similar to `query_chw()`

raise TypeError("No bounding box was found in the sample")

Expand All @@ -22,7 +22,7 @@ def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tupl
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
elif isinstance(image, torch.Tensor): # datumbox: should this have been is_simple_tensor() instead?
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
Expand Down Expand Up @@ -68,5 +68,6 @@ def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -
# TODO: Given that this is not related to pytree / the Transform object, we should probably move it to somewhere else.
# One possibility is `functional._utils` so both the functionals and the transforms have proper access to it. We could
# also move it `features` since it literally checks for the _Feature type.
# datumbox: let's do the move in `features` module.
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, features._Feature)
12 changes: 12 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,16 @@ def elastic(


elastic_transform = elastic
# datumbox: we should decide upon the naming of those that differ from the old API. Others like this include vflip,
# hflip, pil_to_tensor, to_pil_image. Some of them might not be worth renaming to avoid changing user code; those that we do need to change them
# we should offer aliases that throw a deprecation warning. First we need to get a list of those functionals that
# were renamed.

# datumbox: Some kernels don't have mid-level dispachers. Examples:
# get_dimensions, five_crop, ten_crop
# Also some others don't have their names exposed via the prototype functional at all. Here are some of them:
# to_tensor, get_image_size, get_image_num_channels, convert_image_dtype
# We should check which of them should be deprecated and which should be exposed normally via an alias


def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
Expand Down Expand Up @@ -1043,6 +1053,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
def five_crop_image_tensor(
img: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# datumbox: Add TODO to consider breaking BC and return a list to align with ten_crop
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(img)

Expand All @@ -1062,6 +1073,7 @@ def five_crop_image_tensor(
def five_crop_image_pil(
img: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
# datumbox: same as above
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(img)

Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def convert_bounding_box_format(
def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
) -> torch.Tensor:
# datumbox: Perhaps we could speed up clamping if we have different implementations for each bbox format. Not sure
# if they yield equivalent results. We should add a TODO for performance investigation.
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
# datumbox: we should not be doing pass throughs in the kernels. We should remove this if.
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
Expand Down