Skip to content

Commit 2bbb112

Browse files
committed
Merge remote-tracking branch 'origin/models/convnext_variants' into models/convnext_variants
2 parents cafa02d + f803797 commit 2bbb112

File tree

6 files changed

+128
-38
lines changed

6 files changed

+128
-38
lines changed
939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
from collections import OrderedDict
1010
from tempfile import TemporaryDirectory
11+
from typing import Any
1112

1213
import pytest
1314
import torch
@@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr():
514515
assert t.__repr__() == expected_string
515516

516517

518+
test_vit_conv_stem_configs = [
519+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
520+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
521+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
522+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
523+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
524+
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
525+
]
526+
527+
528+
def vitc_b_16(**kwargs: Any):
529+
return models.VisionTransformer(
530+
image_size=224,
531+
patch_size=16,
532+
num_layers=12,
533+
num_heads=12,
534+
hidden_dim=768,
535+
mlp_dim=3072,
536+
conv_stem_configs=test_vit_conv_stem_configs,
537+
**kwargs,
538+
)
539+
540+
541+
@pytest.mark.parametrize("model_fn", [vitc_b_16])
542+
@pytest.mark.parametrize("dev", cpu_and_gpu())
543+
def test_vitc_models(model_fn, dev):
544+
test_classification_model(model_fn, dev)
545+
546+
517547
@pytest.mark.parametrize("model_fn", get_models_from_module(models))
518548
@pytest.mark.parametrize("dev", cpu_and_gpu())
519549
def test_classification_model(model_fn, dev):

test/test_utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,29 +317,42 @@ def test_draw_keypoints_errors():
317317
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
318318

319319

320-
def test_flow_to_image():
320+
@pytest.mark.parametrize("batch", (True, False))
321+
def test_flow_to_image(batch):
321322
h, w = 100, 100
322323
flow = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
323324
flow = torch.stack(flow[::-1], dim=0).float()
324325
flow[0] -= h / 2
325326
flow[1] -= w / 2
327+
328+
if batch:
329+
flow = torch.stack([flow, flow])
330+
326331
img = utils.flow_to_image(flow)
332+
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
333+
327334
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
328335
expected_img = torch.load(path, map_location="cpu")
329-
assert_equal(expected_img, img)
330336

337+
if batch:
338+
expected_img = torch.stack([expected_img, expected_img])
339+
340+
assert_equal(expected_img, img)
331341

332-
def test_flow_to_image_errors():
333-
wrong_flow1 = torch.full((3, 10, 10), 0, dtype=torch.float)
334-
wrong_flow2 = torch.full((2, 10), 0, dtype=torch.float)
335-
wrong_flow3 = torch.full((2, 10, 30), 0, dtype=torch.int)
336342

337-
with pytest.raises(ValueError, match="Input flow should have shape"):
338-
utils.flow_to_image(flow=wrong_flow1)
339-
with pytest.raises(ValueError, match="Input flow should have shape"):
340-
utils.flow_to_image(flow=wrong_flow2)
341-
with pytest.raises(ValueError, match="Flow should be of dtype torch.float"):
342-
utils.flow_to_image(flow=wrong_flow3)
343+
@pytest.mark.parametrize(
344+
"input_flow, match",
345+
(
346+
(torch.full((3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
347+
(torch.full((5, 3, 10, 10), 0, dtype=torch.float), "Input flow should have shape"),
348+
(torch.full((2, 10), 0, dtype=torch.float), "Input flow should have shape"),
349+
(torch.full((5, 2, 10), 0, dtype=torch.float), "Input flow should have shape"),
350+
(torch.full((2, 10, 30), 0, dtype=torch.int), "Flow should be of dtype torch.float"),
351+
),
352+
)
353+
def test_flow_to_image_errors(input_flow, match):
354+
with pytest.raises(ValueError, match=match):
355+
utils.flow_to_image(flow=input_flow)
343356

344357

345358
if __name__ == "__main__":

torchvision/datasets/hmdb51.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class HMDB51(VisionDataset):
1313
"""
14-
`HMDB51 <http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
14+
`HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
1515
dataset.
1616
1717
HMDB51 is an action recognition video dataset.
@@ -47,9 +47,9 @@ class HMDB51(VisionDataset):
4747
- label (int): class of the video clip
4848
"""
4949

50-
data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
50+
data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
5151
splits = {
52-
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
52+
"url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
5353
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
5454
}
5555
TRAIN_TAG = 1

torchvision/models/vision_transformer.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import math
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, Optional
4+
from typing import Any, Callable, List, NamedTuple, Optional
55

66
import torch
77
import torch.nn as nn
88

99
from .._internally_replaced_utils import load_state_dict_from_url
10+
from ..ops.misc import ConvNormActivation
1011
from ..utils import _log_api_usage_once
1112

1213
__all__ = [
@@ -25,6 +26,14 @@
2526
}
2627

2728

29+
class ConvStemConfig(NamedTuple):
30+
out_channels: int
31+
kernel_size: int
32+
stride: int
33+
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
34+
activation_layer: Callable[..., nn.Module] = nn.ReLU
35+
36+
2837
class MLPBlock(nn.Sequential):
2938
"""Transformer MLP block."""
3039

@@ -134,6 +143,7 @@ def __init__(
134143
num_classes: int = 1000,
135144
representation_size: Optional[int] = None,
136145
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
146+
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
137147
):
138148
super().__init__()
139149
_log_api_usage_once(self)
@@ -148,11 +158,31 @@ def __init__(
148158
self.representation_size = representation_size
149159
self.norm_layer = norm_layer
150160

151-
input_channels = 3
152-
153-
# The conv_proj is a more efficient version of reshaping, permuting
154-
# and projecting the input
155-
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
161+
if conv_stem_configs is not None:
162+
# As per https://arxiv.org/abs/2106.14881
163+
seq_proj = nn.Sequential()
164+
prev_channels = 3
165+
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
166+
seq_proj.add_module(
167+
f"conv_bn_relu_{i}",
168+
ConvNormActivation(
169+
in_channels=prev_channels,
170+
out_channels=conv_stem_layer_config.out_channels,
171+
kernel_size=conv_stem_layer_config.kernel_size,
172+
stride=conv_stem_layer_config.stride,
173+
norm_layer=conv_stem_layer_config.norm_layer,
174+
activation_layer=conv_stem_layer_config.activation_layer,
175+
),
176+
)
177+
prev_channels = conv_stem_layer_config.out_channels
178+
seq_proj.add_module(
179+
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
180+
)
181+
self.conv_proj: nn.Module = seq_proj
182+
else:
183+
self.conv_proj = nn.Conv2d(
184+
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
185+
)
156186

157187
seq_length = (image_size // patch_size) ** 2
158188

@@ -184,9 +214,17 @@ def __init__(
184214
self._init_weights()
185215

186216
def _init_weights(self):
187-
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
188-
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
189-
nn.init.zeros_(self.conv_proj.bias)
217+
if isinstance(self.conv_proj, nn.Conv2d):
218+
# Init the patchify stem
219+
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
220+
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
221+
nn.init.zeros_(self.conv_proj.bias)
222+
else:
223+
# Init the last 1x1 conv of the conv stem
224+
nn.init.normal_(
225+
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
226+
)
227+
nn.init.zeros_(self.conv_proj.conv_last.bias)
190228

191229
if hasattr(self.heads, "pre_logits"):
192230
fan_in = self.heads.pre_logits.in_features

torchvision/utils.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -397,42 +397,51 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
397397
Converts a flow to an RGB image.
398398
399399
Args:
400-
flow (Tensor): Flow of shape (2, H, W) and dtype torch.float.
400+
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
401401
402402
Returns:
403-
img (Tensor(3, H, W)): Image Tensor of dtype uint8 where each color corresponds to a given flow direction.
403+
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
404+
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
404405
"""
405406

406407
if flow.dtype != torch.float:
407408
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
408409

409-
if flow.ndim != 3 or flow.size(0) != 2:
410-
raise ValueError(f"Input flow should have shape (2, H, W), got {flow.shape}.")
410+
orig_shape = flow.shape
411+
if flow.ndim == 3:
412+
flow = flow[None] # Add batch dim
411413

412-
max_norm = torch.sum(flow ** 2, dim=0).sqrt().max()
414+
if flow.ndim != 4 or flow.shape[1] != 2:
415+
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
416+
417+
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max()
413418
epsilon = torch.finfo((flow).dtype).eps
414419
normalized_flow = flow / (max_norm + epsilon)
415-
return _normalized_flow_to_image(normalized_flow)
420+
img = _normalized_flow_to_image(normalized_flow)
421+
422+
if len(orig_shape) == 3:
423+
img = img[0] # Remove batch dim
424+
return img
416425

417426

418427
@torch.no_grad()
419428
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
420429

421430
"""
422-
Converts a normalized flow to an RGB image.
431+
Converts a batch of normalized flow to an RGB image.
423432
424433
Args:
425-
normalized_flow (torch.Tensor): Normalized flow tensor of shape (2, H, W)
434+
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
426435
Returns:
427-
img (Tensor(3, H, W)): Flow visualization image of dtype uint8.
436+
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
428437
"""
429438

430-
_, H, W = normalized_flow.shape
431-
flow_image = torch.zeros((3, H, W), dtype=torch.uint8)
439+
N, _, H, W = normalized_flow.shape
440+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
432441
colorwheel = _make_colorwheel() # shape [55x3]
433442
num_cols = colorwheel.shape[0]
434-
norm = torch.sum(normalized_flow ** 2, dim=0).sqrt()
435-
a = torch.atan2(-normalized_flow[1], -normalized_flow[0]) / torch.pi
443+
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
444+
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
436445
fk = (a + 1) / 2 * (num_cols - 1)
437446
k0 = torch.floor(fk).to(torch.long)
438447
k1 = k0 + 1
@@ -445,7 +454,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
445454
col1 = tmp[k1] / 255.0
446455
col = (1 - f) * col0 + f * col1
447456
col = 1 - norm * (1 - col)
448-
flow_image[c, :, :] = torch.floor(255 * col)
457+
flow_image[:, c, :, :] = torch.floor(255 * col)
449458
return flow_image
450459

451460

0 commit comments

Comments
 (0)