Skip to content

call dataset wrapper with idx and sample #7235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions torchvision/prototype/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __getitem__(self, idx):
# of this class
sample = self._dataset[idx]

sample = self._wrapper(sample)
sample = self._wrapper(idx, sample)

# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
Expand Down Expand Up @@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):


def classification_wrapper_factory(dataset):
return identity
def wrapper(idx, sample):
return sample

return wrapper


for dataset_cls in [
Expand All @@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset):


def segmentation_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, mask = sample
return image, pil_image_to_mask(mask)

Expand All @@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)

def wrapper(sample):
def wrapper(idx, sample):
video, audio, label = sample

video = datapoints.Video(video)
Expand Down Expand Up @@ -201,14 +204,17 @@ def segmentation_to_mask(segmentation, *, spatial_size):
)
return torch.from_numpy(mask.decode(segmentation))

def wrapper(sample):
def wrapper(idx, sample):
image_id = dataset.ids[idx]

image, target = sample

if not target:
return image, dict(image_id=image_id)

batched_target = list_of_dicts_to_dict_of_lists(target)

image_ids = batched_target.pop("image_id")
image_id = batched_target["image_id"] = image_ids.pop()
assert all(other_image_id == image_id for other_image_id in image_ids)
batched_target["image_id"] = image_id

spatial_size = tuple(F.get_spatial_size(image))
batched_target["boxes"] = datapoints.BoundingBox(
Expand Down Expand Up @@ -259,7 +265,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
Expand Down Expand Up @@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")

def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

target = wrap_target_by_type(
Expand All @@ -318,7 +324,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand All @@ -336,7 +342,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand Down Expand Up @@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask):
labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))

def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

target = wrap_target_by_type(
Expand All @@ -390,7 +396,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand Down