From ddcf19784e7c2f3132706ac4d6cd13f61b3f8841 Mon Sep 17 00:00:00 2001 From: RazaProdigy Date: Mon, 22 Jan 2024 00:51:45 +0530 Subject: [PATCH 1/2] Added Type Check for cocodetection dataset --- torchvision/datasets/coco.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 7fda9667938..fd374a8ec24 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -44,6 +44,10 @@ def _load_target(self, id: int) -> List[Any]: return self.coco.loadAnns(self.coco.getAnnIds(id)) def __getitem__(self, index: int) -> Tuple[Any, Any]: + + if not isinstance(index, int): + raise ValueError(f"Index must be of type integer, got {type(index)} instead.") + id = self.ids[index] image = self._load_image(id) target = self._load_target(id) From 6fe30ac6b585367415bfae60dd999ad555bbd68c Mon Sep 17 00:00:00 2001 From: RazaProdigy Date: Mon, 22 Jan 2024 21:55:15 +0530 Subject: [PATCH 2/2] Added slice error test for cocodetection dataset --- test/test_datasets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index bee781d488d..4878dbaaa60 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -827,6 +827,11 @@ def test_transforms_v2_wrapper_spawn(self): with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) + def test_slice_error(self): + with self.create_dataset() as (dataset, _): + with pytest.raises(ValueError, match="Index must be of type integer"): + dataset[:2] + class CocoCaptionsTestCase(CocoDetectionTestCase): DATASET_CLASS = datasets.CocoCaptions