From b8df1c840bd63ea04002024f808884505f8bd6ec Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Tue, 1 Mar 2022 16:46:10 +0000 Subject: [PATCH 1/6] add country211 --- test/builtin_dataset_mocks.py | 24 ++ .../prototype/datasets/_builtin/__init__.py | 1 + .../datasets/_builtin/country211.categories | 211 ++++++++++++++++++ .../prototype/datasets/_builtin/country211.py | 51 +++++ 4 files changed, 287 insertions(+) create mode 100644 torchvision/prototype/datasets/_builtin/country211.categories create mode 100644 torchvision/prototype/datasets/_builtin/country211.py diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 123d8f29d3f..736da1e61da 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -11,6 +11,7 @@ import random import xml.etree.ElementTree as ET from collections import defaultdict, Counter +from logging import RootLogger import numpy as np import PIL.Image @@ -878,6 +879,29 @@ def celeba(info, root, config): return CelebAMockData.generate(root)[config.split] +@register_mock +def country211(info, root, config): + split_folder = pathlib.Path(root, "country211", config["split"]) + split_folder.mkdir(parents=True, exist_ok=True) + + num_examples = { + "train": 3, + "valid": 4, + "test": 5, + }[config["split"]] + + classes = ("AD", "BS", "GR") + for cls in classes: + create_image_folder( + split_folder, + name=cls, + file_name_fn=lambda idx: f"{idx}.jpg", + num_examples=num_examples, + ) + make_tar(root, f"{split_folder.parent.name}.tgz", split_folder.parent, compression="gz") + return num_examples * len(classes) + + @register_mock def dtd(info, root, config): data_folder = root / "dtd" diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 9fdfca904f5..d57b5555727 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -3,6 +3,7 @@ from .cifar import Cifar10, Cifar100 from .clevr import CLEVR from .coco import Coco +from .country211 import Country211 from .cub200 import CUB200 from .dtd import DTD from .fer2013 import FER2013 diff --git a/torchvision/prototype/datasets/_builtin/country211.categories b/torchvision/prototype/datasets/_builtin/country211.categories new file mode 100644 index 00000000000..6fc3e99a185 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/country211.categories @@ -0,0 +1,211 @@ +AD +AE +AF +AG +AI +AL +AM +AO +AQ +AR +AT +AU +AW +AX +AZ +BA +BB +BD +BE +BF +BG +BH +BJ +BM +BN +BO +BQ +BR +BS +BT +BW +BY +BZ +CA +CD +CF +CH +CI +CK +CL +CM +CN +CO +CR +CU +CV +CW +CY +CZ +DE +DK +DM +DO +DZ +EC +EE +EG +ES +ET +FI +FJ +FK +FO +FR +GA +GB +GD +GE +GF +GG +GH +GI +GL +GM +GP +GR +GS +GT +GU +GY +HK +HN +HR +HT +HU +ID +IE +IL +IM +IN +IQ +IR +IS +IT +JE +JM +JO +JP +KE +KG +KH +KN +KP +KR +KW +KY +KZ +LA +LB +LC +LI +LK +LR +LT +LU +LV +LY +MA +MC +MD +ME +MF +MG +MK +ML +MM +MN +MO +MQ +MR +MT +MU +MV +MW +MX +MY +MZ +NA +NC +NG +NI +NL +NO +NP +NZ +OM +PA +PE +PF +PG +PH +PK +PL +PR +PS +PT +PW +PY +QA +RE +RO +RS +RU +RW +SA +SB +SC +SD +SE +SG +SH +SI +SJ +SK +SL +SM +SN +SO +SS +SV +SX +SY +SZ +TG +TH +TJ +TL +TM +TN +TO +TR +TT +TW +TZ +UA +UG +US +UY +UZ +VA +VE +VG +VI +VN +VU +WS +XK +YE +ZA +ZM +ZW diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py new file mode 100644 index 00000000000..35bcdb53a74 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -0,0 +1,51 @@ +import functools +import pathlib +from typing import Any, Dict, List, Tuple + +from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter +from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import EncodedImage, Label + + +class Country211(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "country211", + homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", + valid_options=dict(split=("train", "valid", "test")), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "https://openaipublic.azureedge.net/clip/data/country211.tgz", + sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c", + ) + ] + + def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: + path, buffer = data + category = pathlib.Path(path).parent.name + return dict( + label=Label.from_category(category, categories=self.categories), + path=path, + image=EncodedImage.from_file(buffer), + ) + + def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: + return pathlib.Path(data[0]).parent.parent.name == split + + def _make_datapipe( + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + dp = Filter(dp, functools.partial(self._filter_split, split=config.split)) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + resources = self.resources(self.default_config) + dp = resources[0].load(root) + return sorted({pathlib.Path(path).parent.name for path, _ in dp}) From 2afa6c4b76cf4de7e2dc7b0ec0e8f7c63a1d60e3 Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Tue, 1 Mar 2022 17:04:56 +0000 Subject: [PATCH 2/6] remove unused import --- test/builtin_dataset_mocks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 736da1e61da..e72633997be 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -11,7 +11,6 @@ import random import xml.etree.ElementTree as ET from collections import defaultdict, Counter -from logging import RootLogger import numpy as np import PIL.Image From 4b8d51891d6a944271874d2ec684d667f0fac692 Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Wed, 2 Mar 2022 04:56:18 +0000 Subject: [PATCH 3/6] map val to valid and use path comparator --- .../prototype/datasets/_builtin/country211.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 35bcdb53a74..4788ddb844c 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -4,7 +4,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label @@ -13,7 +13,7 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( "country211", homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", - valid_options=dict(split=("train", "valid", "test")), + valid_options=dict(split=("train", "val", "test")), ) def resources(self, config: DatasetConfig) -> List[OnlineResource]: @@ -24,6 +24,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) ] + _SPLIT_NAME_MAPPER = { + "train": "train", + "val": "valid", + "test": "test", + } + def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name @@ -40,7 +46,7 @@ def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, functools.partial(self._filter_split, split=config.split)) + dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) dp = hint_sharding(dp) dp = hint_shuffling(dp) return Mapper(dp, self._prepare_sample) From 927ae6ec62a8e3aafed4cf2b5e44e604929b832e Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Wed, 2 Mar 2022 04:57:38 +0000 Subject: [PATCH 4/6] remove unused import --- torchvision/prototype/datasets/_builtin/country211.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 4788ddb844c..7bf9d2f46d8 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,4 +1,3 @@ -import functools import pathlib from typing import Any, Dict, List, Tuple From 3813fb634a6a0c986707f68a771928a60c452b97 Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Wed, 2 Mar 2022 06:14:07 +0000 Subject: [PATCH 5/6] resolve keyerror --- test/builtin_dataset_mocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index e72633997be..51b139cd21f 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -885,7 +885,7 @@ def country211(info, root, config): num_examples = { "train": 3, - "valid": 4, + "val": 4, "test": 5, }[config["split"]] From 7fe356a4db6c05fb03badc55ca693eddb47e5a82 Mon Sep 17 00:00:00 2001 From: Dbhasin1 Date: Wed, 2 Mar 2022 06:36:19 +0000 Subject: [PATCH 6/6] map split names in dataset mock --- test/builtin_dataset_mocks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 51b139cd21f..31d0aadc64d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -880,7 +880,12 @@ def celeba(info, root, config): @register_mock def country211(info, root, config): - split_folder = pathlib.Path(root, "country211", config["split"]) + split_name_mapper = { + "train": "train", + "val": "valid", + "test": "test", + } + split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]]) split_folder.mkdir(parents=True, exist_ok=True) num_examples = {