Skip to content

Commit a6f3f95

Browse files
pmeierfmassa
andauthored
add tests for Coco (#3416)
Co-authored-by: Francisco Massa <[email protected]>
1 parent ccb7f45 commit a6f3f95

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test/test_datasets.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchvision import datasets
2222
import torch
2323
import shutil
24+
import json
2425

2526

2627
try:
@@ -839,5 +840,70 @@ def test_annotations(self):
839840
self.assertEqual(object, info["annotation"])
840841

841842

843+
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
844+
DATASET_CLASS = datasets.CocoDetection
845+
FEATURE_TYPES = (PIL.Image.Image, list)
846+
847+
REQUIRED_PACKAGES = ("pycocotools",)
848+
849+
def inject_fake_data(self, tmpdir, config):
850+
tmpdir = pathlib.Path(tmpdir)
851+
852+
num_images = 3
853+
num_annotations_per_image = 2
854+
855+
image_folder = tmpdir / "images"
856+
files = datasets_utils.create_image_folder(
857+
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
858+
)
859+
file_names = [file.relative_to(image_folder) for file in files]
860+
861+
annotation_folder = tmpdir / "annotations"
862+
os.makedirs(annotation_folder)
863+
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)
864+
865+
info["num_examples"] = num_images
866+
return (str(image_folder), str(annotation_file)), info
867+
868+
def _create_annotation_file(self, root, file_names, num_annotations_per_image):
869+
image_ids = [int(file_name.stem) for file_name in file_names]
870+
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]
871+
872+
annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
873+
874+
content = dict(images=images, annotations=annotations)
875+
return self._create_json(root, "annotations.json", content), info
876+
877+
def _create_annotations(self, image_ids, num_annotations_per_image):
878+
annotations = datasets_utils.combinations_grid(
879+
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
880+
)
881+
for id, annotation in enumerate(annotations):
882+
annotation["id"] = id
883+
return annotations, dict()
884+
885+
def _create_json(self, root, name, content):
886+
file = pathlib.Path(root) / name
887+
with open(file, "w") as fh:
888+
json.dump(content, fh)
889+
return file
890+
891+
892+
class CocoCaptionsTestCase(CocoDetectionTestCase):
893+
DATASET_CLASS = datasets.CocoCaptions
894+
895+
def _create_annotations(self, image_ids, num_annotations_per_image):
896+
captions = [str(idx) for idx in range(num_annotations_per_image)]
897+
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions)
898+
for id, annotation in enumerate(annotations):
899+
annotation["id"] = id
900+
return annotations, dict(captions=captions)
901+
902+
def test_captions(self):
903+
with self.create_dataset() as (dataset, info):
904+
_, captions = dataset[0]
905+
self.assertEqual(tuple(captions), tuple(info["captions"]))
906+
907+
842908
if __name__ == "__main__":
843909
unittest.main()

0 commit comments

Comments
 (0)