Skip to content

Promote prototype transforms to beta status #7261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
36ae12e
Copy paste prototype.datapoints and prototype.transforms out of proto…
NicolasHug Feb 15, 2023
a144417
make test/test_prototype_transforms_consistency.py use torchvision area
NicolasHug Feb 15, 2023
46086b8
Migrate more tests
NicolasHug Feb 15, 2023
73d0e0f
Updated imports
vfdev-5 Feb 15, 2023
1958bcd
Merge branch 'migration_transforms_v2' of github.com:pytorch/vision i…
NicolasHug Feb 15, 2023
6fd6614
Update geometry files
NicolasHug Feb 15, 2023
fd751c6
SNEAKY
NicolasHug Feb 15, 2023
2a41357
Some updates, might have broken more stuff
NicolasHug Feb 15, 2023
22c7499
Missed some
NicolasHug Feb 15, 2023
55b3772
Fixed functional tests
vfdev-5 Feb 15, 2023
ad87f5b
some removals
NicolasHug Feb 15, 2023
ced6992
Fixed issue with temporal test, test_batched_vs_single
vfdev-5 Feb 15, 2023
b5a67ab
Merge branch 'migration_transforms_v2' of github.com:pytorch/vision i…
NicolasHug Feb 15, 2023
8511b6f
All test_prototype_transforms* should be passing now
NicolasHug Feb 15, 2023
481b939
Merge branch 'main' into migration_transforms_v2
pmeier Feb 15, 2023
f49beb3
cherry-pick d010e82fec10422f79c69564de7ff2721d93d278
pmeier Feb 15, 2023
d8ec4f1
migrate video utility transforms back to prototype
pmeier Feb 15, 2023
99590c0
migrate FixedSizeCrop back to prototype
pmeier Feb 15, 2023
8f49358
remove all mentions of Label and OneHotLabel from transforms.v2
pmeier Feb 15, 2023
ec805da
Fixed code formatting and failing tests
vfdev-5 Feb 15, 2023
86537d7
fix tests
pmeier Feb 15, 2023
3691c8d
Merge branch 'migration_transforms_v2' of https://github.com/pytorch/…
pmeier Feb 15, 2023
1f86853
remove dead code from prototype
pmeier Feb 16, 2023
fe00914
lint and prototype datasets
pmeier Feb 16, 2023
64ecaf1
fix dataset wrapper tests
pmeier Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,24 +584,22 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint

try:
with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
return
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
pytest.skip(msg)
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
return
pytest.skip("Config is currently not supported by this wrapper")
raise error


Expand Down
9 changes: 5 additions & 4 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import pytest
import torch
import torch.testing
import torchvision.prototype.datapoints as proto_datapoints
from datasets_utils import combinations_grid
from torch.nn.functional import one_hot
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor

__all__ = [
"assert_close",
Expand Down Expand Up @@ -457,7 +458,7 @@ def fn(shape, dtype, device):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories)
return proto_datapoints.Label(data, categories=categories)

return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)

Expand All @@ -481,7 +482,7 @@ def fn(shape, dtype, device):
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories)
return proto_datapoints.OneHotLabel(data, categories=categories)

return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)

Expand Down
4 changes: 2 additions & 2 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import collections.abc

import pytest
import torchvision.prototype.transforms.functional as F
import torchvision.transforms.v2.functional as F
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints
from torchvision import datapoints

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]

Expand Down
20 changes: 6 additions & 14 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F
import torchvision.transforms.v2.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
Expand All @@ -28,7 +28,7 @@
TestMark,
)
from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding

__all__ = ["KernelInfo", "KERNEL_INFOS"]
Expand Down Expand Up @@ -2383,19 +2383,18 @@ def sample_inputs_convert_dtype_video():

def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
for temporal_dim in [-4, len(video_loader.shape) - 4]:
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
yield ArgsKwargs(video_loader, num_samples=2)


def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def reference_uniform_temporal_subsample_video(x, num_samples):
# Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[temporal_dim]
t = x.shape[-4]
assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, temporal_dim, indices)
return torch.index_select(x, -4, indices)


def reference_inputs_uniform_temporal_subsample_video():
Expand All @@ -2410,12 +2409,5 @@ def reference_inputs_uniform_temporal_subsample_video():
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
test_marks=[
TestMark(
("TestKernels", "test_batched_vs_single"),
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
),
],
)
)
40 changes: 20 additions & 20 deletions test/test_prototype_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from PIL import Image

from torchvision import datasets
from torchvision.prototype import datapoints
from torchvision import datapoints, datasets
from torchvision.prototype import datapoints as proto_datapoints


@pytest.mark.parametrize(
Expand All @@ -24,38 +24,38 @@
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Label(data, requires_grad=input_requires_grad)
datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad


def test_isinstance():
assert isinstance(
datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor,
)


def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

assert label.data_ptr() == tensor.data_ptr()


def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

label_to = label.to(torch.int32)

assert type(label_to) is datapoints.Label
assert type(label_to) is proto_datapoints.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories


def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)

tensor_to = tensor.to(label)

Expand All @@ -65,31 +65,31 @@ def test_to_datapoint_reference():

def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

label_clone = label.clone()

assert type(label_clone) is datapoints.Label
assert type(label_clone) is proto_datapoints.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories


def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

assert not label.requires_grad

label_requires_grad = label.requires_grad_(True)

assert type(label_requires_grad) is datapoints.Label
assert type(label_requires_grad) is proto_datapoints.Label
assert label.requires_grad
assert label_requires_grad.requires_grad


def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

# any operation besides .to() and .clone() will do here
output = label * 2
Expand All @@ -107,33 +107,33 @@ def test_other_op_no_wrapping():
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

output = op(label)

assert type(output) is not datapoints.Label
assert type(output) is not proto_datapoints.Label


def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

output = label.add_(0)

assert type(output) is torch.Tensor
assert type(label) is datapoints.Label
assert type(label) is proto_datapoints.Label


def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])

# any operation besides .to() and .clone() will do here
output = label * 2

label_new = datapoints.Label.wrap_like(label, output)
label_new = proto_datapoints.Label.wrap_like(label, output)

assert type(label_new) is datapoints.Label
assert type(label_new) is proto_datapoints.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories

Expand Down
13 changes: 7 additions & 6 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import pytest
import torch
import torchvision.transforms.v2 as transforms

import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair

Expand All @@ -19,10 +19,13 @@
from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor


def assert_samples_equal(*args, msg=None, **kwargs):
Expand Down Expand Up @@ -141,9 +144,7 @@ def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))

simple_tensors = {
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}

if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
Expand Down Expand Up @@ -276,6 +277,6 @@ def test_sample_content(self, dataset_mock, config):
assert "label" in sample

assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], datapoints.Label)
assert isinstance(sample["label"], Label)

assert sample["image"].shape == (1, 16, 16)
Loading