Skip to content

Commit 1ab9030

Browse files
committed
Merge branch 'models/convnext_variants' of https://github.com/datumbox/vision into models/convnext_variants
2 parents 30a6b6d + c3971ad commit 1ab9030

25 files changed

+333
-48
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ jobs:
351351
- install_torchvision
352352
- install_prototype_dependencies
353353
- pip_install:
354-
args: scipy pycocotools
354+
args: scipy pycocotools h5py
355355
descr: Install optional dependencies
356356
- run:
357357
name: Enable prototype tests

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ supported Python versions.
2323
+==========================+==========================+=================================+
2424
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.9`` |
2525
+--------------------------+--------------------------+---------------------------------+
26+
| ``1.10.2`` | ``0.11.3`` | ``>=3.6``, ``<=3.9`` |
27+
+--------------------------+--------------------------+---------------------------------+
2628
| ``1.10.1`` | ``0.11.2`` | ``>=3.6``, ``<=3.9`` |
2729
+--------------------------+--------------------------+---------------------------------+
2830
| ``1.10.0`` | ``0.11.1`` | ``>=3.6``, ``<=3.9`` |

docs/source/utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`.
1515
draw_bounding_boxes
1616
draw_segmentation_masks
1717
draw_keypoints
18+
flow_to_image
1819
make_grid
1920
save_image

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def write_version_file():
5858
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
5959

6060
requirements = [
61+
"typing_extensions",
6162
"numpy",
6263
"requests",
6364
pytorch_dep,

test/assets/expected_flow.pt

30 KB
Binary file not shown.

test/builtin_dataset_mocks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import csv
33
import functools
44
import gzip
5+
import io
56
import itertools
67
import json
78
import lzma
@@ -1312,3 +1313,30 @@ def svhn(info, root, config):
13121313
},
13131314
)
13141315
return num_samples
1316+
1317+
1318+
@register_mock
1319+
def pcam(info, root, config):
1320+
import h5py
1321+
1322+
num_images = {"train": 2, "test": 3, "val": 4}[config.split]
1323+
1324+
split = "valid" if config.split == "val" else config.split
1325+
1326+
images_io = io.BytesIO()
1327+
with h5py.File(images_io, "w") as f:
1328+
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)
1329+
1330+
targets_io = io.BytesIO()
1331+
with h5py.File(targets_io, "w") as f:
1332+
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)
1333+
1334+
# Create .gz compressed files
1335+
images_file = root / f"camelyonpatch_level_2_split_{split}_x.h5.gz"
1336+
targets_file = root / f"camelyonpatch_level_2_split_{split}_y.h5.gz"
1337+
for compressed_file_name, uncompressed_file_io in ((images_file, images_io), (targets_file, targets_io)):
1338+
compressed_data = gzip.compress(uncompressed_file_io.getbuffer())
1339+
with open(compressed_file_name, "wb") as compressed_file:
1340+
compressed_file.write(compressed_data)
1341+
1342+
return num_images

test/test_prototype_builtin_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_home(mocker, tmp_path):
1818

1919

2020
def test_coverage():
21-
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
21+
untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys()
2222
if untested_datasets:
2323
raise AssertionError(
2424
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "

test/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,5 +317,30 @@ def test_draw_keypoints_errors():
317317
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
318318

319319

320+
def test_flow_to_image():
321+
h, w = 100, 100
322+
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
323+
flow = torch.stack(flow[::-1], dim=0).float()
324+
flow[0] -= h / 2
325+
flow[1] -= w / 2
326+
img = utils.flow_to_image(flow)
327+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
328+
expected_img = torch.load(path, map_location="cpu")
329+
assert_equal(expected_img, img)
330+
331+
332+
def test_flow_to_image_errors():
333+
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
334+
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
335+
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
336+
337+
with pytest.raises(ValueError, match="Input flow should have shape"):
338+
utils.flow_to_image(flow=wrong_flow1)
339+
with pytest.raises(ValueError, match="Input flow should have shape"):
340+
utils.flow_to_image(flow=wrong_flow2)
341+
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
342+
utils.flow_to_image(flow=wrong_flow3)
343+
344+
320345
if __name__ == "__main__":
321346
pytest.main([__file__])

torchvision/csrc/io/decoder/gpu/demuxer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class Demuxer {
119119
" in demuxer.h\n");
120120
}
121121
}
122+
122123
~Demuxer() {
123124
if (!fmtCtx) {
124125
return;
@@ -223,7 +224,7 @@ class Demuxer {
223224
int64_t time = timestamp * AV_TIME_BASE;
224225
TORCH_CHECK(
225226
0 <= av_seek_frame(fmtCtx, -1, time, flag),
226-
"avformat_open_input() failed at line ",
227+
"av_seek_frame() failed at line ",
227228
__LINE__,
228229
" in demuxer.h\n");
229230
}

torchvision/datasets/stl10.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from typing import Any, Callable, Optional, Tuple
2+
from typing import Any, Callable, Optional, Tuple, cast
33

44
import numpy as np
55
from PIL import Image
@@ -65,10 +65,12 @@ def __init__(
6565
self.labels: Optional[np.ndarray]
6666
if self.split == "train":
6767
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
68+
self.labels = cast(np.ndarray, self.labels)
6869
self.__load_folds(folds)
6970

7071
elif self.split == "train+unlabeled":
7172
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
73+
self.labels = cast(np.ndarray, self.labels)
7274
self.__load_folds(folds)
7375
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
7476
self.data = np.concatenate((self.data, unlabeled_data))

torchvision/models/segmentation/deeplabv3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import mobilenetv3
88
from .. import resnet
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _SimpleSegmentationModel, _load_weights
1111
from .fcn import FCNHead
1212

@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
121121
return_layers = {"layer4": "out"}
122122
if aux:
123123
return_layers["layer3"] = "aux"
124-
backbone = create_feature_extractor(backbone, return_layers)
124+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125125

126126
aux_classifier = FCNHead(1024, num_classes) if aux else None
127127
classifier = DeepLabHead(2048, num_classes)
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
144144
return_layers = {str(out_pos): "out"}
145145
if aux:
146146
return_layers[str(aux_pos)] = "aux"
147-
backbone = create_feature_extractor(backbone, return_layers)
147+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
148148

149149
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
150150
classifier = DeepLabHead(out_inplanes, num_classes)

torchvision/models/segmentation/fcn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import nn
44

55
from .. import resnet
6-
from ..feature_extraction import create_feature_extractor
6+
from .._utils import IntermediateLayerGetter
77
from ._utils import _SimpleSegmentationModel, _load_weights
88

99

@@ -57,7 +57,7 @@ def _fcn_resnet(
5757
return_layers = {"layer4": "out"}
5858
if aux:
5959
return_layers["layer3"] = "aux"
60-
backbone = create_feature_extractor(backbone, return_layers)
60+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
6161

6262
aux_classifier = FCNHead(1024, num_classes) if aux else None
6363
classifier = FCNHead(2048, num_classes)

torchvision/models/segmentation/lraspp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...utils import _log_api_usage_once
88
from .. import mobilenetv3
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _load_weights
1111

1212

@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
9090
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
9191
low_channels = backbone[low_pos].out_channels
9292
high_channels = backbone[high_pos].out_channels
93-
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
93+
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
9494

9595
return LRASPP(backbone, low_channels, high_channels, num_classes)
9696

torchvision/prototype/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
from ._home import home
1212

1313
# Load this last, since some parts depend on the above being loaded first
14-
from ._api import register, _list as list, info, load, find # usort: skip
14+
from ._api import register, list_datasets, info, load, find # usort: skip
1515
from ._folder import from_data_folder, from_image_folder

torchvision/prototype/datasets/_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def register(dataset: Dataset) -> None:
2323
register(obj())
2424

2525

26-
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
27-
def _list() -> List[str]:
26+
def list_datasets() -> List[str]:
2827
return sorted(DATASETS.keys())
2928

3029

@@ -39,7 +38,7 @@ def find(name: str) -> Dataset:
3938
word=name,
4039
possibilities=DATASETS.keys(),
4140
alternative_hint=lambda _: (
42-
"You can use torchvision.datasets.list() to get a list of all available datasets."
41+
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
4342
),
4443
)
4544
) from error

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .imagenet import ImageNet
1111
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
1212
from .oxford_iiit_pet import OxfordIITPet
13+
from .pcam import PCAM
1314
from .sbd import SBD
1415
from .semeion import SEMEION
1516
from .svhn import SVHN
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import io
2+
from collections import namedtuple
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
4+
5+
import torch
6+
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
7+
from torchvision.prototype import features
8+
from torchvision.prototype.datasets.utils import (
9+
Dataset,
10+
DatasetConfig,
11+
DatasetInfo,
12+
OnlineResource,
13+
DatasetType,
14+
GDriveResource,
15+
)
16+
from torchvision.prototype.datasets.utils._internal import (
17+
hint_sharding,
18+
hint_shuffling,
19+
)
20+
from torchvision.prototype.features import Label
21+
22+
23+
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
24+
def __init__(
25+
self,
26+
datapipe: IterDataPipe[Tuple[str, io.IOBase]],
27+
key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset
28+
) -> None:
29+
self.datapipe = datapipe
30+
self.key = key
31+
32+
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
33+
import h5py
34+
35+
for _, handle in self.datapipe:
36+
with h5py.File(handle) as data:
37+
if self.key is not None:
38+
data = data[self.key]
39+
yield from data
40+
41+
42+
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
43+
44+
45+
class PCAM(Dataset):
46+
def _make_info(self) -> DatasetInfo:
47+
return DatasetInfo(
48+
"pcam",
49+
type=DatasetType.RAW,
50+
homepage="https://github.com/basveeling/pcam",
51+
categories=2,
52+
valid_options=dict(split=("train", "test", "val")),
53+
dependencies=["h5py"],
54+
)
55+
56+
_RESOURCES = {
57+
"train": (
58+
_Resource( # Images
59+
file_name="camelyonpatch_level_2_split_train_x.h5.gz",
60+
gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",
61+
sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c",
62+
),
63+
_Resource( # Targets
64+
file_name="camelyonpatch_level_2_split_train_y.h5.gz",
65+
gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
66+
sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a",
67+
),
68+
),
69+
"test": (
70+
_Resource( # Images
71+
file_name="camelyonpatch_level_2_split_test_x.h5.gz",
72+
gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
73+
sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245",
74+
),
75+
_Resource( # Targets
76+
file_name="camelyonpatch_level_2_split_test_y.h5.gz",
77+
gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
78+
sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef",
79+
),
80+
),
81+
"val": (
82+
_Resource( # Images
83+
file_name="camelyonpatch_level_2_split_valid_x.h5.gz",
84+
gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
85+
sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2",
86+
),
87+
_Resource( # Targets
88+
file_name="camelyonpatch_level_2_split_valid_y.h5.gz",
89+
gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
90+
sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c",
91+
),
92+
),
93+
}
94+
95+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
96+
return [ # = [images resource, targets resource]
97+
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True)
98+
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
99+
]
100+
101+
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
102+
image, target = data # They're both numpy arrays at this point
103+
104+
return {
105+
"image": features.Image(image),
106+
"label": Label(target.item()),
107+
}
108+
109+
def _make_datapipe(
110+
self,
111+
resource_dps: List[IterDataPipe],
112+
*,
113+
config: DatasetConfig,
114+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
115+
) -> IterDataPipe[Dict[str, Any]]:
116+
117+
images_dp, targets_dp = resource_dps
118+
119+
images_dp = PCAMH5Reader(images_dp, key="x")
120+
targets_dp = PCAMH5Reader(targets_dp, key="y")
121+
122+
dp = Zipper(images_dp, targets_dp)
123+
dp = hint_sharding(dp)
124+
dp = hint_shuffling(dp)
125+
return Mapper(dp, self._collate_and_decode)

torchvision/prototype/datasets/generate_category_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def parse_args(argv=None):
4949
args = parser.parse_args(argv or sys.argv[1:])
5050

5151
if not args.names:
52-
args.names = datasets.list()
52+
args.names = datasets.list_datasets()
5353

5454
return args
5555

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class DatasetType(enum.Enum):
2424

2525

2626
class DatasetConfig(FrozenBunch):
27+
# This needs to be Frozen because we often pass configs as partial(func, config=config)
28+
# and partial() requires the parameters to be hashable.
2729
pass
2830

2931

0 commit comments

Comments
 (0)