Skip to content

Commit 1feb637

Browse files
authored
fix category file generation (#5188)
1 parent e3767f8 commit 1feb637

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool:
132132
return self._classify_archive(data) == DTDDemux.IMAGES
133133

134134
def _generate_categories(self, root: pathlib.Path) -> List[str]:
135-
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
135+
resources = self.resources(self.default_config)
136+
137+
dp = resources[0].load(root)
136138
dp = Filter(dp, self._filter_images)
139+
137140
return sorted({pathlib.Path(path).parent.name for path, _ in dp})

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,13 @@ def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
145145

146146
def _generate_categories(self, root: pathlib.Path) -> List[str]:
147147
config = self.default_config
148-
dp = self.resources(config)[1].load(pathlib.Path(root) / self.name)
148+
resources = self.resources(config)
149+
150+
dp = resources[1].load(root)
149151
dp = Filter(dp, self._filter_split_and_classification_anns)
150152
dp = Filter(dp, path_comparator("name", f"{config.split}.txt"))
151153
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
154+
152155
raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp}
153156
raw_categories, _ = zip(
154157
*sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))

0 commit comments

Comments
 (0)