Skip to content

Commit 6325494

Browse files
authored
Merge branch 'main' into issue-6538-support-cmyk
2 parents 54da7b2 + 9040793 commit 6325494

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1148
-845
lines changed

docs/source/datapoints.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
1818
BoundingBoxes
1919
Mask
2020
Datapoint
21+
set_return_type
22+
wrap

docs/source/transforms.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,11 @@ Conversion
228228

229229
ToPILImage
230230
v2.ToPILImage
231-
v2.ToImagePIL
232231
ToTensor
233232
v2.ToTensor
234233
PILToTensor
235234
v2.PILToTensor
236-
v2.ToImageTensor
235+
v2.ToImage
237236
ConvertImageDtype
238237
v2.ConvertImageDtype
239238
v2.ToDtype

gallery/plot_custom_datapoints.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
How to write your own Datapoint class
44
=====================================
55
6-
This guide is intended for downstream library maintainers. We explain how to
6+
This guide is intended for advanced users and downstream library maintainers. We explain how to
77
write your own datapoint class, and how to make it compatible with the built-in
88
Torchvision v2 transforms. Before continuing, make sure you have read
99
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
@@ -49,28 +49,24 @@ class MyDatapoint(datapoints.Datapoint):
4949
from torchvision.transforms.v2 import functional as F
5050

5151

52-
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
52+
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
5353
def hflip_my_datapoint(my_dp, *args, **kwargs):
5454
print("Flipping!")
5555
out = my_dp.flip(-1)
56-
return MyDatapoint.wrap_like(my_dp, out)
56+
return datapoints.wrap(out, like=my_dp)
5757

5858

5959
# %%
60-
# To understand why ``wrap_like`` is used, see
60+
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
6161
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
6262
# we will explain it below in :ref:`param_forwarding`.
6363
#
6464
# .. note::
6565
#
6666
# In our call to ``register_kernel`` above we used a string
67-
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
67+
# ``functional="hflip"`` to refer to the functional we want to hook into. We
6868
# could also have used the functional *itself*, i.e.
69-
# ``@register_kernel(dispatcher=F.hflip, ...)``.
70-
#
71-
# The functionals that you can be hooked into are the ones in
72-
# ``torchvision.transforms.v2.functional`` and they are documented in
73-
# :ref:`functional_transforms`.
69+
# ``@register_kernel(functional=F.hflip, ...)``.
7470
#
7571
# Now that we have registered our kernel, we can call the functional API on a
7672
# ``MyDatapoint`` instance:
@@ -111,7 +107,7 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
111107
def hflip_my_datapoint(my_dp): # noqa
112108
print("Flipping!")
113109
out = my_dp.flip(-1)
114-
return MyDatapoint.wrap_like(my_dp, out)
110+
return datapoints.wrap(out, like=my_dp)
115111

116112

117113
# %%

gallery/plot_datapoints.py

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,22 @@
4848
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
4949
# for the input data.
5050
#
51+
# :mod:`torchvision.datapoints` supports four types of datapoints:
52+
#
53+
# * :class:`~torchvision.datapoints.Image`
54+
# * :class:`~torchvision.datapoints.Video`
55+
# * :class:`~torchvision.datapoints.BoundingBoxes`
56+
# * :class:`~torchvision.datapoints.Mask`
57+
#
5158
# What can I do with a datapoint?
5259
# -------------------------------
5360
#
5461
# Datapoints look and feel just like regular tensors - they **are** tensors.
5562
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
56-
# any ``torch.*`` operator will also works on datapoints. See
63+
# any ``torch.*`` operator will also work on datapoints. See
5764
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.
5865

5966
# %%
60-
#
61-
# What datapoints are supported?
62-
# ------------------------------
63-
#
64-
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
65-
#
66-
# * :class:`~torchvision.datapoints.Image`
67-
# * :class:`~torchvision.datapoints.Video`
68-
# * :class:`~torchvision.datapoints.BoundingBoxes`
69-
# * :class:`~torchvision.datapoints.Mask`
70-
#
7167
# .. _datapoint_creation:
7268
#
7369
# How do I construct a datapoint?
@@ -111,26 +107,23 @@
111107
print(bboxes)
112108

113109
# %%
114-
# Using the ``wrap_like()`` class method
115-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110+
# Using ``datapoints.wrap()``
111+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
116112
#
117-
# You can also use the ``wrap_like()`` class method to wrap a tensor object
113+
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
118114
# into a datapoint. This is useful when you already have an object of the
119115
# desired type, which typically happens when writing transforms: you just want
120-
# to wrap the output like the input. This API is inspired by utils like
121-
# :func:`torch.zeros_like`:
116+
# to wrap the output like the input.
122117

123118
new_bboxes = torch.tensor([0, 20, 30, 40])
124-
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
119+
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
125120
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
126121
assert new_bboxes.canvas_size == bboxes.canvas_size
127122

128123

129124
# %%
130125
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
131-
# it as a parameter to override it. Check the
132-
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
133-
# more details.
126+
# it as a parameter to override it.
134127
#
135128
# Do I have to wrap the output of the datasets myself?
136129
# ----------------------------------------------------
@@ -209,42 +202,78 @@ def get_transform(train):
209202
# I had a Datapoint but now I have a Tensor. Help!
210203
# ------------------------------------------------
211204
#
212-
# For a lot of operations involving datapoints, we cannot safely infer whether
213-
# the result should retain the datapoint type, so we choose to return a plain
214-
# tensor instead of a datapoint (this might change, see note below):
205+
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
206+
# will return a pure Tensor:
215207

216208

217209
assert isinstance(bboxes, datapoints.BoundingBoxes)
218210

219211
# Shift bboxes by 3 pixels in both H and W
220212
new_bboxes = bboxes + 3
221213

222-
assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes)
214+
assert isinstance(new_bboxes, torch.Tensor)
215+
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)
216+
217+
# %%
218+
# .. note::
219+
#
220+
# This behavior only affects native ``torch`` operations. If you are using
221+
# the built-in ``torchvision`` transforms or functionals, you will always get
222+
# as output the same type that you passed as input (pure ``Tensor`` or
223+
# ``Datapoint``).
223224

224225
# %%
225-
# If you're writing your own custom transforms or code involving datapoints, you
226-
# can re-wrap the output into a datapoint by just calling their constructor, or
227-
# by using the ``.wrap_like()`` class method:
226+
# But I want a Datapoint back!
227+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
228+
#
229+
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
230+
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function
231+
# (see more details above in :ref:`datapoint_creation`):
228232

229233
new_bboxes = bboxes + 3
230-
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
234+
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
231235
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
232236

233237
# %%
234-
# See more details above in :ref:`datapoint_creation`.
238+
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
239+
# as a global config setting for the whole program, or as a context manager:
240+
241+
with datapoints.set_return_type("datapoint"):
242+
new_bboxes = bboxes + 3
243+
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
244+
245+
# %%
246+
# Why is this happening?
247+
# ^^^^^^^^^^^^^^^^^^^^^^
235248
#
236-
# .. note::
249+
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
250+
# classes are Tensor subclasses, so any operation involving a
251+
# :class:`~torchvision.datapoints.Datapoint` object will go through the
252+
# `__torch_function__
253+
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
254+
# protocol. This induces a small overhead, which we want to avoid when possible.
255+
# This doesn't matter for built-in ``torchvision`` transforms because we can
256+
# avoid the overhead there, but it could be a problem in your model's
257+
# ``forward``.
237258
#
238-
# You never need to re-wrap manually if you're using the built-in transforms
239-
# or their functional equivalents: this is automatically taken care of for
240-
# you.
259+
# **The alternative isn't much better anyway.** For every operation where
260+
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
261+
# sense, there are just as many operations where returning a pure Tensor is
262+
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
263+
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
264+
# the way, even model's logits or the output of the loss function would end up
265+
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
266+
# desirable.
241267
#
242268
# .. note::
243269
#
244-
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
270+
# This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
245271
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
246272
# https://github.com/pytorch/vision/issues/7319
247273
#
274+
# Exceptions
275+
# ^^^^^^^^^^
276+
#
248277
# There are a few exceptions to this "unwrapping" rule:
249278
#
250279
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,

gallery/plot_transforms_v2_e2e.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
66
Object detection is not supported out of the box by ``torchvision.transforms`` v1, since it only supports images.
77
``torchvision.transforms.v2`` enables jointly transforming images, videos, bounding boxes, and masks. This example
8-
showcases an end-to-end object detection training using the stable ``torchvisio.datasets`` and ``torchvision.models`` as
9-
well as the new ``torchvision.transforms.v2`` v2 API.
8+
showcases an end-to-end object detection training using the stable ``torchvision.datasets`` and ``torchvision.models``
9+
as well as the new ``torchvision.transforms.v2`` v2 API.
1010
"""
1111

1212
import pathlib
@@ -27,7 +27,7 @@ def show(sample):
2727

2828
image, target = sample
2929
if isinstance(image, PIL.Image.Image):
30-
image = F.to_image_tensor(image)
30+
image = F.to_image(image)
3131
image = F.to_dtype(image, torch.uint8, scale=True)
3232
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
3333

@@ -101,7 +101,7 @@ def load_example_coco_detection_dataset(**kwargs):
101101
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
102102
transforms.RandomIoUCrop(),
103103
transforms.RandomHorizontalFlip(),
104-
transforms.ToImageTensor(),
104+
transforms.ToImage(),
105105
transforms.ConvertImageDtype(torch.float32),
106106
transforms.SanitizeBoundingBoxes(),
107107
]

packaging/pre_build_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ if [[ "$(uname)" == Darwin ]]; then
1111
conda install -yq wget
1212
fi
1313

14-
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then
14+
if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" || "$ARCH" == "aarch64" ]]; then
1515
# Install libpng from Anaconda (defaults)
1616
conda install libpng -yq
1717
conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch

references/detection/presets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
transforms = []
3434
backend = backend.lower()
3535
if backend == "datapoint":
36-
transforms.append(T.ToImageTensor())
36+
transforms.append(T.ToImage())
3737
elif backend == "tensor":
3838
transforms.append(T.PILToTensor())
3939
elif backend != "pil":
@@ -71,7 +71,7 @@ def __init__(
7171

7272
if backend == "pil":
7373
# Note: we could just convert to pure tensors even in v2.
74-
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
74+
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
7575

7676
transforms += [T.ConvertImageDtype(torch.float)]
7777

@@ -94,11 +94,11 @@ def __init__(self, backend="pil", use_v2=False):
9494
backend = backend.lower()
9595
if backend == "pil":
9696
# Note: we could just convert to pure tensors even in v2?
97-
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
97+
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
9898
elif backend == "tensor":
9999
transforms += [T.PILToTensor()]
100100
elif backend == "datapoint":
101-
transforms += [T.ToImageTensor()]
101+
transforms += [T.ToImage()]
102102
else:
103103
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
104104

references/segmentation/presets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
transforms = []
3333
backend = backend.lower()
3434
if backend == "datapoint":
35-
transforms.append(T.ToImageTensor())
35+
transforms.append(T.ToImage())
3636
elif backend == "tensor":
3737
transforms.append(T.PILToTensor())
3838
elif backend != "pil":
@@ -81,7 +81,7 @@ def __init__(
8181
if backend == "tensor":
8282
transforms += [T.PILToTensor()]
8383
elif backend == "datapoint":
84-
transforms += [T.ToImageTensor()]
84+
transforms += [T.ToImage()]
8585
elif backend != "pil":
8686
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
8787

@@ -92,7 +92,7 @@ def __init__(
9292

9393
if backend == "pil":
9494
# Note: we could just convert to pure tensors even in v2?
95-
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
95+
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
9696

9797
transforms += [
9898
T.ConvertImageDtype(torch.float),

test/common_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2828
from torchvision import datapoints, io
2929
from torchvision.transforms._functional_tensor import _max_value as get_max_value
30-
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor
30+
from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
3131

3232

3333
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -293,7 +293,7 @@ def __init__(
293293
**other_parameters,
294294
):
295295
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
296-
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
296+
actual, expected = [to_image(input) for input in [actual, expected]]
297297

298298
super().__init__(actual, expected, **other_parameters)
299299
self.mae = mae
@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
536536

537537

538538
def make_image_pil(*args, **kwargs):
539-
return to_image_pil(make_image(*args, **kwargs))
539+
return to_pil_image(make_image(*args, **kwargs))
540540

541541

542542
def make_image_loader(
@@ -609,12 +609,12 @@ def fn(shape, dtype, device, memory_format):
609609
)
610610
)
611611

612-
image_tensor = to_image_tensor(image_pil)
612+
image_tensor = to_image(image_pil)
613613
if memory_format == torch.contiguous_format:
614614
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
615615
else:
616616
image_tensor = image_tensor.to(device=device)
617-
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True)
617+
image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
618618

619619
return datapoints.Image(image_tensor)
620620

test/smoke_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,13 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
7878
def main() -> None:
7979
print(f"torchvision: {torchvision.__version__}")
8080
print(f"torch.cuda.is_available: {torch.cuda.is_available()}")
81-
print(f"{torch.ops.image._jpeg_version() = }")
82-
assert torch.ops.image._is_compiled_against_turbo()
81+
82+
# Turn 1.11.0aHASH into 1.11 (major.minor only)
83+
version = ".".join(torchvision.__version__.split(".")[:2])
84+
if version >= "0.16":
85+
print(f"{torch.ops.image._jpeg_version() = }")
86+
assert torch.ops.image._is_compiled_against_turbo()
87+
8388
smoke_test_torchvision()
8489
smoke_test_torchvision_read_decode()
8590
smoke_test_torchvision_resnet50_classify()

0 commit comments

Comments
 (0)