Skip to content

Concatenate torchvision.datasets.FakeData with another dataset -> cannot load it #3517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
zhifengkong opened this issue Mar 6, 2021 · 5 comments

Comments

@zhifengkong
Copy link

zhifengkong commented Mar 6, 2021

🐛 Bug

If you concatenate a dataset such as CIFAR10 with FakeData, you get error

  • AttributeError: 'int' object has no attribute 'numel'

To Reproduce

Steps to reproduce the behavior:

  1. cifar_dataset = torchvision.datasets.CIFAR10(...)
  2. fake_dataset = torchvision.datasets.FakeData(...)
  3. train_data = Concat([cifar_dataset, fake_dataset])
  4. train_loader = DataLoader(train_data, ...)
  5. for data in train_loader then error

Additional context

The reason why it happens is the labels in CIFAR10 are int and labels in FakeData are tensors. When concatenating them to construct a batch, the batch labels look like [0,1,2,3,tensor(0),3,4,5,6,tensor(2)...].

I can solve this bug by letting target_transform=int when I load fake_dataset. However, this is very hard to debug. I assume that the default target type in the FakeData source code should be set to int instead of long tensor.

Here:
https://pytorch.org/vision/0.8/_modules/torchvision/datasets/fakedata.html#FakeData
in function __getitem__
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
It's long tensor. It should be int.

cc @pmeier @fmassa @vfdev-5

@zhifengkong zhifengkong changed the title torchvision.datasets.FakeData target type error torchvision.datasets.FakeData concatenated with another dataset -> cannot load it Mar 6, 2021
@zhifengkong zhifengkong changed the title torchvision.datasets.FakeData concatenated with another dataset -> cannot load it Concatenate torchvision.datasets.FakeData with another dataset -> cannot load it Mar 6, 2021
@vfdev-5 vfdev-5 transferred this issue from pytorch/pytorch Mar 6, 2021
@fmassa
Copy link
Member

fmassa commented Mar 9, 2021

Yes, the target should be a int, and it was the case in the first version of torchvision, but with the introduction of scalar tensors in PyTorch that snippet became a tensor.

I'm happy to accept a PR adding a .item() call to the aforementioned line to fix the issue

@fmassa fmassa added the bug label Mar 9, 2021
zhifengkong added a commit to zhifengkong/vision that referenced this issue Mar 12, 2021
target -> target.item() so it's an int instead of a long tensor
@avijit9
Copy link
Contributor

avijit9 commented Mar 18, 2021

Anybody working on this? I can send a PR otherwise.

@pmeier
Copy link
Collaborator

pmeier commented Mar 18, 2021

@avijit9 Go ahead!

@avijit9
Copy link
Contributor

avijit9 commented Mar 22, 2021

@pmeier Shouldn't this issue be closed?

@pmeier
Copy link
Collaborator

pmeier commented Mar 22, 2021

Indeed it should. If the PR contains a certain keyword together with the issue number, GitHub will close the issue automatically when the PR is merged.

You used a keyword in your PR that is not recognized by GitHub:

solves #3517

@pmeier pmeier closed this as completed Mar 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants