Skip to content

Commit 69e4261

Browse files
authored
Merge branch 'main' into turboooo
2 parents a0680da + 641fdd9 commit 69e4261

40 files changed

+1330
-1129
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
320320
used within the autoclass directive.
321321
"""
322322

323-
if obj.__name__.endswith(("_Weights", "_QuantizedWeights")):
323+
if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")):
324324

325325
if len(obj) == 0:
326326
lines[:] = ["There are no available pre-trained weights."]

docs/source/datapoints.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
1717
BoundingBoxFormat
1818
BoundingBoxes
1919
Mask
20+
Datapoint

docs/source/transforms.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ Color
155155

156156
ColorJitter
157157
v2.ColorJitter
158+
v2.RandomChannelPermutation
158159
v2.RandomPhotometricDistort
159160
Grayscale
160161
v2.Grayscale
@@ -375,3 +376,14 @@ you can use a functional transform to build transform classes with custom behavi
375376
to_pil_image
376377
to_tensor
377378
vflip
379+
380+
Developer tools
381+
---------------
382+
383+
.. currentmodule:: torchvision.transforms.v2.functional
384+
385+
.. autosummary::
386+
:toctree: generated/
387+
:template: function.rst
388+
389+
register_kernel

gallery/plot_custom_datapoints.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
=====================================
3+
How to write your own Datapoint class
4+
=====================================
5+
6+
This guide is intended for downstream library maintainers. We explain how to
7+
write your own datapoint class, and how to make it compatible with the built-in
8+
Torchvision v2 transforms. Before continuing, make sure you have read
9+
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
10+
"""
11+
12+
# %%
13+
import torch
14+
import torchvision
15+
16+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
17+
# some APIs may slightly change in the future
18+
torchvision.disable_beta_transforms_warning()
19+
20+
from torchvision import datapoints
21+
from torchvision.transforms import v2
22+
23+
# %%
24+
# We will create a very simple class that just inherits from the base
25+
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
26+
# what you need to know to implement your more elaborate uses-cases. If you need
27+
# to create a class that carries meta-data, take a look at how the
28+
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
29+
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
30+
31+
32+
class MyDatapoint(datapoints.Datapoint):
33+
pass
34+
35+
36+
my_dp = MyDatapoint([1, 2, 3])
37+
my_dp
38+
39+
# %%
40+
# Now that we have defined our custom Datapoint class, we want it to be
41+
# compatible with the built-in torchvision transforms, and the functional API.
42+
# For that, we need to implement a kernel which performs the core of the
43+
# transformation, and then "hook" it to the functional that we want to support
44+
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
45+
#
46+
# We illustrate this process below: we create a kernel for the "horizontal flip"
47+
# operation of our MyDatapoint class, and register it to the functional API.
48+
49+
from torchvision.transforms.v2 import functional as F
50+
51+
52+
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
53+
def hflip_my_datapoint(my_dp, *args, **kwargs):
54+
print("Flipping!")
55+
out = my_dp.flip(-1)
56+
return MyDatapoint.wrap_like(my_dp, out)
57+
58+
59+
# %%
60+
# To understand why ``wrap_like`` is used, see
61+
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
62+
# we will explain it below in :ref:`param_forwarding`.
63+
#
64+
# .. note::
65+
#
66+
# In our call to ``register_kernel`` above we used a string
67+
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
68+
# could also have used the functional *itself*, i.e.
69+
# ``@register_kernel(dispatcher=F.hflip, ...)``.
70+
#
71+
# The functionals that you can be hooked into are the ones in
72+
# ``torchvision.transforms.v2.functional`` and they are documented in
73+
# :ref:`functional_transforms`.
74+
#
75+
# Now that we have registered our kernel, we can call the functional API on a
76+
# ``MyDatapoint`` instance:
77+
78+
my_dp = MyDatapoint(torch.rand(3, 256, 256))
79+
_ = F.hflip(my_dp)
80+
81+
# %%
82+
# And we can also use the
83+
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
84+
t = v2.RandomHorizontalFlip(p=1)
85+
_ = t(my_dp)
86+
87+
# %%
88+
# .. note::
89+
#
90+
# We cannot register a kernel for a transform class, we can only register a
91+
# kernel for a **functional**. The reason we can't register a transform
92+
# class is because one transform may internally rely on more than one
93+
# functional, so in general we can't register a single kernel for a given
94+
# class.
95+
#
96+
# .. _param_forwarding:
97+
#
98+
# Parameter forwarding, and ensuring future compatibility of your kernels
99+
# -----------------------------------------------------------------------
100+
#
101+
# The functional API that you're hooking into is public and therefore
102+
# **backward** compatible: we guarantee that the parameters of these functionals
103+
# won't be removed or renamed without a proper deprecation cycle. However, we
104+
# don't guarantee **forward** compatibility, and we may add new parameters in
105+
# the future.
106+
#
107+
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
108+
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
109+
# already defined and registered your own kernel as
110+
111+
def hflip_my_datapoint(my_dp): # noqa
112+
print("Flipping!")
113+
out = my_dp.flip(-1)
114+
return MyDatapoint.wrap_like(my_dp, out)
115+
116+
117+
# %%
118+
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
119+
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
120+
# accept it.
121+
#
122+
# For this reason, we recommend to always define your kernels with
123+
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
124+
# will be able to accept any new parameter that we may add in the future.
125+
# (Technically, adding `**kwargs` only should be enough).

gallery/plot_custom_transforms.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
===================================
3+
How to write your own v2 transforms
4+
===================================
5+
6+
This guide explains how to write transforms that are compatible with the
7+
torchvision transforms V2 API.
8+
"""
9+
10+
# %%
11+
import torch
12+
import torchvision
13+
14+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
15+
# some APIs may slightly change in the future
16+
torchvision.disable_beta_transforms_warning()
17+
18+
from torchvision import datapoints
19+
from torchvision.transforms import v2
20+
21+
22+
# %%
23+
# Just create a ``nn.Module`` and override the ``forward`` method
24+
# ===============================================================
25+
#
26+
# In most cases, this is all you're going to need, as long as you already know
27+
# the structure of the input that your transform will expect. For example if
28+
# you're just doing image classification, your transform will typically accept a
29+
# single image as input, or a ``(img, label)`` input. So you can just hard-code
30+
# your ``forward`` method to accept just that, e.g.
31+
#
32+
# .. code:: python
33+
#
34+
# class MyCustomTransform(torch.nn.Module):
35+
# def forward(self, img, label):
36+
# # Do some transformations
37+
# return new_img, new_label
38+
#
39+
# .. note::
40+
#
41+
# This means that if you have a custom transform that is already compatible
42+
# with the V1 transforms (those in ``torchvision.transforms``), it will
43+
# still work with the V2 transforms without any change!
44+
#
45+
# We will illustrate this more completely below with a typical detection case,
46+
# where our samples are just images, bounding boxes and labels:
47+
48+
class MyCustomTransform(torch.nn.Module):
49+
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
50+
print(
51+
f"I'm transforming an image of shape {img.shape} "
52+
f"with bboxes = {bboxes}\n{label = }"
53+
)
54+
# Do some transformations. Here, we're just passing though the input
55+
return img, bboxes, label
56+
57+
58+
transforms = v2.Compose([
59+
MyCustomTransform(),
60+
v2.RandomResizedCrop((224, 224), antialias=True),
61+
v2.RandomHorizontalFlip(p=1),
62+
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
63+
])
64+
65+
H, W = 256, 256
66+
img = torch.rand(3, H, W)
67+
bboxes = datapoints.BoundingBoxes(
68+
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
69+
format="XYXY",
70+
canvas_size=(H, W)
71+
)
72+
label = 3
73+
74+
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
75+
# %%
76+
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
77+
# %%
78+
# .. note::
79+
# While working with datapoint classes in your code, make sure to
80+
# familiarize yourself with this section:
81+
# :ref:`datapoint_unwrapping_behaviour`
82+
#
83+
# Supporting arbitrary input structures
84+
# =====================================
85+
#
86+
# In the section above, we have assumed that you already know the structure of
87+
# your inputs and that you're OK with hard-coding this expected structure in
88+
# your code. If you want your custom transforms to be as flexible as possible,
89+
# this can be a bit limitting.
90+
#
91+
# A key feature of the builtin Torchvision V2 transforms is that they can accept
92+
# arbitrary input structure and return the same structure as output (with
93+
# transformed entries). For example, transforms can accept a single image, or a
94+
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
95+
96+
structured_input = {
97+
"img": img,
98+
"annotations": (bboxes, label),
99+
"something_that_will_be_ignored": (1, "hello")
100+
}
101+
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
102+
103+
assert isinstance(structured_output, dict)
104+
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
105+
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
106+
107+
# %%
108+
# If you want to reproduce this behavior in your own transform, we invite you to
109+
# look at our `code
110+
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
111+
# and adapt it to your needs.
112+
#
113+
# In brief, the core logic is to unpack the input into a flat list using `pytree
114+
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
115+
# then transform only the entries that can be transformed (the decision is made
116+
# based on the **class** of the entries, as all datapoints are
117+
# tensor-subclasses) plus some custom logic that is out of score here - check the
118+
# code for details. The (potentially transformed) entries are then repacked and
119+
# returned, in the same structure as the input.
120+
#
121+
# We do not provide public dev-facing tools to achieve that at this time, but if
122+
# this is something that would be valuable to you, please let us know by opening
123+
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.

0 commit comments

Comments
 (0)