|
| 1 | +""" |
| 2 | +=================================== |
| 3 | +How to write your own v2 transforms |
| 4 | +=================================== |
| 5 | +
|
| 6 | +This guide explains how to write transforms that are compatible with the |
| 7 | +torchvision transforms V2 API. |
| 8 | +""" |
| 9 | + |
| 10 | +# %% |
| 11 | +import torch |
| 12 | +import torchvision |
| 13 | + |
| 14 | +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that |
| 15 | +# some APIs may slightly change in the future |
| 16 | +torchvision.disable_beta_transforms_warning() |
| 17 | + |
| 18 | +from torchvision import datapoints |
| 19 | +from torchvision.transforms import v2 |
| 20 | + |
| 21 | + |
| 22 | +# %% |
| 23 | +# Just create a ``nn.Module`` and override the ``forward`` method |
| 24 | +# =============================================================== |
| 25 | +# |
| 26 | +# In most cases, this is all you're going to need, as long as you already know |
| 27 | +# the structure of the input that your transform will expect. For example if |
| 28 | +# you're just doing image classification, your transform will typically accept a |
| 29 | +# single image as input, or a ``(img, label)`` input. So you can just hard-code |
| 30 | +# your ``forward`` method to accept just that, e.g. |
| 31 | +# |
| 32 | +# .. code:: python |
| 33 | +# |
| 34 | +# class MyCustomTransform(torch.nn.Module): |
| 35 | +# def forward(self, img, label): |
| 36 | +# # Do some transformations |
| 37 | +# return new_img, new_label |
| 38 | +# |
| 39 | +# .. note:: |
| 40 | +# |
| 41 | +# This means that if you have a custom transform that is already compatible |
| 42 | +# with the V1 transforms (those in ``torchvision.transforms``), it will |
| 43 | +# still work with the V2 transforms without any change! |
| 44 | +# |
| 45 | +# We will illustrate this more completely below with a typical detection case, |
| 46 | +# where our samples are just images, bounding boxes and labels: |
| 47 | + |
| 48 | +class MyCustomTransform(torch.nn.Module): |
| 49 | + def forward(self, img, bboxes, label): # we assume inputs are always structured like this |
| 50 | + print( |
| 51 | + f"I'm transforming an image of shape {img.shape} " |
| 52 | + f"with bboxes = {bboxes}\n{label = }" |
| 53 | + ) |
| 54 | + # Do some transformations. Here, we're just passing though the input |
| 55 | + return img, bboxes, label |
| 56 | + |
| 57 | + |
| 58 | +transforms = v2.Compose([ |
| 59 | + MyCustomTransform(), |
| 60 | + v2.RandomResizedCrop((224, 224), antialias=True), |
| 61 | + v2.RandomHorizontalFlip(p=1), |
| 62 | + v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) |
| 63 | +]) |
| 64 | + |
| 65 | +H, W = 256, 256 |
| 66 | +img = torch.rand(3, H, W) |
| 67 | +bboxes = datapoints.BoundingBoxes( |
| 68 | + torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), |
| 69 | + format="XYXY", |
| 70 | + canvas_size=(H, W) |
| 71 | +) |
| 72 | +label = 3 |
| 73 | + |
| 74 | +out_img, out_bboxes, out_label = transforms(img, bboxes, label) |
| 75 | +# %% |
| 76 | +print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") |
| 77 | +# %% |
| 78 | +# .. note:: |
| 79 | +# While working with datapoint classes in your code, make sure to |
| 80 | +# familiarize yourself with this section: |
| 81 | +# :ref:`datapoint_unwrapping_behaviour` |
| 82 | +# |
| 83 | +# Supporting arbitrary input structures |
| 84 | +# ===================================== |
| 85 | +# |
| 86 | +# In the section above, we have assumed that you already know the structure of |
| 87 | +# your inputs and that you're OK with hard-coding this expected structure in |
| 88 | +# your code. If you want your custom transforms to be as flexible as possible, |
| 89 | +# this can be a bit limitting. |
| 90 | +# |
| 91 | +# A key feature of the builtin Torchvision V2 transforms is that they can accept |
| 92 | +# arbitrary input structure and return the same structure as output (with |
| 93 | +# transformed entries). For example, transforms can accept a single image, or a |
| 94 | +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: |
| 95 | + |
| 96 | +structured_input = { |
| 97 | + "img": img, |
| 98 | + "annotations": (bboxes, label), |
| 99 | + "something_that_will_be_ignored": (1, "hello") |
| 100 | +} |
| 101 | +structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) |
| 102 | + |
| 103 | +assert isinstance(structured_output, dict) |
| 104 | +assert structured_output["something_that_will_be_ignored"] == (1, "hello") |
| 105 | +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") |
| 106 | + |
| 107 | +# %% |
| 108 | +# If you want to reproduce this behavior in your own transform, we invite you to |
| 109 | +# look at our `code |
| 110 | +# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ |
| 111 | +# and adapt it to your needs. |
| 112 | +# |
| 113 | +# In brief, the core logic is to unpack the input into a flat list using `pytree |
| 114 | +# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and |
| 115 | +# then transform only the entries that can be transformed (the decision is made |
| 116 | +# based on the **class** of the entries, as all datapoints are |
| 117 | +# tensor-subclasses) plus some custom logic that is out of score here - check the |
| 118 | +# code for details. The (potentially transformed) entries are then repacked and |
| 119 | +# returned, in the same structure as the input. |
| 120 | +# |
| 121 | +# We do not provide public dev-facing tools to achieve that at this time, but if |
| 122 | +# this is something that would be valuable to you, please let us know by opening |
| 123 | +# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. |
0 commit comments