Skip to content

Commit 71907be

Browse files
lezwonpmeierNicolasHugmalfetsahilg06
authored
USPS dataset (#5647)
* added usps dataset * fixed type issues * fix mobilnet norm layer test (#5643) * xfail mobilnet norm layer test * fix test * More robust check in tests for 16 bits images (#5652) * Prefer nvidia channel for conda builds (#5648) To mitigate missing `libcupti.so` dependency * fix torchdata CI installation (#5657) * update urls for kinetics dataset (#5578) * update urls for kinetics dataset * update urls for kinetics dataset * remove errors * update the changes and add test option to split * added test to valid values for split arg * change .txt to .csv for annotation url of k600 Co-authored-by: Nicolas Hug <[email protected]> * Port Multi-weight support from prototype to main (#5618) * Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet. * Porting googlenet * Porting inception * Porting mnasnet * Porting mobilenetv2 * Porting mobilenetv3 * Porting regnet * Porting resnet * Porting shufflenetv2 * Porting squeezenet * Porting vgg * Porting vit * Fix docstrings * Fixing imports * Adding missing import * Fix mobilenet imports * Fix tests * Fix prototype tests * Exclude get_weight from models on test * Fix init files * Porting googlenet * Porting inception * porting mobilenetv2 * porting mobilenetv3 * porting resnet * porting shufflenetv2 * Fix test and linter * Fixing docs. * Porting Detection models (#5617) * fix inits * fix docs * Port faster_rcnn * Port fcos * Port keypoint_rcnn * Port mask_rcnn * Port retinanet * Port ssd * Port ssdlite * Fix linter * Fixing tests * Fixing tests * Fixing vgg test * Porting Optical Flow, Segmentation, Video models (#5619) * Porting raft * Porting video resnet * Porting deeplabv3 * Porting fcn and lraspp * Fixing the tests and linter * Porting docs, examples, tutorials and galleries (#5620) * Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py Co-authored-by: Nicolas Hug <[email protected]> * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring Co-authored-by: Nicolas Hug <[email protected]> * Resolve conflict * Porting model tests (#5622) * Porting tests * Remove unnecessary variable * Fix linter * Move prototype to extended tests * Fix download models job * Update CI on Multiweight branch to use the new weight download approach (#5628) * port Pad to prototype transforms (#5621) * port Pad to prototype transforms * use literal * Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624) Co-authored-by: Anton Thomma <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> * pre-download model weights in CI docs build (#5625) * pre-download model weights in CI docs build * move changes into template * change docs image * Regenerated config.yml Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Anton Thomma <[email protected]> Co-authored-by: Anton Thomma <[email protected]> * Porting reference scripts and updating presets (#5629) * Making _preset.py classes * Remove support of targets on presets. * Rewriting the video preset * Adding tests to check that the bundled transforms are JIT scriptable * Rename all presets from *Eval to *Inference * Minor refactoring * Remove --prototype and --pretrained from reference scripts * remove pretained_backbone refs * Corrections and simplifications * Fixing bug * Fixing linter * Fix flake8 * restore documentation example * minor fixes * fix optical flow missing param * Fixing commands * Adding weights_backbone support in detection and segmentation * Updating the commands for InceptionV3 * Setting `weights_backbone` to its fully BC value (#5653) * Replace default `weights_backbone=None` with its BC values. * Fixing tests * Fix linter * Update docs. * Update preprocessing on reference scripts. * Change qat/ptq to their full values. * Refactoring preprocessing * Fix video preset * No initialization on VGG if pretrained * Fix warning messages for backbone utils. * Adding star to all preset constructors. * Fix mypy. Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Anton Thomma <[email protected]> Co-authored-by: Anton Thomma <[email protected]> * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> * use decompressor for extracting bz2 * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> * fixed lint fails * added tests for USPS * check image shape * fix tests * check shape on image directly * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> * removed test and comments * Update test/test_prototype_builtin_datasets.py Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Sahil Goyal <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Anton Thomma <[email protected]> Co-authored-by: Anton Thomma <[email protected]>
1 parent 1af20e8 commit 71907be

File tree

4 files changed

+92
-1
lines changed

4 files changed

+92
-1
lines changed

test/builtin_dataset_mocks.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import bz2
12
import collections.abc
23
import csv
34
import functools
@@ -1431,3 +1432,21 @@ def stanford_cars(info, root, config):
14311432
make_tar(root, "car_devkit.tgz", devkit, compression="gz")
14321433

14331434
return num_samples
1435+
1436+
1437+
@register_mock
1438+
def usps(info, root, config):
1439+
num_samples = {"train": 15, "test": 7}[config.split]
1440+
1441+
with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh:
1442+
lines = []
1443+
for _ in range(num_samples):
1444+
label = make_tensor(1, low=1, high=11, dtype=torch.int)
1445+
values = make_tensor(256, low=-1, high=1, dtype=torch.float)
1446+
lines.append(
1447+
" ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))])
1448+
)
1449+
1450+
fh.write("\n".join(lines).encode())
1451+
1452+
return num_samples

test/test_prototype_builtin_datasets.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torchdata.datapipes.iter import IterDataPipe, Shuffler
1313
from torchvision._utils import sequence_to_str
1414
from torchvision.prototype import transforms, datasets
15-
15+
from torchvision.prototype.features import Image, Label
1616

1717
assert_samples_equal = functools.partial(
1818
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
@@ -180,3 +180,20 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
180180
for sample in dataset:
181181
label_from_path = int(Path(sample["path"]).parent.name)
182182
assert sample["label"] == label_from_path
183+
184+
185+
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
186+
class TestUSPS:
187+
def test_sample_content(self, test_home, dataset_mock, config):
188+
dataset_mock.prepare(test_home, config)
189+
190+
dataset = datasets.load(dataset_mock.name, **config)
191+
192+
for sample in dataset:
193+
assert "image" in sample
194+
assert "label" in sample
195+
196+
assert isinstance(sample["image"], Image)
197+
assert isinstance(sample["label"], Label)
198+
199+
assert sample["image"].shape == (1, 16, 16)

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from .semeion import SEMEION
1818
from .stanford_cars import StanfordCars
1919
from .svhn import SVHN
20+
from .usps import USPS
2021
from .voc import VOC
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Dict, List
2+
3+
import torch
4+
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
5+
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
6+
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
7+
from torchvision.prototype.features import Image, Label
8+
9+
10+
class USPS(Dataset):
11+
def _make_info(self) -> DatasetInfo:
12+
return DatasetInfo(
13+
"usps",
14+
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
15+
valid_options=dict(
16+
split=("train", "test"),
17+
),
18+
categories=10,
19+
)
20+
21+
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
22+
23+
_RESOURCES = {
24+
"train": HttpResource(
25+
f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f"
26+
),
27+
"test": HttpResource(
28+
f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e"
29+
),
30+
}
31+
32+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
33+
return [USPS._RESOURCES[config.split]]
34+
35+
def _prepare_sample(self, line: str) -> Dict[str, Any]:
36+
label, *values = line.strip().split(" ")
37+
values = [float(value.split(":")[1]) for value in values]
38+
pixels = torch.tensor(values).add_(1).div_(2)
39+
return dict(
40+
image=Image(pixels.reshape(16, 16)),
41+
label=Label(int(label) - 1, categories=self.categories),
42+
)
43+
44+
def _make_datapipe(
45+
self,
46+
resource_dps: List[IterDataPipe],
47+
*,
48+
config: DatasetConfig,
49+
) -> IterDataPipe[Dict[str, Any]]:
50+
dp = Decompressor(resource_dps[0])
51+
dp = LineReader(dp, decode=True, return_path=False)
52+
dp = hint_sharding(dp)
53+
dp = hint_shuffling(dp)
54+
return Mapper(dp, self._prepare_sample)

0 commit comments

Comments
 (0)