|
19 | 19 | import pathlib
|
20 | 20 | import pickle
|
21 | 21 | from torchvision import datasets
|
| 22 | +import json |
22 | 23 |
|
23 | 24 |
|
24 | 25 | try:
|
@@ -560,5 +561,70 @@ class CIFAR100(CIFAR10TestCase):
|
560 | 561 | )
|
561 | 562 |
|
562 | 563 |
|
| 564 | +class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): |
| 565 | + DATASET_CLASS = datasets.CocoDetection |
| 566 | + FEATURE_TYPES = (PIL.Image.Image, list) |
| 567 | + |
| 568 | + REQUIRED_PACKAGES = ("pycocotools",) |
| 569 | + |
| 570 | + def inject_fake_data(self, tmpdir, config): |
| 571 | + tmpdir = pathlib.Path(tmpdir) |
| 572 | + |
| 573 | + num_images = 3 |
| 574 | + num_annotations_per_image = 2 |
| 575 | + |
| 576 | + image_folder = tmpdir / "images" |
| 577 | + files = datasets_utils.create_image_folder( |
| 578 | + tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images |
| 579 | + ) |
| 580 | + file_names = [file.relative_to(image_folder) for file in files] |
| 581 | + |
| 582 | + annotation_folder = tmpdir / "annotations" |
| 583 | + os.makedirs(annotation_folder) |
| 584 | + annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image) |
| 585 | + |
| 586 | + info["num_examples"] = num_images |
| 587 | + return (str(image_folder), str(annotation_file)), info |
| 588 | + |
| 589 | + def _create_annotation_file(self, root, file_names, num_annotations_per_image): |
| 590 | + image_ids = [int(file_name.stem) for file_name in file_names] |
| 591 | + images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)] |
| 592 | + |
| 593 | + annotations, info = self._create_annotations(image_ids, num_annotations_per_image) |
| 594 | + |
| 595 | + content = dict(images=images, annotations=annotations) |
| 596 | + return self._create_json(root, "annotations.json", content), info |
| 597 | + |
| 598 | + def _create_annotations(self, image_ids, num_annotations_per_image): |
| 599 | + annotations = datasets_utils.combinations_grid( |
| 600 | + image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image |
| 601 | + ) |
| 602 | + for id, annotation in enumerate(annotations): |
| 603 | + annotation["id"] = id |
| 604 | + return annotations, dict() |
| 605 | + |
| 606 | + def _create_json(self, root, name, content): |
| 607 | + file = pathlib.Path(root) / name |
| 608 | + with open(file, "w") as fh: |
| 609 | + json.dump(content, fh) |
| 610 | + return file |
| 611 | + |
| 612 | + |
| 613 | +class CocoCaptionsTestCase(CocoDetectionTestCase): |
| 614 | + DATASET_CLASS = datasets.CocoCaptions |
| 615 | + |
| 616 | + def _create_annotations(self, image_ids, num_annotations_per_image): |
| 617 | + captions = [str(idx) for idx in range(num_annotations_per_image)] |
| 618 | + annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions) |
| 619 | + for id, annotation in enumerate(annotations): |
| 620 | + annotation["id"] = id |
| 621 | + return annotations, dict(captions=captions) |
| 622 | + |
| 623 | + def test_captions(self): |
| 624 | + with self.create_dataset() as (dataset, info): |
| 625 | + _, captions = dataset[0] |
| 626 | + self.assertEqual(tuple(captions), tuple(info["captions"])) |
| 627 | + |
| 628 | + |
563 | 629 | if __name__ == "__main__":
|
564 | 630 | unittest.main()
|
0 commit comments