Skip to content

Commit fad04f4

Browse files
committed
Merge branch 'main' into revamp-prototype-features-transforms
2 parents 8079e44 + ac1f0ff commit fad04f4

File tree

13 files changed

+144
-48
lines changed

13 files changed

+144
-48
lines changed

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ supported Python versions.
2323
+==========================+==========================+=================================+
2424
| ``main`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.7``, ``<=3.9`` |
2525
+--------------------------+--------------------------+---------------------------------+
26+
| ``1.10.2`` | ``0.11.3`` | ``>=3.6``, ``<=3.9`` |
27+
+--------------------------+--------------------------+---------------------------------+
2628
| ``1.10.1`` | ``0.11.2`` | ``>=3.6``, ``<=3.9`` |
2729
+--------------------------+--------------------------+---------------------------------+
2830
| ``1.10.0`` | ``0.11.1`` | ``>=3.6``, ``<=3.9`` |

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def write_version_file():
5858
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
5959

6060
requirements = [
61+
"typing_extensions",
6162
"numpy",
6263
"requests",
6364
pytorch_dep,
939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
from collections import OrderedDict
1010
from tempfile import TemporaryDirectory
11+
from typing import Any
1112

1213
import pytest
1314
import torch
@@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr():
514515
assert t.__repr__() == expected_string
515516

516517

518+
test_vit_conv_stem_configs = [
519+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
520+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
521+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
522+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
523+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
524+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
525+
]
526+
527+
528+
def vitc_b_16(**kwargs: Any):
529+
return models.VisionTransformer(
530+
image_size=224,
531+
patch_size=16,
532+
num_layers=12,
533+
num_heads=12,
534+
hidden_dim=768,
535+
mlp_dim=3072,
536+
conv_stem_configs=test_vit_conv_stem_configs,
537+
**kwargs,
538+
)
539+
540+
541+
@pytest.mark.parametrize("model_fn", [vitc_b_16])
542+
@pytest.mark.parametrize("dev", cpu_and_gpu())
543+
def test_vitc_models(model_fn, dev):
544+
test_classification_model(model_fn, dev)
545+
546+
517547
@pytest.mark.parametrize("model_fn", get_models_from_module(models))
518548
@pytest.mark.parametrize("dev", cpu_and_gpu())
519549
def test_classification_model(model_fn, dev):

test/test_utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
317317
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
318318

319319

320-
def test_flow_to_image():
320+
@pytest.mark.parametrize("batch", (True, False))
321+
def test_flow_to_image(batch):
321322
h, w = 100, 100
322323
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
323324
flow = torch.stack(flow[::-1], dim=0).float()
324325
flow[0] -= h / 2
325326
flow[1] -= w / 2
327+
328+
if batch:
329+
flow = torch.stack([flow, flow])
330+
326331
img = utils.flow_to_image(flow)
332+
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
333+
327334
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
328335
expected_img = torch.load(path, map_location="cpu")
329-
assert_equal(expected_img, img)
330336

337+
if batch:
338+
expected_img = torch.stack([expected_img, expected_img])
339+
340+
assert_equal(expected_img, img)
331341

332-
def test_flow_to_image_errors():
333-
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
334-
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
335-
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
336342

337-
with pytest.raises(ValueError, match="Input flow should have shape"):
338-
utils.flow_to_image(flow=wrong_flow1)
339-
with pytest.raises(ValueError, match="Input flow should have shape"):
340-
utils.flow_to_image(flow=wrong_flow2)
341-
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
342-
utils.flow_to_image(flow=wrong_flow3)
343+
@pytest.mark.parametrize(
344+
"input_flow, match",
345+
(
346+
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
347+
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
348+
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
349+
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
350+
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
351+
),
352+
)
353+
def test_flow_to_image_errors(input_flow, match):
354+
with pytest.raises(ValueError, match=match):
355+
utils.flow_to_image(flow=input_flow)
343356

344357

345358
if __name__ == "__main__":

torchvision/datasets/hmdb51.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class HMDB51(VisionDataset):
1313
"""
14-
`HMDB51 <http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
14+
`HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
1515
dataset.
1616
1717
HMDB51 is an action recognition video dataset.
@@ -47,9 +47,9 @@ class HMDB51(VisionDataset):
4747
- label (int): class of the video clip
4848
"""
4949

50-
data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
50+
data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
5151
splits = {
52-
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
52+
"url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
5353
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
5454
}
5555
TRAIN_TAG = 1

torchvision/datasets/stl10.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from typing import Any, Callable, Optional, Tuple
2+
from typing import Any, Callable, Optional, Tuple, cast
33

44
import numpy as np
55
from PIL import Image
@@ -65,10 +65,12 @@ def __init__(
6565
self.labels: Optional[np.ndarray]
6666
if self.split == "train":
6767
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
68+
self.labels = cast(np.ndarray, self.labels)
6869
self.__load_folds(folds)
6970

7071
elif self.split == "train+unlabeled":
7172
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
73+
self.labels = cast(np.ndarray, self.labels)
7274
self.__load_folds(folds)
7375
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
7476
self.data = np.concatenate((self.data, unlabeled_data))

torchvision/models/segmentation/deeplabv3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import mobilenetv3
88
from .. import resnet
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _SimpleSegmentationModel, _load_weights
1111
from .fcn import FCNHead
1212

@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
121121
return_layers = {"layer4": "out"}
122122
if aux:
123123
return_layers["layer3"] = "aux"
124-
backbone = create_feature_extractor(backbone, return_layers)
124+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125125

126126
aux_classifier = FCNHead(1024, num_classes) if aux else None
127127
classifier = DeepLabHead(2048, num_classes)
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
144144
return_layers = {str(out_pos): "out"}
145145
if aux:
146146
return_layers[str(aux_pos)] = "aux"
147-
backbone = create_feature_extractor(backbone, return_layers)
147+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
148148

149149
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
150150
classifier = DeepLabHead(out_inplanes, num_classes)

torchvision/models/segmentation/fcn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import nn
44

55
from .. import resnet
6-
from ..feature_extraction import create_feature_extractor
6+
from .._utils import IntermediateLayerGetter
77
from ._utils import _SimpleSegmentationModel, _load_weights
88

99

@@ -57,7 +57,7 @@ def _fcn_resnet(
5757
return_layers = {"layer4": "out"}
5858
if aux:
5959
return_layers["layer3"] = "aux"
60-
backbone = create_feature_extractor(backbone, return_layers)
60+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
6161

6262
aux_classifier = FCNHead(1024, num_classes) if aux else None
6363
classifier = FCNHead(2048, num_classes)

torchvision/models/segmentation/lraspp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...utils import _log_api_usage_once
88
from .. import mobilenetv3
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _load_weights
1111

1212

@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
9090
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
9191
low_channels = backbone[low_pos].out_channels
9292
high_channels = backbone[high_pos].out_channels
93-
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
93+
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
9494

9595
return LRASPP(backbone, low_channels, high_channels, num_classes)
9696

torchvision/models/vision_transformer.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import math
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, Optional
4+
from typing import Any, Callable, List, NamedTuple, Optional
55

66
import torch
77
import torch.nn as nn
88

99
from .._internally_replaced_utils import load_state_dict_from_url
10+
from ..ops.misc import ConvNormActivation
1011
from ..utils import _log_api_usage_once
1112

1213
__all__ = [
@@ -25,6 +26,14 @@
2526
}
2627

2728

29+
class ConvStemConfig(NamedTuple):
30+
out_channels: int
31+
kernel_size: int
32+
stride: int
33+
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
34+
activation_layer: Callable[..., nn.Module] = nn.ReLU
35+
36+
2837
class MLPBlock(nn.Sequential):
2938
"""Transformer MLP block."""
3039

@@ -134,6 +143,7 @@ def __init__(
134143
num_classes: int = 1000,
135144
representation_size: Optional[int] = None,
136145
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
146+
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
137147
):
138148
super().__init__()
139149
_log_api_usage_once(self)
@@ -148,11 +158,31 @@ def __init__(
148158
self.representation_size = representation_size
149159
self.norm_layer = norm_layer
150160

151-
input_channels = 3
152-
153-
# The conv_proj is a more efficient version of reshaping, permuting
154-
# and projecting the input
155-
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
161+
if conv_stem_configs is not None:
162+
# As per https://arxiv.org/abs/2106.14881
163+
seq_proj = nn.Sequential()
164+
prev_channels = 3
165+
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
166+
seq_proj.add_module(
167+
f"conv_bn_relu_{i}",
168+
ConvNormActivation(
169+
in_channels=prev_channels,
170+
out_channels=conv_stem_layer_config.out_channels,
171+
kernel_size=conv_stem_layer_config.kernel_size,
172+
stride=conv_stem_layer_config.stride,
173+
norm_layer=conv_stem_layer_config.norm_layer,
174+
activation_layer=conv_stem_layer_config.activation_layer,
175+
),
176+
)
177+
prev_channels = conv_stem_layer_config.out_channels
178+
seq_proj.add_module(
179+
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
180+
)
181+
self.conv_proj: nn.Module = seq_proj
182+
else:
183+
self.conv_proj = nn.Conv2d(
184+
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
185+
)
156186

157187
seq_length = (image_size // patch_size) ** 2
158188

@@ -184,9 +214,17 @@ def __init__(
184214
self._init_weights()
185215

186216
def _init_weights(self):
187-
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
188-
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
189-
nn.init.zeros_(self.conv_proj.bias)
217+
if isinstance(self.conv_proj, nn.Conv2d):
218+
# Init the patchify stem
219+
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
220+
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
221+
nn.init.zeros_(self.conv_proj.bias)
222+
else:
223+
# Init the last 1x1 conv of the conv stem
224+
nn.init.normal_(
225+
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
226+
)
227+
nn.init.zeros_(self.conv_proj.conv_last.bias)
190228

191229
if hasattr(self.heads, "pre_logits"):
192230
fan_in = self.heads.pre_logits.in_features

torchvision/transforms/functional_pil.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
from PIL import Image, ImageOps, ImageEnhance
7+
from typing_extensions import Literal
78

89
try:
910
import accimage
@@ -130,7 +131,7 @@ def pad(
130131
img: Image.Image,
131132
padding: Union[int, List[int], Tuple[int, ...]],
132133
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
133-
padding_mode: str = "constant",
134+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
134135
) -> Image.Image:
135136

136137
if not _is_pil_image(img):
@@ -189,7 +190,7 @@ def pad(
189190
if img.mode == "P":
190191
palette = img.getpalette()
191192
img = np.asarray(img)
192-
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
193+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
193194
img = Image.fromarray(img)
194195
img.putpalette(palette)
195196
return img

0 commit comments

Comments
 (0)