Skip to content

Commit 6fb095e

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] remove datapoints compatibility for prototype datasets (#7154)
Reviewed By: vmoens Differential Revision: D44416260 fbshipit-source-id: 9437edffc8a7ccf08f381c5147d9a8f3e18530a3
1 parent 049e7e2 commit 6fb095e

File tree

7 files changed

+25
-36
lines changed

7 files changed

+25
-36
lines changed

test/test_prototype_datasets_builtin.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchdata.datapipes.utils import StreamWrapper
2222
from torchvision._utils import sequence_to_str
2323
from torchvision.prototype import datapoints, datasets, transforms
24+
from torchvision.prototype.datasets.utils import EncodedImage
2425
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
2526

2627

@@ -136,18 +137,21 @@ def make_msg_and_close(head):
136137
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
137138

138139
@parametrize_dataset_mocks(DATASET_MOCKS)
139-
def test_no_simple_tensors(self, dataset_mock, config):
140+
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
140141
dataset, _ = dataset_mock.load(config)
142+
sample = next_consume(iter(dataset))
141143

142144
simple_tensors = {
143-
key
144-
for key, value in next_consume(iter(dataset)).items()
145-
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
145+
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
146146
}
147-
if simple_tensors:
147+
148+
if simple_tensors and not any(
149+
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
150+
):
148151
raise AssertionError(
149152
f"The values of key(s) "
150-
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
153+
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
154+
f"but didn't find any (encoded) image or video."
151155
)
152156

153157
@parametrize_dataset_mocks(DATASET_MOCKS)

torchvision/prototype/datapoints/_datapoint.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,9 @@ def _to_tensor(
2929
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
3030
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
3131

32-
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
33-
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
34-
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
35-
# one public again.
36-
def __new__(
37-
cls,
38-
data: Any,
39-
dtype: Optional[torch.dtype] = None,
40-
device: Optional[Union[torch.device, str, int]] = None,
41-
requires_grad: Optional[bool] = None,
42-
) -> Datapoint:
43-
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
44-
return tensor.as_subclass(Datapoint)
45-
4632
@classmethod
4733
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
48-
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
49-
# this method should be made abstract
50-
# raise NotImplementedError
51-
return tensor.as_subclass(cls)
34+
raise NotImplementedError
5235

5336
_NO_WRAPPING_EXCEPTIONS = {
5437
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from typing import Any, BinaryIO, Dict, List, Tuple, Union
44

55
import numpy as np
6+
7+
import torch
68
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
79
from torchvision.prototype.datapoints import BoundingBox, Label
8-
from torchvision.prototype.datapoints._datapoint import Datapoint
910
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
1011
from torchvision.prototype.datasets.utils._internal import (
1112
hint_sharding,
@@ -115,7 +116,7 @@ def _prepare_sample(
115116
format="xyxy",
116117
spatial_size=image.spatial_size,
117118
),
118-
contour=Datapoint(ann["obj_contour"].T),
119+
contour=torch.as_tensor(ann["obj_contour"].T),
119120
)
120121

121122
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import pathlib
33
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
44

5+
import torch
56
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
67
from torchvision.prototype.datapoints import BoundingBox, Label
7-
from torchvision.prototype.datapoints._datapoint import Datapoint
88
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
99
from torchvision.prototype.datasets.utils._internal import (
1010
getitem,
@@ -149,7 +149,7 @@ def _prepare_sample(
149149
spatial_size=image.spatial_size,
150150
),
151151
landmarks={
152-
landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
152+
landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
153153
for landmark in {key[:-2] for key in landmarks.keys()}
154154
},
155155
)

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
UnBatcher,
1616
)
1717
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
18-
from torchvision.prototype.datapoints._datapoint import Datapoint
1918
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
2019
from torchvision.prototype.datasets.utils._internal import (
2120
getitem,
@@ -124,8 +123,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
124123
]
125124
)
126125
),
127-
areas=Datapoint([ann["area"] for ann in anns]),
128-
crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool),
126+
areas=torch.as_tensor([ann["area"] for ann in anns]),
127+
crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
129128
bounding_boxes=BoundingBox(
130129
[ann["bbox"] for ann in anns],
131130
format="xywh",

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pathlib
44
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union
55

6+
import torch
67
from torchdata.datapipes.iter import (
78
CSVDictParser,
89
CSVParser,
@@ -15,7 +16,6 @@
1516
)
1617
from torchdata.datapipes.map import IterToMapConverter
1718
from torchvision.prototype.datapoints import BoundingBox, Label
18-
from torchvision.prototype.datapoints._datapoint import Datapoint
1919
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
2020
from torchvision.prototype.datasets.utils._internal import (
2121
getitem,
@@ -162,7 +162,7 @@ def _2010_prepare_ann(
162162
format="xyxy",
163163
spatial_size=spatial_size,
164164
),
165-
segmentation=Datapoint(content["seg"]),
165+
segmentation=torch.as_tensor(content["seg"]),
166166
)
167167

168168
def _prepare_sample(

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
44

55
import numpy as np
6+
import torch
67
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
7-
from torchvision.prototype.datapoints._datapoint import Datapoint
88
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
99
from torchvision.prototype.datasets.utils._internal import (
1010
getitem,
@@ -92,8 +92,10 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
9292
image=EncodedImage.from_file(image_buffer),
9393
ann_path=ann_path,
9494
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
95-
boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
96-
segmentation=Datapoint(anns["Segmentation"].item()),
95+
boundaries=torch.as_tensor(
96+
np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
97+
),
98+
segmentation=torch.as_tensor(anns["Segmentation"].item()),
9799
)
98100

99101
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:

0 commit comments

Comments
 (0)