Skip to content

Commit ee8c7a0

Browse files
add tests for USPS dataset (#3466)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: fmassa Differential Revision: D26756275 fbshipit-source-id: 52ed0b84541cde380a3ac539cc49c92d57bbebab
1 parent 156f184 commit ee8c7a0

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

test/test_datasets.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import shutil
2424
import json
2525
import random
26+
import bz2
2627
import torch.nn.functional as F
2728
import string
2829
import io
@@ -1173,5 +1174,24 @@ def inject_fake_data(self, tmpdir, config):
11731174
return num_images
11741175

11751176

1177+
class USPSTestCase(datasets_utils.ImageDatasetTestCase):
1178+
DATASET_CLASS = datasets.USPS
1179+
1180+
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1181+
1182+
def inject_fake_data(self, tmpdir, config):
1183+
num_images = 2 if config["train"] else 1
1184+
1185+
images = torch.rand(num_images, 256) * 2 - 1
1186+
labels = torch.randint(1, 11, size=(num_images,))
1187+
1188+
with bz2.open(pathlib.Path(tmpdir) / f"usps{'.t' if not config['train'] else ''}.bz2", "w") as fh:
1189+
for image, label in zip(images, labels):
1190+
line = " ".join((str(label.item()), *[f"{idx}:{pixel:.6f}" for idx, pixel in enumerate(image, 1)]))
1191+
fh.write(f"{line}\n".encode())
1192+
1193+
return num_images
1194+
1195+
11761196
if __name__ == "__main__":
11771197
unittest.main()

0 commit comments

Comments
 (0)