Skip to content

Dimension out of range (expected to be in range of [-1, 0] on transforms.py boxes.unbind(1) #2198

Closed
@sarmientoj24

Description

@sarmientoj24

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Create Dataset object
class OwnDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        # Own coco file
        coco = self.coco
        # Image ID
        img_id = self.ids[index]
        # List: get annotation id from coco
        ann_ids = coco.getAnnIds(imgIds=img_id)
        # Dictionary: target coco_annotation file for an image
        coco_annotation = coco.loadAnns(ann_ids)
        # path for input image
        path = coco.loadImgs(img_id)[0]['file_name']
        # open the input image
        img = Image.open(os.path.join(self.root, path))

        # number of objects in the image
        num_objs = len(coco_annotation)

        # Bounding boxes for objects
        # In coco format, bbox = [xmin, ymin, width, height]
        # In pytorch, the input should be [xmin, ymin, xmax, ymax]
        boxes = []
        for i in range(num_objs):
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # Labels (In my case, I only one class: target class or background)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        # Tensorise img_id
        img_id = torch.tensor([img_id])
        # Size of bbox (Rectangular)
        areas = []
        for i in range(num_objs):
            areas.append(coco_annotation[i]['area'])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        # Iscrowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Annotation is in dictionary format
        my_annotation = {}
        my_annotation["boxes"] = boxes
        my_annotation["labels"] = labels
        my_annotation["image_id"] = img_id
        my_annotation["area"] = areas
        my_annotation["iscrowd"] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)

        return img, my_annotation

    def __len__(self):
        return len(self.ids)
  1. Create DataLoader
from torchvision import transforms
def get_transform():
    custom_transforms = [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    return torchvision.transforms.Compose(custom_transforms)

train_data_dir = '/content/datasets/imgs/'
train_coco = '/content/datasets/outputs/train.json'


# create own Dataset
my_dataset = OwnDataset(root=train_data_dir, 
                        annotation=train_coco, 
                        transforms=get_transform()
                        )

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

# Batch size
train_batch_size = 2

# own DataLoader
data_loader = torch.utils.data.DataLoader(my_dataset,
                                          batch_size=train_batch_size,
                                          shuffle=True,
                                          num_workers=4,
                                          collate_fn=collate_fn)

  1. Create a model of MobileNetV2 on top of Faster RCNN (below is my model instance)
def get_model_instance_segmentation(num_classes):
  # load a pre-trained model for classification and return
  # only the features
  backbone = torchvision.models.mobilenet_v2(pretrained=True).features
  # FasterRCNN needs to know the number of
  # output channels in a backbone. For mobilenet_v2, it's 1280
  # so we need to add it here
  backbone.out_channels = 1280
  # let's make the RPN generate 5 x 3 anchors per spatial
  # location, with 5 different sizes and 3 different aspect
  # ratios. We have a Tuple[Tuple[int]] because each feature
  # map could potentially have different sizes and
  # aspect ratios
  anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256),),
                                    aspect_ratios=((0.5, 1.0, 2.0),))
  # let's define what are the feature maps that we will
  # use to perform the region of interest cropping, as well as
  # the size of the crop after rescaling.
  # if your backbone returns a Tensor, featmap_names is expected to
  # be [0]. More generally, the backbone should return an
  # OrderedDict[Tensor], and in featmap_names you can choose which
  # feature maps to use.
  roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                  output_size=7,
                                                  sampling_ratio=2)
  # put the pieces together inside a FasterRCNN model
  model = FasterRCNN(backbone,
                   num_classes=num_classes,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
  return model
  1. Train the model with libraries/modules from pytorch such as transforms.py, utils.py, engine.py, coco_eval.py, etc.
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)
total_len = num_epochs * len_dataloader
total_processed = 0

import time

time_s = time.time()
for epoch in range(num_epochs):
  model.train()
  i = 0    
  for imgs, annotations in data_loader:
    i += 1
    total_processed += 1
    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
    loss_dict = model(imgs, annotations)
    losses = sum(loss for loss in loss_dict.values())

    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

    if i % 100 == 0:
      print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')
      time_elapsed = time.time() - time_s
      ave_processing_time = total_processed / time_elapsed
      remaining_records = total_len - total_processed
      time_remaining = remaining_records / ave_processing_time
      print(f"Time remaining in seconds {time_remaining}s and {time_remaining / 60} min...")
  if epoch % 2 == 0:
    print('Saving model')
    torch.save(model, '/content/gdrive/My Drive/libraries/model_full.pth')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, '/content/gdrive/My Drive/libraries/model_full_resuming.pth') 
  1. Error comes out like this:
<ipython-input-30-0ef74c49d85f> in <module>()
     77     imgs = list(img.to(device) for img in imgs)
     78     annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
---> 79     loss_dict = model(imgs, annotations)
     80     losses = sum(loss for loss in loss_dict.values())
     81 

5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     64             original_image_sizes.append((val[0], val[1]))
     65 
---> 66         images, targets = self.transform(images, targets)
     67         features = self.backbone(images.tensors)
     68         if isinstance(features, torch.Tensor):

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/transform.py in forward(self, images, targets)
     43                                  "of shape [C, H, W], got {}".format(image.shape))
     44             image = self.normalize(image)
---> 45             image, target_index = self.resize(image, target_index)
     46             images[i] = image
     47             if targets is not None and target_index is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/transform.py in resize(self, image, target)
     96 
     97         bbox = target["boxes"]
---> 98         bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
     99         target["boxes"] = bbox
    100 

/usr/local/lib/python3.6/dist-packages/torchvision/models/detection/transform.py in resize_boxes(boxes, original_size, new_size)
    218     ]
    219     ratio_height, ratio_width = ratios
--> 220     xmin, ymin, xmax, ymax = boxes.unbind(1)
    221 
    222     xmin = xmin * ratio_width

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)```

## Expected behavior

Expected to have the training as normal without the bug.

## Environment

Please copy and paste the output from our
[environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py)
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py

For security purposes, please check the contents of collect_env.py before running it.

python collect_env.py


PyTorch version: 1.5.0+cu101
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: Tesla P100-PCIE-16GB
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.4
[pip3] torch==1.5.0+cu101
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.6.0+cu101
[conda] Could not collect

## Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions