Skip to content

Add MViT architecture in TorchVision #6198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 24, 2022
1 change: 1 addition & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ pre-trained weights:
.. toctree::
:maxdepth: 1

models/video_mvit
models/video_resnet

|
Expand Down
26 changes: 26 additions & 0 deletions docs/source/models/video_mvit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Video MViT
==========

.. currentmodule:: torchvision.models.video

The MViT model is based on the
`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
<https://arxiv.org/abs/2112.01526>`__ and `Multiscale Vision Transformers
<https://arxiv.org/abs/2104.11227>`__ papers.


Model builders
--------------

The following model builders can be used to instantiate a MViT model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.MViT`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/mvit.py>`_ for
more details about this class.

.. autosummary::
:toctree: generated/
:template: function.rst

mvit_v1_b
5 changes: 3 additions & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def main(args):
split="train",
step_between_clips=1,
transform=transform_train,
frame_rate=15,
frame_rate=args.frame_rate,
extensions=(
"avi",
"mp4",
Expand Down Expand Up @@ -189,7 +189,7 @@ def main(args):
split="val",
step_between_clips=1,
transform=transform_test,
frame_rate=15,
frame_rate=args.frame_rate,
extensions=(
"avi",
"mp4",
Expand Down Expand Up @@ -324,6 +324,7 @@ def parse_args():
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
parser.add_argument(
"--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
)
Expand Down
Binary file added test/expect/ModelTester.test_mvit_v1_b_expect.pkl
Binary file not shown.
5 changes: 4 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_schema_meta_validation(model_fn):
"license",
"_metrics",
"min_size",
"min_temporal_size",
"num_params",
"recipe",
"unquantized",
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_transforms_jit(model_fn):
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 4, 3, 112, 112),
"input_shape": (1, 3, 4, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
Expand All @@ -194,6 +195,8 @@ def test_transforms_jit(model_fn):
if module_name == "optical_flow":
args = (x, x)
else:
if module_name == "video":
x = x.permute(0, 2, 1, 3, 4)
args = (x,)

problematic_weights = []
Expand Down
5 changes: 5 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def _check_input_backprop(model, inputs):
"image_size": 56,
"input_shape": (1, 3, 56, 56),
},
"mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224),
},
}
# speeding up slow models:
slow_models = [
Expand Down Expand Up @@ -830,6 +833,8 @@ def test_video_model(model_fn, dev):
"num_classes": 50,
}
model_name = model_fn.__name__
if SKIP_BIG_MODEL and model_name in skipped_big_models:
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/video/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mvit import *
from .resnet import *
Loading