|
1 | 1 | import itertools
|
| 2 | +import re |
2 | 3 |
|
3 | 4 | import numpy as np
|
4 | 5 |
|
5 | 6 | import PIL.Image
|
6 |
| - |
7 | 7 | import pytest
|
8 | 8 | import torch
|
9 | 9 |
|
10 | 10 | import torchvision.prototype.transforms.utils
|
11 |
| -from common_utils import assert_equal, cpu_and_gpu |
| 11 | +from common_utils import cpu_and_gpu |
12 | 12 | from prototype_common_utils import (
|
| 13 | + assert_equal, |
13 | 14 | DEFAULT_EXTRA_DIMS,
|
14 | 15 | make_bounding_box,
|
15 | 16 | make_bounding_boxes,
|
|
25 | 26 | )
|
26 | 27 | from torchvision.ops.boxes import box_iou
|
27 | 28 | from torchvision.prototype import datapoints, transforms
|
28 |
| -from torchvision.prototype.transforms.utils import check_type |
| 29 | +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor |
29 | 30 | from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
|
30 | 31 |
|
31 | 32 | BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
|
@@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input):
|
222 | 223 | transform(input)
|
223 | 224 |
|
224 | 225 |
|
| 226 | +@pytest.mark.parametrize( |
| 227 | + "flat_inputs", |
| 228 | + itertools.permutations( |
| 229 | + [ |
| 230 | + next(make_vanilla_tensor_images()), |
| 231 | + next(make_vanilla_tensor_images()), |
| 232 | + next(make_pil_images()), |
| 233 | + make_image(), |
| 234 | + next(make_videos()), |
| 235 | + ], |
| 236 | + 3, |
| 237 | + ), |
| 238 | +) |
| 239 | +def test_simple_tensor_heuristic(flat_inputs): |
| 240 | + def split_on_simple_tensor(to_split): |
| 241 | + # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: |
| 242 | + # 1. The first simple tensor. If none is present, this will be `None` |
| 243 | + # 2. A list of the remaining simple tensors |
| 244 | + # 3. A list of all other items |
| 245 | + simple_tensors = [] |
| 246 | + others = [] |
| 247 | + # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to |
| 248 | + # affect the splitting. |
| 249 | + for item, inpt in zip(to_split, flat_inputs): |
| 250 | + (simple_tensors if is_simple_tensor(inpt) else others).append(item) |
| 251 | + return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others |
| 252 | + |
| 253 | + class CopyCloneTransform(transforms.Transform): |
| 254 | + def _transform(self, inpt, params): |
| 255 | + return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() |
| 256 | + |
| 257 | + @staticmethod |
| 258 | + def was_applied(output, inpt): |
| 259 | + identity = output is inpt |
| 260 | + if identity: |
| 261 | + return False |
| 262 | + |
| 263 | + # Make sure nothing fishy is going on |
| 264 | + assert_equal(output, inpt) |
| 265 | + return True |
| 266 | + |
| 267 | + first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs) |
| 268 | + |
| 269 | + transform = CopyCloneTransform() |
| 270 | + transformed_sample = transform(flat_inputs) |
| 271 | + |
| 272 | + first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) |
| 273 | + |
| 274 | + if first_simple_tensor_input is not None: |
| 275 | + if other_inputs: |
| 276 | + assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) |
| 277 | + else: |
| 278 | + assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) |
| 279 | + |
| 280 | + for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): |
| 281 | + assert not transform.was_applied(output, inpt) |
| 282 | + |
| 283 | + for input, output in zip(other_inputs, other_outputs): |
| 284 | + assert transform.was_applied(output, input) |
| 285 | + |
| 286 | + |
225 | 287 | @pytest.mark.parametrize("p", [0.0, 1.0])
|
226 | 288 | class TestRandomHorizontalFlip:
|
227 | 289 | def input_expected_image_tensor(self, p, dtype=torch.float32):
|
@@ -1755,117 +1817,158 @@ def test__transform(self, mocker):
|
1755 | 1817 | )
|
1756 | 1818 |
|
1757 | 1819 |
|
1758 |
| -@pytest.mark.parametrize( |
1759 |
| - ("dtype", "expected_dtypes"), |
1760 |
| - [ |
1761 |
| - ( |
1762 |
| - torch.float64, |
1763 |
| - {torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, |
1764 |
| - ), |
1765 |
| - ( |
1766 |
| - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, |
1767 |
| - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, |
1768 |
| - ), |
1769 |
| - ], |
1770 |
| -) |
1771 |
| -def test_to_dtype(dtype, expected_dtypes): |
1772 |
| - sample = dict( |
1773 |
| - plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), |
1774 |
| - image=make_image(dtype=torch.uint8), |
1775 |
| - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), |
1776 |
| - str="str", |
1777 |
| - int=0, |
| 1820 | +class TestToDtype: |
| 1821 | + @pytest.mark.parametrize( |
| 1822 | + ("dtype", "expected_dtypes"), |
| 1823 | + [ |
| 1824 | + ( |
| 1825 | + torch.float64, |
| 1826 | + { |
| 1827 | + datapoints.Video: torch.float64, |
| 1828 | + datapoints.Image: torch.float64, |
| 1829 | + datapoints.BoundingBox: torch.float64, |
| 1830 | + }, |
| 1831 | + ), |
| 1832 | + ( |
| 1833 | + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, |
| 1834 | + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, |
| 1835 | + ), |
| 1836 | + ], |
1778 | 1837 | )
|
| 1838 | + def test_call(self, dtype, expected_dtypes): |
| 1839 | + sample = dict( |
| 1840 | + video=make_video(dtype=torch.int64), |
| 1841 | + image=make_image(dtype=torch.uint8), |
| 1842 | + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), |
| 1843 | + str="str", |
| 1844 | + int=0, |
| 1845 | + ) |
1779 | 1846 |
|
1780 |
| - transform = transforms.ToDtype(dtype) |
1781 |
| - transformed_sample = transform(sample) |
| 1847 | + transform = transforms.ToDtype(dtype) |
| 1848 | + transformed_sample = transform(sample) |
1782 | 1849 |
|
1783 |
| - for key, value in sample.items(): |
1784 |
| - value_type = type(value) |
1785 |
| - transformed_value = transformed_sample[key] |
| 1850 | + for key, value in sample.items(): |
| 1851 | + value_type = type(value) |
| 1852 | + transformed_value = transformed_sample[key] |
1786 | 1853 |
|
1787 |
| - # make sure the transformation retains the type |
1788 |
| - assert isinstance(transformed_value, value_type) |
| 1854 | + # make sure the transformation retains the type |
| 1855 | + assert isinstance(transformed_value, value_type) |
1789 | 1856 |
|
1790 |
| - if isinstance(value, torch.Tensor): |
1791 |
| - assert transformed_value.dtype is expected_dtypes[value_type] |
1792 |
| - else: |
1793 |
| - assert transformed_value is value |
| 1857 | + if isinstance(value, torch.Tensor): |
| 1858 | + assert transformed_value.dtype is expected_dtypes[value_type] |
| 1859 | + else: |
| 1860 | + assert transformed_value is value |
1794 | 1861 |
|
| 1862 | + @pytest.mark.filterwarnings("error") |
| 1863 | + def test_plain_tensor_call(self): |
| 1864 | + tensor = torch.empty((), dtype=torch.float32) |
| 1865 | + transform = transforms.ToDtype({torch.Tensor: torch.float64}) |
1795 | 1866 |
|
1796 |
| -@pytest.mark.parametrize( |
1797 |
| - ("dims", "inverse_dims"), |
1798 |
| - [ |
1799 |
| - ( |
1800 |
| - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None}, |
1801 |
| - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None}, |
1802 |
| - ), |
1803 |
| - ( |
1804 |
| - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, |
1805 |
| - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, |
1806 |
| - ), |
1807 |
| - ], |
1808 |
| -) |
1809 |
| -def test_permute_dimensions(dims, inverse_dims): |
1810 |
| - sample = dict( |
1811 |
| - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), |
1812 |
| - image=make_image(), |
1813 |
| - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), |
1814 |
| - video=make_video(), |
1815 |
| - str="str", |
1816 |
| - int=0, |
| 1867 | + assert transform(tensor).dtype is torch.float64 |
| 1868 | + |
| 1869 | + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) |
| 1870 | + def test_plain_tensor_warning(self, other_type): |
| 1871 | + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): |
| 1872 | + transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) |
| 1873 | + |
| 1874 | + |
| 1875 | +class TestPermuteDimensions: |
| 1876 | + @pytest.mark.parametrize( |
| 1877 | + ("dims", "inverse_dims"), |
| 1878 | + [ |
| 1879 | + ( |
| 1880 | + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, |
| 1881 | + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, |
| 1882 | + ), |
| 1883 | + ( |
| 1884 | + {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, |
| 1885 | + {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, |
| 1886 | + ), |
| 1887 | + ], |
1817 | 1888 | )
|
| 1889 | + def test_call(self, dims, inverse_dims): |
| 1890 | + sample = dict( |
| 1891 | + image=make_image(), |
| 1892 | + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), |
| 1893 | + video=make_video(), |
| 1894 | + str="str", |
| 1895 | + int=0, |
| 1896 | + ) |
1818 | 1897 |
|
1819 |
| - transform = transforms.PermuteDimensions(dims) |
1820 |
| - transformed_sample = transform(sample) |
| 1898 | + transform = transforms.PermuteDimensions(dims) |
| 1899 | + transformed_sample = transform(sample) |
1821 | 1900 |
|
1822 |
| - for key, value in sample.items(): |
1823 |
| - value_type = type(value) |
1824 |
| - transformed_value = transformed_sample[key] |
| 1901 | + for key, value in sample.items(): |
| 1902 | + value_type = type(value) |
| 1903 | + transformed_value = transformed_sample[key] |
1825 | 1904 |
|
1826 |
| - if check_type( |
1827 |
| - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) |
1828 |
| - ): |
1829 |
| - if transform.dims.get(value_type) is not None: |
1830 |
| - assert transformed_value.permute(inverse_dims[value_type]).equal(value) |
1831 |
| - assert type(transformed_value) == torch.Tensor |
1832 |
| - else: |
1833 |
| - assert transformed_value is value |
| 1905 | + if check_type( |
| 1906 | + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) |
| 1907 | + ): |
| 1908 | + if transform.dims.get(value_type) is not None: |
| 1909 | + assert transformed_value.permute(inverse_dims[value_type]).equal(value) |
| 1910 | + assert type(transformed_value) == torch.Tensor |
| 1911 | + else: |
| 1912 | + assert transformed_value is value |
1834 | 1913 |
|
| 1914 | + @pytest.mark.filterwarnings("error") |
| 1915 | + def test_plain_tensor_call(self): |
| 1916 | + tensor = torch.empty((2, 3, 4)) |
| 1917 | + transform = transforms.PermuteDimensions(dims=(1, 2, 0)) |
1835 | 1918 |
|
1836 |
| -@pytest.mark.parametrize( |
1837 |
| - "dims", |
1838 |
| - [ |
1839 |
| - (-1, -2), |
1840 |
| - {torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None}, |
1841 |
| - ], |
1842 |
| -) |
1843 |
| -def test_transpose_dimensions(dims): |
1844 |
| - sample = dict( |
1845 |
| - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), |
1846 |
| - image=make_image(), |
1847 |
| - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), |
1848 |
| - video=make_video(), |
1849 |
| - str="str", |
1850 |
| - int=0, |
| 1919 | + assert transform(tensor).shape == (3, 4, 2) |
| 1920 | + |
| 1921 | + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) |
| 1922 | + def test_plain_tensor_warning(self, other_type): |
| 1923 | + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): |
| 1924 | + transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) |
| 1925 | + |
| 1926 | + |
| 1927 | +class TestTransposeDimensions: |
| 1928 | + @pytest.mark.parametrize( |
| 1929 | + "dims", |
| 1930 | + [ |
| 1931 | + (-1, -2), |
| 1932 | + {datapoints.Image: (1, 2), datapoints.Video: None}, |
| 1933 | + ], |
1851 | 1934 | )
|
| 1935 | + def test_call(self, dims): |
| 1936 | + sample = dict( |
| 1937 | + image=make_image(), |
| 1938 | + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), |
| 1939 | + video=make_video(), |
| 1940 | + str="str", |
| 1941 | + int=0, |
| 1942 | + ) |
1852 | 1943 |
|
1853 |
| - transform = transforms.TransposeDimensions(dims) |
1854 |
| - transformed_sample = transform(sample) |
| 1944 | + transform = transforms.TransposeDimensions(dims) |
| 1945 | + transformed_sample = transform(sample) |
1855 | 1946 |
|
1856 |
| - for key, value in sample.items(): |
1857 |
| - value_type = type(value) |
1858 |
| - transformed_value = transformed_sample[key] |
| 1947 | + for key, value in sample.items(): |
| 1948 | + value_type = type(value) |
| 1949 | + transformed_value = transformed_sample[key] |
1859 | 1950 |
|
1860 |
| - transposed_dims = transform.dims.get(value_type) |
1861 |
| - if check_type( |
1862 |
| - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) |
1863 |
| - ): |
1864 |
| - if transposed_dims is not None: |
1865 |
| - assert transformed_value.transpose(*transposed_dims).equal(value) |
1866 |
| - assert type(transformed_value) == torch.Tensor |
1867 |
| - else: |
1868 |
| - assert transformed_value is value |
| 1951 | + transposed_dims = transform.dims.get(value_type) |
| 1952 | + if check_type( |
| 1953 | + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) |
| 1954 | + ): |
| 1955 | + if transposed_dims is not None: |
| 1956 | + assert transformed_value.transpose(*transposed_dims).equal(value) |
| 1957 | + assert type(transformed_value) == torch.Tensor |
| 1958 | + else: |
| 1959 | + assert transformed_value is value |
| 1960 | + |
| 1961 | + @pytest.mark.filterwarnings("error") |
| 1962 | + def test_plain_tensor_call(self): |
| 1963 | + tensor = torch.empty((2, 3, 4)) |
| 1964 | + transform = transforms.TransposeDimensions(dims=(0, 2)) |
| 1965 | + |
| 1966 | + assert transform(tensor).shape == (4, 3, 2) |
| 1967 | + |
| 1968 | + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) |
| 1969 | + def test_plain_tensor_warning(self, other_type): |
| 1970 | + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): |
| 1971 | + transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) |
1869 | 1972 |
|
1870 | 1973 |
|
1871 | 1974 | class TestUniformTemporalSubsample:
|
|
0 commit comments