Skip to content

Commit a3cf806

Browse files
committed
add tests for Coco
1 parent 22c548b commit a3cf806

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test/test_datasets.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pathlib
2020
import pickle
2121
from torchvision import datasets
22+
import json
2223

2324

2425
try:
@@ -560,5 +561,70 @@ class CIFAR100(CIFAR10TestCase):
560561
)
561562

562563

564+
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
565+
DATASET_CLASS = datasets.CocoDetection
566+
FEATURE_TYPES = (PIL.Image.Image, list)
567+
568+
REQUIRED_PACKAGES = ("pycocotools",)
569+
570+
def inject_fake_data(self, tmpdir, config):
571+
tmpdir = pathlib.Path(tmpdir)
572+
573+
num_images = 3
574+
num_annotations_per_image = 2
575+
576+
image_folder = tmpdir / "images"
577+
files = datasets_utils.create_image_folder(
578+
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
579+
)
580+
file_names = [file.relative_to(image_folder) for file in files]
581+
582+
annotation_folder = tmpdir / "annotations"
583+
os.makedirs(annotation_folder)
584+
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)
585+
586+
info["num_examples"] = num_images
587+
return (str(image_folder), str(annotation_file)), info
588+
589+
def _create_annotation_file(self, root, file_names, num_annotations_per_image):
590+
image_ids = [int(file_name.stem) for file_name in file_names]
591+
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]
592+
593+
annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
594+
595+
content = dict(images=images, annotations=annotations)
596+
return self._create_json(root, "annotations.json", content), info
597+
598+
def _create_annotations(self, image_ids, num_annotations_per_image):
599+
annotations = datasets_utils.combinations_grid(
600+
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
601+
)
602+
for id, annotation in enumerate(annotations):
603+
annotation["id"] = id
604+
return annotations, dict()
605+
606+
def _create_json(self, root, name, content):
607+
file = pathlib.Path(root) / name
608+
with open(file, "w") as fh:
609+
json.dump(content, fh)
610+
return file
611+
612+
613+
class CocoCaptionsTestCase(CocoDetectionTestCase):
614+
DATASET_CLASS = datasets.CocoCaptions
615+
616+
def _create_annotations(self, image_ids, num_annotations_per_image):
617+
captions = [str(idx) for idx in range(num_annotations_per_image)]
618+
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions)
619+
for id, annotation in enumerate(annotations):
620+
annotation["id"] = id
621+
return annotations, dict(captions=captions)
622+
623+
def test_captions(self):
624+
with self.create_dataset() as (dataset, info):
625+
_, captions = dataset[0]
626+
self.assertEqual(tuple(captions), tuple(info["captions"]))
627+
628+
563629
if __name__ == "__main__":
564630
unittest.main()

0 commit comments

Comments
 (0)