Skip to content

VOCSegmentation, VOCDetection, linting passing, examples. #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
from .voc import VOCSegmentation, VOCDetection

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot')
'Omniglot', 'VOCSegmentation', 'VOCDetection')
264 changes: 264 additions & 0 deletions torchvision/datasets/voc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os
import sys
import tarfile
import torch.utils.data as data
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET

from PIL import Image
from .utils import download_url, check_integrity

VOC_CLASSES = [
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
DATASET_YEAR_DICT = {
'2012': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'VOCtrainval_11-May-2012.tar', '6cd6e144f989b92b3379bac3b3de84fd',
' VOCdevkit/VOC2012'
],
'2011': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'VOCtrainval_25-May-2011.tar', '6c3384ef61512963050cb5d687e5bf1e',
'TrainVal/VOCdevkit/VOC2011'
],
'2010': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'VOCtrainval_03-May-2010.tar', 'da459979d0c395079b5c75ee67908abb',
'VOCdevkit/VOC2010'
],
'2009': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'VOCtrainval_11-May-2009.tar', '59065e4b188729180974ef6572f6a212',
'VOCdevkit/VOC2009'
],
'2008': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'VOCtrainval_14-Jul-2008.tar', '2629fa636546599198acfcfbfcf1904a',
'VOCdevkit/VOC2008'
],
'2007': [
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'VOCtrainval_06-Nov-2007.tar', 'c52e279531787c972589f7e41ab4ae64',
'VOCdevkit/VOC2007'
]
}


class VOCSegmentation(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.

Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None):
self.root = root
self.year = year
self.url = DATASET_YEAR_DICT[year][0]
self.filename = DATASET_YEAR_DICT[year][1]
self.md5 = DATASET_YEAR_DICT[year][2]
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
_base_dir = DATASET_YEAR_DICT[year][3]
_voc_root = os.path.join(self.root, _base_dir)
_image_dir = os.path.join(_voc_root, 'JPEGImages')
_mask_dir = os.path.join(_voc_root, 'SegmentationClass')

if download:
download_extract(self.url, self.root, self.filename, self.md5)

if not os.path.isdir(_voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

_splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation')

_split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt')

if not os.path.exists(_split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')

self.images = []
self.masks = []
with open(os.path.join(_split_f), "r") as lines:
for line in lines:
_image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg")
_mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png")
assert os.path.isfile(_image)
assert os.path.isfile(_mask)
self.images.append(_image)
self.masks.append(_mask)

assert (len(self.images) == len(self.masks))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is the image segmentation.
"""
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.masks[index])

if self.transform is not None:
_img = self.transform(_img)

if self.target_transform is not None:
_target = self.target_transform(_target)

return _img, _target

def __len__(self):
return len(self.images)


class VOCDetection(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.

Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
(default: alphabetic indexing of VOC's 20 classes).
keep_difficult (boolean, optional): keep difficult instances or not.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self,
root,
year='2012',
image_set='train',
download=False,
class_to_ind=None,
keep_difficult=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keep_difficult could be part of what the user pass to the target_transform.

transform=None,
target_transform=None):
self.root = root
self.year = year
self.url = DATASET_YEAR_DICT[year][0]
self.filename = DATASET_YEAR_DICT[year][1]
self.md5 = DATASET_YEAR_DICT[year][2]
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
self.class_to_ind = class_to_ind or dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove class_to_ind? It's not used anymore.

zip(VOC_CLASSES, range(len(VOC_CLASSES))))
self.keep_difficult = keep_difficult
_base_dir = DATASET_YEAR_DICT[year][3]
_voc_root = os.path.join(self.root, _base_dir)
_image_dir = os.path.join(_voc_root, 'JPEGImages')
_annotation_dir = os.path.join(_voc_root, 'Annotations')

if download:
download_extract(self.url, self.root, self.filename, self.md5)

if not os.path.isdir(_voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

_splits_dir = os.path.join(_voc_root, 'ImageSets/Main')

_split_f = os.path.join(_splits_dir, image_set.rstrip('\n') + '.txt')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need the rstrip? just out of curiosity


if not os.path.exists(_split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val" or a valid'
'image_set from the VOC ImageSets/Main folder.')

self.images = []
self.annotations = []
with open(os.path.join(_split_f), "r") as lines:
for line in lines:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we instead do

with open(os.path.join(_split_f), "r") as f:
    image_names = f.readlines()

? I believe this strips out the \n in the end, and is a bit faster than the current version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked and it doesn't strip the \n
the value of image_names is:

['2008_000008\n', '2008_000015\n', '2008_000019\n', '2008_000023\n', '2008_000028\n', '2008_000033\n', '2008_000036\n', '2008_000037\n', '2008_000041\n', '2008_000045\n', '2008_000053\n', '2008_000060\n', '2008_000066\n', '2008_000070\n', '2008_000074\n', '2008_000085\n', '2008_000089\n', '2008_000093\n', '2008_000095\n', '2008_000096\n', '2008_000097\n', '2008_000099\n', '2008_000103\n', '2008_000105\n', '2008_000109\n', '2008_000112\n'...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I think what I had done then was something like

with open(os.path.join(_split_f), "r") as f:
    image_names = [x.strip() for x in f.readlines()]

but do it as you think it's better.

_image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg")
_annotation = os.path.join(_annotation_dir,
line.rstrip('\n') + ".xml")
assert os.path.isfile(_image)
assert os.path.isfile(_annotation)
self.images.append(_image)
self.annotations.append(_annotation)

assert (len(self.images) == len(self.annotations))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is a list of bounding boxes of
relative coordinates like``[[xmin, ymin, xmax, ymax, ind], [...], ...]``.
"""
_img = Image.open(self.images[index]).convert('RGB')
_target = self._get_bboxes(ET.parse(self.annotations[index]).getroot())

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead return the raw result from ET.parse().getroot()? Or maybe a dict with the full parsing of the xml?
It's up to the user to decide what they actually want for their target, and we are forcing them to use one format now (which doesn't hold all the information from the dataset, such as truncated / occluded / etc).


if self.transform is not None:
_img = self.transform(_img)

if self.target_transform is not None:
_target = self.target_transform(_target)

return _img, _target

def __len__(self):
return len(self.images)

def _get_bboxes(self, target):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably let the user write this down. Maybe what would be the most user-friendly would be to parse the ET and return a nested dict?

Copy link
Contributor Author

@bpinaya bpinaya Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I thought about it, maybe a dict would be better, but I saw here that they pass a Bb class. I think whatever is easier for the end user maybe. Regarding the iteration of the ET, any ideas to make it recursive and elegant? I implemented a function but it's way too hacky.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In maskrcnn-benchmark, we have a dedicated class BoxList which is used everywhere in the codebase.
My original plan was to move BoxList to torchvision, but it needs to mature a bit more before we move it here.

About the ET, I suppose there is a function call that we can use that would enable us to get it recursively? Probably something like this that I wrote for lua

res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not self.keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
width = int(target.find('size').find('width').text)
height = int(target.find('size').find('height').text)
bndbox = []
for i, cur_bb in enumerate(bbox):
bb_sz = int(cur_bb.text) - 1
# relative coordinates
bb_sz = bb_sz / width if i % 2 == 0 else bb_sz / height
bndbox.append(bb_sz)

label_ind = self.class_to_ind[name]
bndbox.append(label_ind)
res.append(bndbox) # [xmin, ymin, xmax, ymax, ind]
return res


def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)