Skip to content

Commit 4a9f05a

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Add MViT architecture in TorchVision (#6198)
Summary: * Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy Reviewed By: NicolasHug Differential Revision: D37450352 fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
1 parent b7dcb23 commit 4a9f05a

File tree

8 files changed

+591
-3
lines changed

8 files changed

+591
-3
lines changed

docs/source/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ pre-trained weights:
465465
.. toctree::
466466
:maxdepth: 1
467467

468+
models/video_mvit
468469
models/video_resnet
469470

470471
|

docs/source/models/video_mvit.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Video MViT
2+
==========
3+
4+
.. currentmodule:: torchvision.models.video
5+
6+
The MViT model is based on the
7+
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
8+
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
9+
<https://arxiv.org/abs/2104.11227>`__ papers.
10+
11+
12+
Model builders
13+
--------------
14+
15+
The following model builders can be used to instantiate a MViT model, with or
16+
without pre-trained weights. All the model builders internally rely on the
17+
``torchvision.models.video.MViT`` base class. Please refer to the `source
18+
code
19+
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for
20+
more details about this class.
21+
22+
.. autosummary::
23+
:toctree: generated/
24+
:template: function.rst
25+
26+
mvit_v1_b

references/video_classification/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def main(args):
152152
split="train",
153153
step_between_clips=1,
154154
transform=transform_train,
155-
frame_rate=15,
155+
frame_rate=args.frame_rate,
156156
extensions=(
157157
"avi",
158158
"mp4",
@@ -189,7 +189,7 @@ def main(args):
189189
split="val",
190190
step_between_clips=1,
191191
transform=transform_test,
192-
frame_rate=15,
192+
frame_rate=args.frame_rate,
193193
extensions=(
194194
"avi",
195195
"mp4",
@@ -324,6 +324,7 @@ def parse_args():
324324
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
325325
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
326326
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
327+
parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
327328
parser.add_argument(
328329
"--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
329330
)
939 Bytes
Binary file not shown.

test/test_extended_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_schema_meta_validation(model_fn):
8787
"license",
8888
"_metrics",
8989
"min_size",
90+
"min_temporal_size",
9091
"num_params",
9192
"recipe",
9293
"unquantized",
@@ -180,7 +181,7 @@ def test_transforms_jit(model_fn):
180181
"input_shape": (1, 3, 520, 520),
181182
},
182183
"video": {
183-
"input_shape": (1, 4, 3, 112, 112),
184+
"input_shape": (1, 3, 4, 112, 112),
184185
},
185186
"optical_flow": {
186187
"input_shape": (1, 3, 128, 128),
@@ -194,6 +195,8 @@ def test_transforms_jit(model_fn):
194195
if module_name == "optical_flow":
195196
args = (x, x)
196197
else:
198+
if module_name == "video":
199+
x = x.permute(0, 2, 1, 3, 4)
197200
args = (x,)
198201

199202
problematic_weights = []

test/test_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ def _check_input_backprop(model, inputs):
309309
"image_size": 56,
310310
"input_shape": (1, 3, 56, 56),
311311
},
312+
"mvit_v1_b": {
313+
"input_shape": (1, 3, 16, 224, 224),
314+
},
312315
}
313316
# speeding up slow models:
314317
slow_models = [
@@ -830,6 +833,8 @@ def test_video_model(model_fn, dev):
830833
"num_classes": 50,
831834
}
832835
model_name = model_fn.__name__
836+
if SKIP_BIG_MODEL and model_name in skipped_big_models:
837+
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
833838
kwargs = {**defaults, **_model_params.get(model_name, {})}
834839
num_classes = kwargs.get("num_classes")
835840
input_shape = kwargs.pop("input_shape")

torchvision/models/video/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .mvit import *
12
from .resnet import *

0 commit comments

Comments
 (0)