Hi @seba-1511, I was trying to generate a taskset for the FGVC Aircraft dataset using RandomClassRotation. Here's the minimum reproducible code
def fgvcaircraft_tasksets(train_ways=5, train_samples=10, test_ways=5, test_samples=10, root='~/data', device=None, **kwargs):
data_transform = tv.transforms.Compose([tv.transforms.Resize((84, 84), interpolation=LANCZOS), tv.transforms.ToTensor()])
train_dataset = l2l.vision.datasets.FGVCAircraft(root=root, transform=data_transform, mode='train', download=True)
train_dataset = l2l.data.MetaDataset(train_dataset)
train_transforms = [Nways ... LoadData ... ConsecutiveLabels(train_dataset), RandomClassRotation(train_dataset, [0, 90, 180, 270])]
return (train_dataset), (train_transforms)
def fgvcaircraft_benchmark(train_ways=5, train_samples=10, test_ways=5, test_samples=10, num_tasks_train=20000, num_tasks_test=600, root='~/data'):
datasets, transforms = fgvcaircraft_tasksets()
train_dataset, train_transforms = datasets, transforms
train_tasks = l2l.data.TaskDataset(dataset=train_dataset, task_transforms=train_transforms, num_tasks=num_tasks_train,)
return BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)
tasksets = fgvcaircraft_benchmark()
X, y = tasksets.train.sample()
This results in the following stacktrace:
<ipython-input-3-9196a1451978> in <lambda>(x)
26 ])
27 rotation = rotations[c]
---> 28 data_description.transforms.append(lambda x: (rotation(x[0]), x[1]))
29 return task_description
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
58 def __call__(self, img):
59 for t in self.transforms:
---> 60 img = t(img)
61 return img
62
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py in forward(self, img)
1284 angle = self.get_params(self.degrees)
1285
-> 1286 return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
1287
1288 def __repr__(self):
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, interpolation, expand, center, fill, resample)
988 if not isinstance(img, torch.Tensor):
989 pil_interpolation = pil_modes_mapping[interpolation]
--> 990 return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
991
992 center_f = [0.0, 0.0]
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional_pil.py in rotate(img, angle, interpolation, expand, center, fill)
276 raise TypeError("img should be PIL Image. Got {}".format(type(img)))
277
--> 278 opts = _parse_fill(fill, img)
279 return img.rotate(angle, interpolation, expand, center, **opts)
280
/usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional_pil.py in _parse_fill(fill, img, name)
254 msg = ("The number of elements in 'fill' does not match the number of "
255 "bands of the image ({} != {})")
--> 256 raise ValueError(msg.format(len(fill), num_bands))
257
258 fill = tuple(fill)
ValueError: The number of elements in 'fill' does not match the number of bands of the image (1 != 3)
As the error mentions, there's something with the fill in RandomClassRotation class transform https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/transforms.py#L14.
- Should the below line be:
transforms.RandomRotation((rot, rot), fill=0),
# by default fill is 0, so we can even remove that
instead of this:
transforms.RandomRotation((rot, rot), fill=(0, )),
# https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/transforms.py#L48
- Is there a need for try & except (https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/transforms.py#L45) in that case?
Thanks, and do correct me if I am wrong.
Hi @seba-1511, I was trying to generate a taskset for the
FGVC Aircraftdataset usingRandomClassRotation. Here's the minimum reproducible codeThis results in the following stacktrace:
As the error mentions, there's something with the
fillinRandomClassRotationclass transform https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/transforms.py#L14.instead of this:
Thanks, and do correct me if I am wrong.