|
21 | 21 | from torchvision import datasets
|
22 | 22 | import torch
|
23 | 23 | import shutil
|
| 24 | +import json |
24 | 25 |
|
25 | 26 |
|
26 | 27 | try:
|
@@ -839,5 +840,70 @@ def test_annotations(self):
|
839 | 840 | self.assertEqual(object, info["annotation"])
|
840 | 841 |
|
841 | 842 |
|
| 843 | +class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): |
| 844 | + DATASET_CLASS = datasets.CocoDetection |
| 845 | + FEATURE_TYPES = (PIL.Image.Image, list) |
| 846 | + |
| 847 | + REQUIRED_PACKAGES = ("pycocotools",) |
| 848 | + |
| 849 | + def inject_fake_data(self, tmpdir, config): |
| 850 | + tmpdir = pathlib.Path(tmpdir) |
| 851 | + |
| 852 | + num_images = 3 |
| 853 | + num_annotations_per_image = 2 |
| 854 | + |
| 855 | + image_folder = tmpdir / "images" |
| 856 | + files = datasets_utils.create_image_folder( |
| 857 | + tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images |
| 858 | + ) |
| 859 | + file_names = [file.relative_to(image_folder) for file in files] |
| 860 | + |
| 861 | + annotation_folder = tmpdir / "annotations" |
| 862 | + os.makedirs(annotation_folder) |
| 863 | + annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image) |
| 864 | + |
| 865 | + info["num_examples"] = num_images |
| 866 | + return (str(image_folder), str(annotation_file)), info |
| 867 | + |
| 868 | + def _create_annotation_file(self, root, file_names, num_annotations_per_image): |
| 869 | + image_ids = [int(file_name.stem) for file_name in file_names] |
| 870 | + images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)] |
| 871 | + |
| 872 | + annotations, info = self._create_annotations(image_ids, num_annotations_per_image) |
| 873 | + |
| 874 | + content = dict(images=images, annotations=annotations) |
| 875 | + return self._create_json(root, "annotations.json", content), info |
| 876 | + |
| 877 | + def _create_annotations(self, image_ids, num_annotations_per_image): |
| 878 | + annotations = datasets_utils.combinations_grid( |
| 879 | + image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image |
| 880 | + ) |
| 881 | + for id, annotation in enumerate(annotations): |
| 882 | + annotation["id"] = id |
| 883 | + return annotations, dict() |
| 884 | + |
| 885 | + def _create_json(self, root, name, content): |
| 886 | + file = pathlib.Path(root) / name |
| 887 | + with open(file, "w") as fh: |
| 888 | + json.dump(content, fh) |
| 889 | + return file |
| 890 | + |
| 891 | + |
| 892 | +class CocoCaptionsTestCase(CocoDetectionTestCase): |
| 893 | + DATASET_CLASS = datasets.CocoCaptions |
| 894 | + |
| 895 | + def _create_annotations(self, image_ids, num_annotations_per_image): |
| 896 | + captions = [str(idx) for idx in range(num_annotations_per_image)] |
| 897 | + annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions) |
| 898 | + for id, annotation in enumerate(annotations): |
| 899 | + annotation["id"] = id |
| 900 | + return annotations, dict(captions=captions) |
| 901 | + |
| 902 | + def test_captions(self): |
| 903 | + with self.create_dataset() as (dataset, info): |
| 904 | + _, captions = dataset[0] |
| 905 | + self.assertEqual(tuple(captions), tuple(info["captions"])) |
| 906 | + |
| 907 | + |
842 | 908 | if __name__ == "__main__":
|
843 | 909 | unittest.main()
|
0 commit comments