Skip to content

Commit 058f4bd

Browse files
authored
OxfordIIITPet dataset (#5116)
* add prototype dataset for oxford-iiit-pet * add old-style dataset * add tests * fix mypy * fix test * remove properties and use pathlib * target_type to target_types * move target annotation * add docstring * fix test
1 parent f01b533 commit 058f4bd

File tree

7 files changed

+377
-0
lines changed

7 files changed

+377
-0
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
6262
LSUN
6363
MNIST
6464
Omniglot
65+
OxfordIIITPet
6566
PhotoTour
6667
Places365
6768
QMNIST

test/test_datasets.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,5 +2357,65 @@ def inject_fake_data(self, tmpdir, config):
23572357
return len(image_files)
23582358

23592359

2360+
class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
2361+
DATASET_CLASS = datasets.OxfordIIITPet
2362+
FEATURE_TYPES = (PIL.Image.Image, (int, PIL.Image.Image, tuple, type(None)))
2363+
2364+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2365+
split=("trainval", "test"),
2366+
target_types=("category", "segmentation", ["category", "segmentation"], []),
2367+
)
2368+
2369+
def inject_fake_data(self, tmpdir, config):
2370+
base_folder = os.path.join(tmpdir, "oxford-iiit-pet")
2371+
2372+
classification_anns_meta = (
2373+
dict(cls="Abyssinian", label=0, species="cat"),
2374+
dict(cls="Keeshond", label=18, species="dog"),
2375+
dict(cls="Yorkshire Terrier", label=37, species="dog"),
2376+
)
2377+
split_and_classification_anns = [
2378+
self._meta_to_split_and_classification_ann(meta, idx)
2379+
for meta, idx in itertools.product(classification_anns_meta, (1, 2, 10))
2380+
]
2381+
image_ids, *_ = zip(*split_and_classification_anns)
2382+
2383+
image_files = datasets_utils.create_image_folder(
2384+
base_folder, "images", file_name_fn=lambda idx: f"{image_ids[idx]}.jpg", num_examples=len(image_ids)
2385+
)
2386+
2387+
anns_folder = os.path.join(base_folder, "annotations")
2388+
os.makedirs(anns_folder)
2389+
split_and_classification_anns_in_split = random.choices(split_and_classification_anns, k=len(image_ids) // 2)
2390+
with open(os.path.join(anns_folder, f"{config['split']}.txt"), "w", newline="") as file:
2391+
writer = csv.writer(file, delimiter=" ")
2392+
for split_and_classification_ann in split_and_classification_anns_in_split:
2393+
writer.writerow(split_and_classification_ann)
2394+
2395+
segmentation_files = datasets_utils.create_image_folder(
2396+
anns_folder, "trimaps", file_name_fn=lambda idx: f"{image_ids[idx]}.png", num_examples=len(image_ids)
2397+
)
2398+
2399+
# The dataset has some rogue files
2400+
for path in image_files[:2]:
2401+
path.with_suffix(".mat").touch()
2402+
for path in segmentation_files:
2403+
path.with_name(f".{path.name}").touch()
2404+
2405+
return len(split_and_classification_anns_in_split)
2406+
2407+
def _meta_to_split_and_classification_ann(self, meta, idx):
2408+
image_id = "_".join(
2409+
[
2410+
*[(str.title if meta["species"] == "cat" else str.lower)(part) for part in meta["cls"].split()],
2411+
str(idx),
2412+
]
2413+
)
2414+
class_id = str(meta["label"] + 1)
2415+
species = "1" if meta["species"] == "cat" else "2"
2416+
breed_id = "-1"
2417+
return (image_id, class_id, species, breed_id)
2418+
2419+
23602420
if __name__ == "__main__":
23612421
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .lsun import LSUN, LSUNClass
2222
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
2323
from .omniglot import Omniglot
24+
from .oxford_iiit_pet import OxfordIIITPet
2425
from .phototour import PhotoTour
2526
from .places365 import Places365
2627
from .sbd import SBDataset
@@ -87,4 +88,5 @@
8788
"FER2013",
8889
"GTSRB",
8990
"CLEVRClassification",
91+
"OxfordIIITPet",
9092
)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import os.path
3+
import pathlib
4+
from typing import Any, Callable, Optional, Union, Tuple
5+
from typing import Sequence
6+
7+
from PIL import Image
8+
9+
from .utils import download_and_extract_archive, verify_str_arg
10+
from .vision import VisionDataset
11+
12+
13+
class OxfordIIITPet(VisionDataset):
14+
"""`Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_.
15+
16+
Args:
17+
root (string): Root directory of the dataset.
18+
split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
19+
target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
20+
``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
21+
22+
- ``category`` (int): Label for one of the 37 pet categories.
23+
- ``segmentation`` (PIL image): Segmentation trimap of the image.
24+
25+
If empty, ``None`` will be returned as target.
26+
27+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
28+
version. E.g, ``transforms.RandomCrop``.
29+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30+
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/dtd``. If
31+
dataset is already downloaded, it is not downloaded again.
32+
"""
33+
34+
_RESOURCES = (
35+
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
36+
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
37+
)
38+
_VALID_TARGET_TYPES = ("category", "segmentation")
39+
40+
def __init__(
41+
self,
42+
root: str,
43+
split: str = "trainval",
44+
target_types: Union[Sequence[str], str] = "category",
45+
transforms: Optional[Callable] = None,
46+
transform: Optional[Callable] = None,
47+
target_transform: Optional[Callable] = None,
48+
download: bool = True,
49+
):
50+
self._split = verify_str_arg(split, "split", ("trainval", "test"))
51+
if isinstance(target_types, str):
52+
target_types = [target_types]
53+
self._target_types = [
54+
verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
55+
]
56+
57+
super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
58+
self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
59+
self._images_folder = self._base_folder / "images"
60+
self._anns_folder = self._base_folder / "annotations"
61+
self._segs_folder = self._anns_folder / "trimaps"
62+
63+
if download:
64+
self._download()
65+
66+
if not self._check_exists():
67+
raise RuntimeError("Dataset not found. You can use download=True to download it")
68+
69+
image_ids = []
70+
self._labels = []
71+
with open(self._anns_folder / f"{self._split}.txt") as file:
72+
for line in file:
73+
image_id, label, *_ = line.strip().split()
74+
image_ids.append(image_id)
75+
self._labels.append(int(label) - 1)
76+
77+
self.classes = [
78+
" ".join(part.title() for part in raw_cls.split("_"))
79+
for raw_cls, _ in sorted(
80+
{(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
81+
key=lambda image_id_and_label: image_id_and_label[1],
82+
)
83+
]
84+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
85+
86+
self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
87+
self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
88+
89+
def __len__(self) -> int:
90+
return len(self._images)
91+
92+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
93+
image = Image.open(self._images[idx]).convert("RGB")
94+
95+
target: Any = []
96+
for target_type in self._target_types:
97+
if target_type == "category":
98+
target.append(self._labels[idx])
99+
else: # target_type == "segmentation"
100+
target.append(Image.open(self._segs[idx]))
101+
102+
if not target:
103+
target = None
104+
elif len(target) == 1:
105+
target = target[0]
106+
else:
107+
target = tuple(target)
108+
109+
if self.transforms:
110+
image, target = self.transforms(image, target)
111+
112+
return image, target
113+
114+
def _check_exists(self) -> bool:
115+
for folder in (self._images_folder, self._anns_folder):
116+
if not (os.path.exists(folder) and os.path.isdir(folder)):
117+
return False
118+
else:
119+
return True
120+
121+
def _download(self) -> None:
122+
if self._check_exists():
123+
return
124+
125+
for url, md5 in self._RESOURCES:
126+
download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .fer2013 import FER2013
88
from .imagenet import ImageNet
99
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
10+
from .oxford_iiit_pet import OxfordIITPet
1011
from .sbd import SBD
1112
from .semeion import SEMEION
1213
from .voc import VOC
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
Abyssinian
2+
American Bulldog
3+
American Pit Bull Terrier
4+
Basset Hound
5+
Beagle
6+
Bengal
7+
Birman
8+
Bombay
9+
Boxer
10+
British Shorthair
11+
Chihuahua
12+
Egyptian Mau
13+
English Cocker Spaniel
14+
English Setter
15+
German Shorthaired
16+
Great Pyrenees
17+
Havanese
18+
Japanese Chin
19+
Keeshond
20+
Leonberger
21+
Maine Coon
22+
Miniature Pinscher
23+
Newfoundland
24+
Persian
25+
Pomeranian
26+
Pug
27+
Ragdoll
28+
Russian Blue
29+
Saint Bernard
30+
Samoyed
31+
Scottish Terrier
32+
Shiba Inu
33+
Siamese
34+
Sphynx
35+
Staffordshire Bull Terrier
36+
Wheaten Terrier
37+
Yorkshire Terrier

0 commit comments

Comments
 (0)