Skip to content

Commit a1b674f

Browse files
NicolasHugvfdev-5
authored andcommitted
[fbsync] Promote prototype transforms to beta status (#7261)
Summary: Reviewed By: vmoens Differential Revision: D44416605 fbshipit-source-id: 0fd313c8279f6be9ed488a19e2d053f33c2f77a5 Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: vfdev-5 <[email protected]>
1 parent 1562a89 commit a1b674f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1683
-1669
lines changed

test/datasets_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,24 +584,22 @@ def test_transforms(self, config):
584584

585585
@test_all_configs
586586
def test_transforms_v2_wrapper(self, config):
587-
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
588-
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
589-
# branch, we screwed up the roll-out
590-
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
591-
from torchvision.prototype.datapoints._datapoint import Datapoint
587+
from torchvision.datapoints import wrap_dataset_for_transforms_v2
588+
from torchvision.datapoints._datapoint import Datapoint
592589

593590
try:
594591
with self.create_dataset(config) as (dataset, _):
595592
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
596593
wrapped_sample = wrapped_dataset[0]
597594
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
598595
except TypeError as error:
599-
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
600-
return
596+
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
597+
if str(error).startswith(msg):
598+
pytest.skip(msg)
601599
raise error
602600
except RuntimeError as error:
603601
if "currently not supported by this wrapper" in str(error):
604-
return
602+
pytest.skip("Config is currently not supported by this wrapper")
605603
raise error
606604

607605

test/prototype_common_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import pytest
1313
import torch
1414
import torch.testing
15+
import torchvision.prototype.datapoints as proto_datapoints
1516
from datasets_utils import combinations_grid
1617
from torch.nn.functional import one_hot
1718
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
18-
from torchvision.prototype import datapoints
19-
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
19+
from torchvision import datapoints
2020
from torchvision.transforms.functional_tensor import _max_value as get_max_value
21+
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
2122

2223
__all__ = [
2324
"assert_close",
@@ -457,7 +458,7 @@ def fn(shape, dtype, device):
457458
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
458459
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
459460
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
460-
return datapoints.Label(data, categories=categories)
461+
return proto_datapoints.Label(data, categories=categories)
461462

462463
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
463464

@@ -481,7 +482,7 @@ def fn(shape, dtype, device):
481482
# since `one_hot` only supports int64
482483
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
483484
data = one_hot(label, num_classes=num_categories).to(dtype)
484-
return datapoints.OneHotLabel(data, categories=categories)
485+
return proto_datapoints.OneHotLabel(data, categories=categories)
485486

486487
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
487488

test/prototype_transforms_dispatcher_infos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import collections.abc
22

33
import pytest
4-
import torchvision.prototype.transforms.functional as F
4+
import torchvision.transforms.v2.functional as F
55
from prototype_common_utils import InfoBase, TestMark
66
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
7-
from torchvision.prototype import datapoints
7+
from torchvision import datapoints
88

99
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
1010

test/prototype_transforms_kernel_infos.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
import torch.testing
1010
import torchvision.ops
11-
import torchvision.prototype.transforms.functional as F
11+
import torchvision.transforms.v2.functional as F
1212
from datasets_utils import combinations_grid
1313
from prototype_common_utils import (
1414
ArgsKwargs,
@@ -28,7 +28,7 @@
2828
TestMark,
2929
)
3030
from torch.utils._pytree import tree_map
31-
from torchvision.prototype import datapoints
31+
from torchvision import datapoints
3232
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
3333

3434
__all__ = ["KernelInfo", "KERNEL_INFOS"]
@@ -2383,19 +2383,18 @@ def sample_inputs_convert_dtype_video():
23832383

23842384
def sample_inputs_uniform_temporal_subsample_video():
23852385
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
2386-
for temporal_dim in [-4, len(video_loader.shape) - 4]:
2387-
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
2386+
yield ArgsKwargs(video_loader, num_samples=2)
23882387

23892388

2390-
def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
2389+
def reference_uniform_temporal_subsample_video(x, num_samples):
23912390
# Copy-pasted from
23922391
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
2393-
t = x.shape[temporal_dim]
2392+
t = x.shape[-4]
23942393
assert num_samples > 0 and t > 0
23952394
# Sample by nearest neighbor interpolation if num_samples > t.
23962395
indices = torch.linspace(0, t - 1, num_samples)
23972396
indices = torch.clamp(indices, 0, t - 1).long()
2398-
return torch.index_select(x, temporal_dim, indices)
2397+
return torch.index_select(x, -4, indices)
23992398

24002399

24012400
def reference_inputs_uniform_temporal_subsample_video():
@@ -2410,12 +2409,5 @@ def reference_inputs_uniform_temporal_subsample_video():
24102409
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
24112410
reference_fn=reference_uniform_temporal_subsample_video,
24122411
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
2413-
test_marks=[
2414-
TestMark(
2415-
("TestKernels", "test_batched_vs_single"),
2416-
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
2417-
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
2418-
),
2419-
],
24202412
)
24212413
)

test/test_prototype_datapoints.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from PIL import Image
77

8-
from torchvision import datasets
9-
from torchvision.prototype import datapoints
8+
from torchvision import datapoints, datasets
9+
from torchvision.prototype import datapoints as proto_datapoints
1010

1111

1212
@pytest.mark.parametrize(
@@ -24,38 +24,38 @@
2424
],
2525
)
2626
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
27-
datapoint = datapoints.Label(data, requires_grad=input_requires_grad)
27+
datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
2828
assert datapoint.requires_grad is expected_requires_grad
2929

3030

3131
def test_isinstance():
3232
assert isinstance(
33-
datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
33+
proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
3434
torch.Tensor,
3535
)
3636

3737

3838
def test_wrapping_no_copy():
3939
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
40-
label = datapoints.Label(tensor, categories=["foo", "bar"])
40+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
4141

4242
assert label.data_ptr() == tensor.data_ptr()
4343

4444

4545
def test_to_wrapping():
4646
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
47-
label = datapoints.Label(tensor, categories=["foo", "bar"])
47+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
4848

4949
label_to = label.to(torch.int32)
5050

51-
assert type(label_to) is datapoints.Label
51+
assert type(label_to) is proto_datapoints.Label
5252
assert label_to.dtype is torch.int32
5353
assert label_to.categories is label.categories
5454

5555

5656
def test_to_datapoint_reference():
5757
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
58-
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
58+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
5959

6060
tensor_to = tensor.to(label)
6161

@@ -65,31 +65,31 @@ def test_to_datapoint_reference():
6565

6666
def test_clone_wrapping():
6767
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
68-
label = datapoints.Label(tensor, categories=["foo", "bar"])
68+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
6969

7070
label_clone = label.clone()
7171

72-
assert type(label_clone) is datapoints.Label
72+
assert type(label_clone) is proto_datapoints.Label
7373
assert label_clone.data_ptr() != label.data_ptr()
7474
assert label_clone.categories is label.categories
7575

7676

7777
def test_requires_grad__wrapping():
7878
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
79-
label = datapoints.Label(tensor, categories=["foo", "bar"])
79+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
8080

8181
assert not label.requires_grad
8282

8383
label_requires_grad = label.requires_grad_(True)
8484

85-
assert type(label_requires_grad) is datapoints.Label
85+
assert type(label_requires_grad) is proto_datapoints.Label
8686
assert label.requires_grad
8787
assert label_requires_grad.requires_grad
8888

8989

9090
def test_other_op_no_wrapping():
9191
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
92-
label = datapoints.Label(tensor, categories=["foo", "bar"])
92+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
9393

9494
# any operation besides .to() and .clone() will do here
9595
output = label * 2
@@ -107,33 +107,33 @@ def test_other_op_no_wrapping():
107107
)
108108
def test_no_tensor_output_op_no_wrapping(op):
109109
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
110-
label = datapoints.Label(tensor, categories=["foo", "bar"])
110+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
111111

112112
output = op(label)
113113

114-
assert type(output) is not datapoints.Label
114+
assert type(output) is not proto_datapoints.Label
115115

116116

117117
def test_inplace_op_no_wrapping():
118118
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
119-
label = datapoints.Label(tensor, categories=["foo", "bar"])
119+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
120120

121121
output = label.add_(0)
122122

123123
assert type(output) is torch.Tensor
124-
assert type(label) is datapoints.Label
124+
assert type(label) is proto_datapoints.Label
125125

126126

127127
def test_wrap_like():
128128
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
129-
label = datapoints.Label(tensor, categories=["foo", "bar"])
129+
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
130130

131131
# any operation besides .to() and .clone() will do here
132132
output = label * 2
133133

134-
label_new = datapoints.Label.wrap_like(label, output)
134+
label_new = proto_datapoints.Label.wrap_like(label, output)
135135

136-
assert type(label_new) is datapoints.Label
136+
assert type(label_new) is proto_datapoints.Label
137137
assert label_new.data_ptr() == output.data_ptr()
138138
assert label_new.categories is label.categories
139139

test/test_prototype_datasets_builtin.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import pytest
77
import torch
8+
import torchvision.transforms.v2 as transforms
89

9-
import torchvision.prototype.transforms.utils
1010
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
1111
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
1212

@@ -19,10 +19,13 @@
1919
from torchdata.dataloader2.graph.utils import traverse_dps
2020
from torchdata.datapipes.iter import ShardingFilter, Shuffler
2121
from torchdata.datapipes.utils import StreamWrapper
22+
from torchvision import datapoints
2223
from torchvision._utils import sequence_to_str
23-
from torchvision.prototype import datapoints, datasets, transforms
24+
from torchvision.prototype import datasets
25+
from torchvision.prototype.datapoints import Label
2426
from torchvision.prototype.datasets.utils import EncodedImage
2527
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
28+
from torchvision.transforms.v2.utils import is_simple_tensor
2629

2730

2831
def assert_samples_equal(*args, msg=None, **kwargs):
@@ -141,9 +144,7 @@ def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
141144
dataset, _ = dataset_mock.load(config)
142145
sample = next_consume(iter(dataset))
143146

144-
simple_tensors = {
145-
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
146-
}
147+
simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
147148

148149
if simple_tensors and not any(
149150
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
@@ -276,6 +277,6 @@ def test_sample_content(self, dataset_mock, config):
276277
assert "label" in sample
277278

278279
assert isinstance(sample["image"], datapoints.Image)
279-
assert isinstance(sample["label"], datapoints.Label)
280+
assert isinstance(sample["label"], Label)
280281

281282
assert sample["image"].shape == (1, 16, 16)

0 commit comments

Comments
 (0)