Skip to content

Commit 01f07ee

Browse files
Dbhasin1Dbhasin1NicolasHug
authored
add Country211 prototype dataset (#5506)
* add country211 * remove unused import * map val to valid and use path comparator * remove unused import * resolve keyerror * map split names in dataset mock Co-authored-by: Dbhasin1 <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 71d2bb0 commit 01f07ee

File tree

4 files changed

+296
-0
lines changed

4 files changed

+296
-0
lines changed

test/builtin_dataset_mocks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,34 @@ def celeba(info, root, config):
878878
return CelebAMockData.generate(root)[config.split]
879879

880880

881+
@register_mock
882+
def country211(info, root, config):
883+
split_name_mapper = {
884+
"train": "train",
885+
"val": "valid",
886+
"test": "test",
887+
}
888+
split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]])
889+
split_folder.mkdir(parents=True, exist_ok=True)
890+
891+
num_examples = {
892+
"train": 3,
893+
"val": 4,
894+
"test": 5,
895+
}[config["split"]]
896+
897+
classes = ("AD", "BS", "GR")
898+
for cls in classes:
899+
create_image_folder(
900+
split_folder,
901+
name=cls,
902+
file_name_fn=lambda idx: f"{idx}.jpg",
903+
num_examples=num_examples,
904+
)
905+
make_tar(root, f"{split_folder.parent.name}.tgz", split_folder.parent, compression="gz")
906+
return num_examples * len(classes)
907+
908+
881909
@register_mock
882910
def dtd(info, root, config):
883911
data_folder = root / "dtd"

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .cifar import Cifar10, Cifar100
44
from .clevr import CLEVR
55
from .coco import Coco
6+
from .country211 import Country211
67
from .cub200 import CUB200
78
from .dtd import DTD
89
from .fer2013 import FER2013
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
AD
2+
AE
3+
AF
4+
AG
5+
AI
6+
AL
7+
AM
8+
AO
9+
AQ
10+
AR
11+
AT
12+
AU
13+
AW
14+
AX
15+
AZ
16+
BA
17+
BB
18+
BD
19+
BE
20+
BF
21+
BG
22+
BH
23+
BJ
24+
BM
25+
BN
26+
BO
27+
BQ
28+
BR
29+
BS
30+
BT
31+
BW
32+
BY
33+
BZ
34+
CA
35+
CD
36+
CF
37+
CH
38+
CI
39+
CK
40+
CL
41+
CM
42+
CN
43+
CO
44+
CR
45+
CU
46+
CV
47+
CW
48+
CY
49+
CZ
50+
DE
51+
DK
52+
DM
53+
DO
54+
DZ
55+
EC
56+
EE
57+
EG
58+
ES
59+
ET
60+
FI
61+
FJ
62+
FK
63+
FO
64+
FR
65+
GA
66+
GB
67+
GD
68+
GE
69+
GF
70+
GG
71+
GH
72+
GI
73+
GL
74+
GM
75+
GP
76+
GR
77+
GS
78+
GT
79+
GU
80+
GY
81+
HK
82+
HN
83+
HR
84+
HT
85+
HU
86+
ID
87+
IE
88+
IL
89+
IM
90+
IN
91+
IQ
92+
IR
93+
IS
94+
IT
95+
JE
96+
JM
97+
JO
98+
JP
99+
KE
100+
KG
101+
KH
102+
KN
103+
KP
104+
KR
105+
KW
106+
KY
107+
KZ
108+
LA
109+
LB
110+
LC
111+
LI
112+
LK
113+
LR
114+
LT
115+
LU
116+
LV
117+
LY
118+
MA
119+
MC
120+
MD
121+
ME
122+
MF
123+
MG
124+
MK
125+
ML
126+
MM
127+
MN
128+
MO
129+
MQ
130+
MR
131+
MT
132+
MU
133+
MV
134+
MW
135+
MX
136+
MY
137+
MZ
138+
NA
139+
NC
140+
NG
141+
NI
142+
NL
143+
NO
144+
NP
145+
NZ
146+
OM
147+
PA
148+
PE
149+
PF
150+
PG
151+
PH
152+
PK
153+
PL
154+
PR
155+
PS
156+
PT
157+
PW
158+
PY
159+
QA
160+
RE
161+
RO
162+
RS
163+
RU
164+
RW
165+
SA
166+
SB
167+
SC
168+
SD
169+
SE
170+
SG
171+
SH
172+
SI
173+
SJ
174+
SK
175+
SL
176+
SM
177+
SN
178+
SO
179+
SS
180+
SV
181+
SX
182+
SY
183+
SZ
184+
TG
185+
TH
186+
TJ
187+
TL
188+
TM
189+
TN
190+
TO
191+
TR
192+
TT
193+
TW
194+
TZ
195+
UA
196+
UG
197+
US
198+
UY
199+
UZ
200+
VA
201+
VE
202+
VG
203+
VI
204+
VN
205+
VU
206+
WS
207+
XK
208+
YE
209+
ZA
210+
ZM
211+
ZW
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pathlib
2+
from typing import Any, Dict, List, Tuple
3+
4+
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter
5+
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
6+
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling
7+
from torchvision.prototype.features import EncodedImage, Label
8+
9+
10+
class Country211(Dataset):
11+
def _make_info(self) -> DatasetInfo:
12+
return DatasetInfo(
13+
"country211",
14+
homepage="https://github.com/openai/CLIP/blob/main/data/country211.md",
15+
valid_options=dict(split=("train", "val", "test")),
16+
)
17+
18+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
19+
return [
20+
HttpResource(
21+
"https://openaipublic.azureedge.net/clip/data/country211.tgz",
22+
sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c",
23+
)
24+
]
25+
26+
_SPLIT_NAME_MAPPER = {
27+
"train": "train",
28+
"val": "valid",
29+
"test": "test",
30+
}
31+
32+
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
33+
path, buffer = data
34+
category = pathlib.Path(path).parent.name
35+
return dict(
36+
label=Label.from_category(category, categories=self.categories),
37+
path=path,
38+
image=EncodedImage.from_file(buffer),
39+
)
40+
41+
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
42+
return pathlib.Path(data[0]).parent.parent.name == split
43+
44+
def _make_datapipe(
45+
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
46+
) -> IterDataPipe[Dict[str, Any]]:
47+
dp = resource_dps[0]
48+
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split]))
49+
dp = hint_sharding(dp)
50+
dp = hint_shuffling(dp)
51+
return Mapper(dp, self._prepare_sample)
52+
53+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
54+
resources = self.resources(self.default_config)
55+
dp = resources[0].load(root)
56+
return sorted({pathlib.Path(path).parent.name for path, _ in dp})

0 commit comments

Comments
 (0)