Skip to content

Commit 7029839

Browse files
add tests for SBDataset (#3467)
Summary: Co-authored-by: Francisco Massa <[email protected]> Reviewed By: fmassa Differential Revision: D26756273 fbshipit-source-id: 53f4cb3022a0d434104624fc26b7b6ad3dfbd8ae
1 parent 44f4834 commit 7029839

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

test/datasets_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class LazyImporter:
5656
"pycocotools",
5757
"requests",
5858
"scipy.io",
59+
"scipy.sparse",
5960
)
6061

6162
def __init__(self):

test/test_datasets.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,5 +1193,73 @@ def inject_fake_data(self, tmpdir, config):
11931193
return num_images
11941194

11951195

1196+
class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
1197+
DATASET_CLASS = datasets.SBDataset
1198+
FEATURE_TYPES = (PIL.Image.Image, (np.ndarray, PIL.Image.Image))
1199+
1200+
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
1201+
1202+
CONFIGS = datasets_utils.combinations_grid(
1203+
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
1204+
)
1205+
1206+
_NUM_CLASSES = 20
1207+
1208+
def inject_fake_data(self, tmpdir, config):
1209+
num_images, num_images_per_image_set = self._create_split_files(tmpdir)
1210+
1211+
sizes = self._create_target_folder(tmpdir, "cls", num_images)
1212+
1213+
datasets_utils.create_image_folder(
1214+
tmpdir, "img", lambda idx: f"{self._file_stem(idx)}.jpg", num_images, size=lambda idx: sizes[idx]
1215+
)
1216+
1217+
return num_images_per_image_set[config["image_set"]]
1218+
1219+
def _create_split_files(self, root):
1220+
root = pathlib.Path(root)
1221+
1222+
splits = dict(train=(0, 1, 2), train_noval=(0, 2), val=(3,))
1223+
1224+
for split, idcs in splits.items():
1225+
self._create_split_file(root, split, idcs)
1226+
1227+
num_images = max(itertools.chain(*splits.values())) + 1
1228+
num_images_per_split = dict([(split, len(idcs)) for split, idcs in splits.items()])
1229+
return num_images, num_images_per_split
1230+
1231+
def _create_split_file(self, root, name, idcs):
1232+
with open(root / f"{name}.txt", "w") as fh:
1233+
fh.writelines(f"{self._file_stem(idx)}\n" for idx in idcs)
1234+
1235+
def _create_target_folder(self, root, name, num_images):
1236+
io = datasets_utils.lazy_importer.scipy.io
1237+
1238+
target_folder = pathlib.Path(root) / name
1239+
os.makedirs(target_folder)
1240+
1241+
sizes = [torch.randint(1, 4, size=(2,)).tolist() for _ in range(num_images)]
1242+
for idx, size in enumerate(sizes):
1243+
content = dict(
1244+
GTcls=dict(Boundaries=self._create_boundaries(size), Segmentation=self._create_segmentation(size))
1245+
)
1246+
io.savemat(target_folder / f"{self._file_stem(idx)}.mat", content)
1247+
1248+
return sizes
1249+
1250+
def _create_boundaries(self, size):
1251+
sparse = datasets_utils.lazy_importer.scipy.sparse
1252+
return [
1253+
[sparse.csc_matrix(torch.randint(0, 2, size=size, dtype=torch.uint8).numpy())]
1254+
for _ in range(self._NUM_CLASSES)
1255+
]
1256+
1257+
def _create_segmentation(self, size):
1258+
return torch.randint(0, self._NUM_CLASSES + 1, size=size, dtype=torch.uint8).numpy()
1259+
1260+
def _file_stem(self, idx):
1261+
return f"2008_{idx:06d}"
1262+
1263+
11961264
if __name__ == "__main__":
11971265
unittest.main()

0 commit comments

Comments
 (0)