Skip to content

Commit 402eb28

Browse files
datumboxpmeier
authored andcommitted
[fbsync] Food101 new dataset api (#5584)
Summary: * [FEAT] Start implementing Food101 using the new datasets API. WIP. * [FEAT] Generate Food101 categories and start the test mock. * [FEAT] food101 dataset code seems to work now. * [TEST] food101 mock update. * [FIX] Some fixes thanks to running food101 tests. * [FIX] Fix mypy checks for the food101 file. * [FIX] Remove unused numpy. * [FIX] Some changes thanks to code review. * [ENH] More idomatic dataset code thanks to code review. * [FIX] Remove unused cast. * [ENH] Set decompress and extract to True for some performance gains. * [FEAT] Use the preprocess=decompress keyword. * [ENH] Use the train and test.txt file instead of the .json variants and simplify code + update mock data. * [ENH] Better food101 mock data generation. * [FIX] Remove a useless print. Reviewed By: NicolasHug Differential Revision: D35393170 fbshipit-source-id: c7f51bdfb2e05913593cdcba9e30994557afaf87 Co-authored-by: Philip Meier <[email protected]>
1 parent e32b44d commit 402eb28

File tree

4 files changed

+231
-0
lines changed

4 files changed

+231
-0
lines changed

test/builtin_dataset_mocks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,44 @@ def country211(info, root, config):
911911
return num_examples * len(classes)
912912

913913

914+
@register_mock
915+
def food101(info, root, config):
916+
data_folder = root / "food-101"
917+
918+
num_images_per_class = 3
919+
image_folder = data_folder / "images"
920+
categories = ["apple_pie", "baby_back_ribs", "waffles"]
921+
image_ids = []
922+
for category in categories:
923+
image_files = create_image_folder(
924+
image_folder,
925+
category,
926+
file_name_fn=lambda idx: f"{idx:04d}.jpg",
927+
num_examples=num_images_per_class,
928+
)
929+
image_ids.extend(path.relative_to(path.parents[1]).with_suffix("").as_posix() for path in image_files)
930+
931+
meta_folder = data_folder / "meta"
932+
meta_folder.mkdir()
933+
934+
with open(meta_folder / "classes.txt", "w") as file:
935+
for category in categories:
936+
file.write(f"{category}\n")
937+
938+
splits = ["train", "test"]
939+
num_samples_map = {}
940+
for offset, split in enumerate(splits):
941+
image_ids_in_split = image_ids[offset :: len(splits)]
942+
num_samples_map[split] = len(image_ids_in_split)
943+
with open(meta_folder / f"{split}.txt", "w") as file:
944+
for image_id in image_ids_in_split:
945+
file.write(f"{image_id}\n")
946+
947+
make_tar(root, f"{data_folder.name}.tar.gz", compression="gz")
948+
949+
return num_samples_map[config.split]
950+
951+
914952
@register_mock
915953
def dtd(info, root, config):
916954
data_folder = root / "dtd"

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .dtd import DTD
99
from .eurosat import EuroSAT
1010
from .fer2013 import FER2013
11+
from .food101 import Food101
1112
from .gtsrb import GTSRB
1213
from .imagenet import ImageNet
1314
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
apple_pie
2+
baby_back_ribs
3+
baklava
4+
beef_carpaccio
5+
beef_tartare
6+
beet_salad
7+
beignets
8+
bibimbap
9+
bread_pudding
10+
breakfast_burrito
11+
bruschetta
12+
caesar_salad
13+
cannoli
14+
caprese_salad
15+
carrot_cake
16+
ceviche
17+
cheesecake
18+
cheese_plate
19+
chicken_curry
20+
chicken_quesadilla
21+
chicken_wings
22+
chocolate_cake
23+
chocolate_mousse
24+
churros
25+
clam_chowder
26+
club_sandwich
27+
crab_cakes
28+
creme_brulee
29+
croque_madame
30+
cup_cakes
31+
deviled_eggs
32+
donuts
33+
dumplings
34+
edamame
35+
eggs_benedict
36+
escargots
37+
falafel
38+
filet_mignon
39+
fish_and_chips
40+
foie_gras
41+
french_fries
42+
french_onion_soup
43+
french_toast
44+
fried_calamari
45+
fried_rice
46+
frozen_yogurt
47+
garlic_bread
48+
gnocchi
49+
greek_salad
50+
grilled_cheese_sandwich
51+
grilled_salmon
52+
guacamole
53+
gyoza
54+
hamburger
55+
hot_and_sour_soup
56+
hot_dog
57+
huevos_rancheros
58+
hummus
59+
ice_cream
60+
lasagna
61+
lobster_bisque
62+
lobster_roll_sandwich
63+
macaroni_and_cheese
64+
macarons
65+
miso_soup
66+
mussels
67+
nachos
68+
omelette
69+
onion_rings
70+
oysters
71+
pad_thai
72+
paella
73+
pancakes
74+
panna_cotta
75+
peking_duck
76+
pho
77+
pizza
78+
pork_chop
79+
poutine
80+
prime_rib
81+
pulled_pork_sandwich
82+
ramen
83+
ravioli
84+
red_velvet_cake
85+
risotto
86+
samosa
87+
sashimi
88+
scallops
89+
seaweed_salad
90+
shrimp_and_grits
91+
spaghetti_bolognese
92+
spaghetti_carbonara
93+
spring_rolls
94+
steak
95+
strawberry_shortcake
96+
sushi
97+
tacos
98+
takoyaki
99+
tiramisu
100+
tuna_tartare
101+
waffles
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from pathlib import Path
2+
from typing import Any, Tuple, List, Dict, Optional, BinaryIO
3+
4+
from torchdata.datapipes.iter import (
5+
IterDataPipe,
6+
Filter,
7+
Mapper,
8+
LineReader,
9+
Demultiplexer,
10+
IterKeyZipper,
11+
)
12+
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource
13+
from torchvision.prototype.datasets.utils._internal import (
14+
hint_shuffling,
15+
hint_sharding,
16+
path_comparator,
17+
getitem,
18+
INFINITE_BUFFER_SIZE,
19+
)
20+
from torchvision.prototype.features import Label, EncodedImage
21+
22+
23+
class Food101(Dataset):
24+
def _make_info(self) -> DatasetInfo:
25+
return DatasetInfo(
26+
"food101",
27+
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
28+
valid_options=dict(split=("train", "test")),
29+
)
30+
31+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
32+
return [
33+
HttpResource(
34+
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
35+
sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4",
36+
preprocess="decompress",
37+
)
38+
]
39+
40+
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
41+
path = Path(data[0])
42+
if path.parents[1].name == "images":
43+
return 0
44+
elif path.parents[0].name == "meta":
45+
return 1
46+
else:
47+
return None
48+
49+
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
50+
id, (path, buffer) = data
51+
return dict(
52+
label=Label.from_category(id.split("/", 1)[0], categories=self.categories),
53+
path=path,
54+
image=EncodedImage.from_file(buffer),
55+
)
56+
57+
def _image_key(self, data: Tuple[str, Any]) -> str:
58+
path = Path(data[0])
59+
return path.relative_to(path.parents[1]).with_suffix("").as_posix()
60+
61+
def _make_datapipe(
62+
self,
63+
resource_dps: List[IterDataPipe],
64+
*,
65+
config: DatasetConfig,
66+
) -> IterDataPipe[Dict[str, Any]]:
67+
archive_dp = resource_dps[0]
68+
images_dp, split_dp = Demultiplexer(
69+
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
70+
)
71+
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
72+
split_dp = LineReader(split_dp, decode=True, return_path=False)
73+
split_dp = hint_sharding(split_dp)
74+
split_dp = hint_shuffling(split_dp)
75+
76+
dp = IterKeyZipper(
77+
split_dp,
78+
images_dp,
79+
key_fn=getitem(),
80+
ref_key_fn=self._image_key,
81+
buffer_size=INFINITE_BUFFER_SIZE,
82+
)
83+
84+
return Mapper(dp, self._prepare_sample)
85+
86+
def _generate_categories(self, root: Path) -> List[str]:
87+
resources = self.resources(self.default_config)
88+
dp = resources[0].load(root)
89+
dp = Filter(dp, path_comparator("name", "classes.txt"))
90+
dp = LineReader(dp, decode=True, return_path=False)
91+
return list(dp)

0 commit comments

Comments
 (0)