Skip to content

remove datapoints compatibility for prototype datasets #7154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE


Expand Down Expand Up @@ -136,18 +137,21 @@ def make_msg_and_close(head):
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))

simple_tensors = {
key
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors:

if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"but didn't find any (encoded) image or video."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down
19 changes: 1 addition & 18 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,9 @@ def _to_tensor(
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# one public again.
def __new__(
cls,
data: Any,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint)

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)
raise NotImplementedError

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import Any, BinaryIO, Dict, List, Tuple, Union

import numpy as np

import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
Expand Down Expand Up @@ -115,7 +116,7 @@ def _prepare_sample(
format="xyxy",
spatial_size=image.spatial_size,
),
contour=Datapoint(ann["obj_contour"].T),
contour=torch.as_tensor(ann["obj_contour"].T),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -149,7 +149,7 @@ def _prepare_sample(
spatial_size=image.spatial_size,
),
landmarks={
landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
Expand Down
5 changes: 2 additions & 3 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -124,8 +123,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
]
)
),
areas=Datapoint([ann["area"] for ann in anns]),
crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=torch.as_tensor([ann["area"] for ann in anns]),
crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union

import torch
from torchdata.datapipes.iter import (
CSVDictParser,
CSVParser,
Expand All @@ -15,7 +16,6 @@
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -162,7 +162,7 @@ def _2010_prepare_ann(
format="xyxy",
spatial_size=spatial_size,
),
segmentation=Datapoint(content["seg"]),
segmentation=torch.as_tensor(content["seg"]),
)

def _prepare_sample(
Expand Down
8 changes: 5 additions & 3 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
Expand Down Expand Up @@ -92,8 +92,10 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
segmentation=Datapoint(anns["Segmentation"].item()),
boundaries=torch.as_tensor(
np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
),
segmentation=torch.as_tensor(anns["Segmentation"].item()),
)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
Expand Down