@@ -74,7 +74,7 @@ def __getitem__(self, idx):
74
74
# of this class
75
75
sample = self ._dataset [idx ]
76
76
77
- sample = self ._wrapper (sample )
77
+ sample = self ._wrapper (idx , sample )
78
78
79
79
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
80
80
# or joint (`transforms`), we can access the full functionality through `transforms`
@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
125
125
126
126
127
127
def classification_wrapper_factory (dataset ):
128
- return identity
128
+ def wrapper (idx , sample ):
129
+ return sample
130
+
131
+ return wrapper
129
132
130
133
131
134
for dataset_cls in [
@@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset):
143
146
144
147
145
148
def segmentation_wrapper_factory (dataset ):
146
- def wrapper (sample ):
149
+ def wrapper (idx , sample ):
147
150
image , mask = sample
148
151
return image , pil_image_to_mask (mask )
149
152
@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
163
166
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
164
167
)
165
168
166
- def wrapper (sample ):
169
+ def wrapper (idx , sample ):
167
170
video , audio , label = sample
168
171
169
172
video = datapoints .Video (video )
@@ -201,14 +204,17 @@ def segmentation_to_mask(segmentation, *, spatial_size):
201
204
)
202
205
return torch .from_numpy (mask .decode (segmentation ))
203
206
204
- def wrapper (sample ):
207
+ def wrapper (idx , sample ):
208
+ image_id = dataset .ids [idx ]
209
+
205
210
image , target = sample
206
211
212
+ if not target :
213
+ return image , dict (image_id = image_id )
214
+
207
215
batched_target = list_of_dicts_to_dict_of_lists (target )
208
216
209
- image_ids = batched_target .pop ("image_id" )
210
- image_id = batched_target ["image_id" ] = image_ids .pop ()
211
- assert all (other_image_id == image_id for other_image_id in image_ids )
217
+ batched_target ["image_id" ] = image_id
212
218
213
219
spatial_size = tuple (F .get_spatial_size (image ))
214
220
batched_target ["boxes" ] = datapoints .BoundingBox (
@@ -259,7 +265,7 @@ def wrapper(sample):
259
265
260
266
@WRAPPER_FACTORIES .register (datasets .VOCDetection )
261
267
def voc_detection_wrapper_factory (dataset ):
262
- def wrapper (sample ):
268
+ def wrapper (idx , sample ):
263
269
image , target = sample
264
270
265
271
batched_instances = list_of_dicts_to_dict_of_lists (target ["annotation" ]["object" ])
@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
294
300
if any (target_type in dataset .target_type for target_type in ["attr" , "landmarks" ]):
295
301
raise_not_supported ("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`" )
296
302
297
- def wrapper (sample ):
303
+ def wrapper (idx , sample ):
298
304
image , target = sample
299
305
300
306
target = wrap_target_by_type (
@@ -318,7 +324,7 @@ def wrapper(sample):
318
324
319
325
@WRAPPER_FACTORIES .register (datasets .Kitti )
320
326
def kitti_wrapper_factory (dataset ):
321
- def wrapper (sample ):
327
+ def wrapper (idx , sample ):
322
328
image , target = sample
323
329
324
330
if target is not None :
@@ -336,7 +342,7 @@ def wrapper(sample):
336
342
337
343
@WRAPPER_FACTORIES .register (datasets .OxfordIIITPet )
338
344
def oxford_iiit_pet_wrapper_factor (dataset ):
339
- def wrapper (sample ):
345
+ def wrapper (idx , sample ):
340
346
image , target = sample
341
347
342
348
if target is not None :
@@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask):
371
377
labels .append (label )
372
378
return dict (masks = datapoints .Mask (torch .stack (masks )), labels = torch .stack (labels ))
373
379
374
- def wrapper (sample ):
380
+ def wrapper (idx , sample ):
375
381
image , target = sample
376
382
377
383
target = wrap_target_by_type (
@@ -390,7 +396,7 @@ def wrapper(sample):
390
396
391
397
@WRAPPER_FACTORIES .register (datasets .WIDERFace )
392
398
def widerface_wrapper (dataset ):
393
- def wrapper (sample ):
399
+ def wrapper (idx , sample ):
394
400
image , target = sample
395
401
396
402
if target is not None :
0 commit comments