Skip to content

Commit 1120aa9

Browse files
authored
introduce heuristic for simple tensor handling of transforms v2 (#7170)
1 parent 1222b49 commit 1120aa9

File tree

3 files changed

+250
-101
lines changed

3 files changed

+250
-101
lines changed

test/test_prototype_transforms.py

Lines changed: 200 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import itertools
2+
import re
23

34
import numpy as np
45

56
import PIL.Image
6-
77
import pytest
88
import torch
99

1010
import torchvision.prototype.transforms.utils
11-
from common_utils import assert_equal, cpu_and_gpu
11+
from common_utils import cpu_and_gpu
1212
from prototype_common_utils import (
13+
assert_equal,
1314
DEFAULT_EXTRA_DIMS,
1415
make_bounding_box,
1516
make_bounding_boxes,
@@ -25,7 +26,7 @@
2526
)
2627
from torchvision.ops.boxes import box_iou
2728
from torchvision.prototype import datapoints, transforms
28-
from torchvision.prototype.transforms.utils import check_type
29+
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
2930
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
3031

3132
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
@@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input):
222223
transform(input)
223224

224225

226+
@pytest.mark.parametrize(
227+
"flat_inputs",
228+
itertools.permutations(
229+
[
230+
next(make_vanilla_tensor_images()),
231+
next(make_vanilla_tensor_images()),
232+
next(make_pil_images()),
233+
make_image(),
234+
next(make_videos()),
235+
],
236+
3,
237+
),
238+
)
239+
def test_simple_tensor_heuristic(flat_inputs):
240+
def split_on_simple_tensor(to_split):
241+
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
242+
# 1. The first simple tensor. If none is present, this will be `None`
243+
# 2. A list of the remaining simple tensors
244+
# 3. A list of all other items
245+
simple_tensors = []
246+
others = []
247+
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
248+
# affect the splitting.
249+
for item, inpt in zip(to_split, flat_inputs):
250+
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
251+
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others
252+
253+
class CopyCloneTransform(transforms.Transform):
254+
def _transform(self, inpt, params):
255+
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()
256+
257+
@staticmethod
258+
def was_applied(output, inpt):
259+
identity = output is inpt
260+
if identity:
261+
return False
262+
263+
# Make sure nothing fishy is going on
264+
assert_equal(output, inpt)
265+
return True
266+
267+
first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
268+
269+
transform = CopyCloneTransform()
270+
transformed_sample = transform(flat_inputs)
271+
272+
first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)
273+
274+
if first_simple_tensor_input is not None:
275+
if other_inputs:
276+
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
277+
else:
278+
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
279+
280+
for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
281+
assert not transform.was_applied(output, inpt)
282+
283+
for input, output in zip(other_inputs, other_outputs):
284+
assert transform.was_applied(output, input)
285+
286+
225287
@pytest.mark.parametrize("p", [0.0, 1.0])
226288
class TestRandomHorizontalFlip:
227289
def input_expected_image_tensor(self, p, dtype=torch.float32):
@@ -1755,117 +1817,158 @@ def test__transform(self, mocker):
17551817
)
17561818

17571819

1758-
@pytest.mark.parametrize(
1759-
("dtype", "expected_dtypes"),
1760-
[
1761-
(
1762-
torch.float64,
1763-
{torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
1764-
),
1765-
(
1766-
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1767-
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1768-
),
1769-
],
1770-
)
1771-
def test_to_dtype(dtype, expected_dtypes):
1772-
sample = dict(
1773-
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
1774-
image=make_image(dtype=torch.uint8),
1775-
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
1776-
str="str",
1777-
int=0,
1820+
class TestToDtype:
1821+
@pytest.mark.parametrize(
1822+
("dtype", "expected_dtypes"),
1823+
[
1824+
(
1825+
torch.float64,
1826+
{
1827+
datapoints.Video: torch.float64,
1828+
datapoints.Image: torch.float64,
1829+
datapoints.BoundingBox: torch.float64,
1830+
},
1831+
),
1832+
(
1833+
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1834+
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
1835+
),
1836+
],
17781837
)
1838+
def test_call(self, dtype, expected_dtypes):
1839+
sample = dict(
1840+
video=make_video(dtype=torch.int64),
1841+
image=make_image(dtype=torch.uint8),
1842+
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
1843+
str="str",
1844+
int=0,
1845+
)
17791846

1780-
transform = transforms.ToDtype(dtype)
1781-
transformed_sample = transform(sample)
1847+
transform = transforms.ToDtype(dtype)
1848+
transformed_sample = transform(sample)
17821849

1783-
for key, value in sample.items():
1784-
value_type = type(value)
1785-
transformed_value = transformed_sample[key]
1850+
for key, value in sample.items():
1851+
value_type = type(value)
1852+
transformed_value = transformed_sample[key]
17861853

1787-
# make sure the transformation retains the type
1788-
assert isinstance(transformed_value, value_type)
1854+
# make sure the transformation retains the type
1855+
assert isinstance(transformed_value, value_type)
17891856

1790-
if isinstance(value, torch.Tensor):
1791-
assert transformed_value.dtype is expected_dtypes[value_type]
1792-
else:
1793-
assert transformed_value is value
1857+
if isinstance(value, torch.Tensor):
1858+
assert transformed_value.dtype is expected_dtypes[value_type]
1859+
else:
1860+
assert transformed_value is value
17941861

1862+
@pytest.mark.filterwarnings("error")
1863+
def test_plain_tensor_call(self):
1864+
tensor = torch.empty((), dtype=torch.float32)
1865+
transform = transforms.ToDtype({torch.Tensor: torch.float64})
17951866

1796-
@pytest.mark.parametrize(
1797-
("dims", "inverse_dims"),
1798-
[
1799-
(
1800-
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None},
1801-
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None},
1802-
),
1803-
(
1804-
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
1805-
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
1806-
),
1807-
],
1808-
)
1809-
def test_permute_dimensions(dims, inverse_dims):
1810-
sample = dict(
1811-
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
1812-
image=make_image(),
1813-
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
1814-
video=make_video(),
1815-
str="str",
1816-
int=0,
1867+
assert transform(tensor).dtype is torch.float64
1868+
1869+
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
1870+
def test_plain_tensor_warning(self, other_type):
1871+
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
1872+
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})
1873+
1874+
1875+
class TestPermuteDimensions:
1876+
@pytest.mark.parametrize(
1877+
("dims", "inverse_dims"),
1878+
[
1879+
(
1880+
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
1881+
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
1882+
),
1883+
(
1884+
{datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
1885+
{datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
1886+
),
1887+
],
18171888
)
1889+
def test_call(self, dims, inverse_dims):
1890+
sample = dict(
1891+
image=make_image(),
1892+
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
1893+
video=make_video(),
1894+
str="str",
1895+
int=0,
1896+
)
18181897

1819-
transform = transforms.PermuteDimensions(dims)
1820-
transformed_sample = transform(sample)
1898+
transform = transforms.PermuteDimensions(dims)
1899+
transformed_sample = transform(sample)
18211900

1822-
for key, value in sample.items():
1823-
value_type = type(value)
1824-
transformed_value = transformed_sample[key]
1901+
for key, value in sample.items():
1902+
value_type = type(value)
1903+
transformed_value = transformed_sample[key]
18251904

1826-
if check_type(
1827-
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
1828-
):
1829-
if transform.dims.get(value_type) is not None:
1830-
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
1831-
assert type(transformed_value) == torch.Tensor
1832-
else:
1833-
assert transformed_value is value
1905+
if check_type(
1906+
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
1907+
):
1908+
if transform.dims.get(value_type) is not None:
1909+
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
1910+
assert type(transformed_value) == torch.Tensor
1911+
else:
1912+
assert transformed_value is value
18341913

1914+
@pytest.mark.filterwarnings("error")
1915+
def test_plain_tensor_call(self):
1916+
tensor = torch.empty((2, 3, 4))
1917+
transform = transforms.PermuteDimensions(dims=(1, 2, 0))
18351918

1836-
@pytest.mark.parametrize(
1837-
"dims",
1838-
[
1839-
(-1, -2),
1840-
{torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None},
1841-
],
1842-
)
1843-
def test_transpose_dimensions(dims):
1844-
sample = dict(
1845-
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
1846-
image=make_image(),
1847-
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
1848-
video=make_video(),
1849-
str="str",
1850-
int=0,
1919+
assert transform(tensor).shape == (3, 4, 2)
1920+
1921+
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
1922+
def test_plain_tensor_warning(self, other_type):
1923+
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
1924+
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
1925+
1926+
1927+
class TestTransposeDimensions:
1928+
@pytest.mark.parametrize(
1929+
"dims",
1930+
[
1931+
(-1, -2),
1932+
{datapoints.Image: (1, 2), datapoints.Video: None},
1933+
],
18511934
)
1935+
def test_call(self, dims):
1936+
sample = dict(
1937+
image=make_image(),
1938+
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
1939+
video=make_video(),
1940+
str="str",
1941+
int=0,
1942+
)
18521943

1853-
transform = transforms.TransposeDimensions(dims)
1854-
transformed_sample = transform(sample)
1944+
transform = transforms.TransposeDimensions(dims)
1945+
transformed_sample = transform(sample)
18551946

1856-
for key, value in sample.items():
1857-
value_type = type(value)
1858-
transformed_value = transformed_sample[key]
1947+
for key, value in sample.items():
1948+
value_type = type(value)
1949+
transformed_value = transformed_sample[key]
18591950

1860-
transposed_dims = transform.dims.get(value_type)
1861-
if check_type(
1862-
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
1863-
):
1864-
if transposed_dims is not None:
1865-
assert transformed_value.transpose(*transposed_dims).equal(value)
1866-
assert type(transformed_value) == torch.Tensor
1867-
else:
1868-
assert transformed_value is value
1951+
transposed_dims = transform.dims.get(value_type)
1952+
if check_type(
1953+
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
1954+
):
1955+
if transposed_dims is not None:
1956+
assert transformed_value.transpose(*transposed_dims).equal(value)
1957+
assert type(transformed_value) == torch.Tensor
1958+
else:
1959+
assert transformed_value is value
1960+
1961+
@pytest.mark.filterwarnings("error")
1962+
def test_plain_tensor_call(self):
1963+
tensor = torch.empty((2, 3, 4))
1964+
transform = transforms.TransposeDimensions(dims=(0, 2))
1965+
1966+
assert transform(tensor).shape == (4, 3, 2)
1967+
1968+
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
1969+
def test_plain_tensor_warning(self, other_type):
1970+
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
1971+
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
18691972

18701973

18711974
class TestUniformTemporalSubsample:

torchvision/prototype/transforms/_misc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
23

34
import PIL.Image
@@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]])
155156
super().__init__()
156157
if not isinstance(dtype, dict):
157158
dtype = _get_defaultdict(dtype)
159+
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
160+
warnings.warn(
161+
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
162+
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
163+
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
164+
)
158165
self.dtype = dtype
159166

160167
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
171178
super().__init__()
172179
if not isinstance(dims, dict):
173180
dims = _get_defaultdict(dims)
181+
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
182+
warnings.warn(
183+
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
184+
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
185+
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
186+
)
174187
self.dims = dims
175188

176189
def _transform(
@@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
189202
super().__init__()
190203
if not isinstance(dims, dict):
191204
dims = _get_defaultdict(dims)
205+
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
206+
warnings.warn(
207+
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
208+
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
209+
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
210+
)
192211
self.dims = dims
193212

194213
def _transform(

0 commit comments

Comments
 (0)