Skip to content

only return small set of targets by default from dataset wrapper #7488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,16 @@ def load_example_coco_detection_dataset(**kwargs):
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# ``torchvision.datapoints``.
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.

dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), list(target.keys()))
print(type(target["boxes"]), type(target["masks"]), type(target["labels"]))
print(type(target["boxes"]), type(target["labels"]))

########################################################################################################################
# As baseline, let's have a look at a sample without transformations:
Expand Down
18 changes: 15 additions & 3 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,21 @@ def test_transforms_v2_wrapper(self, config):

try:
with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
wrapped_sample = wrapped_dataset[0]

assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
Expand Down
4 changes: 3 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ def _create_annotations(self, image_ids, num_annotations_per_image):
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
Comment on lines +774 to +775
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the full sample now for target_keys="all".

)
)
annotion_id += 1
Expand Down Expand Up @@ -3336,7 +3338,7 @@ def test_subclass(self, mocker):
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
)

class MyFakeData(datasets.FakeData):
Expand Down
Loading