Skip to content

Add support to MViT v1 #6179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 23, 2022
Merged

Add support to MViT v1 #6179

merged 15 commits into from
Jun 23, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jun 16, 2022

  • Add support for MViTv1.
  • Add support for MViTv2 Attention with a residual_pool.
  • Change the way we configure the networks.

Reference test:

torchrun --nproc_per_node=8 train.py --data-path /datasets01/kinetics/070618/400/ --batch-size 1 --test-only --clip-len 16 --frame-rate 4 --clips-per-video 5 --model mvit_v1_b --weights MViT_V1_B_Weights.KINETICS400_V1 --cache-dataset

The above check shows reduced accuracy comparing to the expected one. There seems to be a regression on our reference script, Kinetics dataset or Video Decoder. The accuracy of the model using TorchVision's implementation was verified using the Slowfast reference scripts:

INFO:test_net:testing done: _ak78.47 Top1 Acc: 78.47 Top5 Acc: 93.65 MEM: 1.97 dataset: k400

@datumbox datumbox changed the title Switch implementation to v1 variant Switch implementation to MViTv1 variant Jun 16, 2022
@datumbox datumbox mentioned this pull request Jun 16, 2022
24 tasks
@datumbox datumbox changed the title Switch implementation to MViTv1 variant Add support to MViT v1 Jun 17, 2022
@datumbox
Copy link
Contributor Author

datumbox commented Jun 17, 2022

Similar to #6105 (comment), we verify the implementation as follows:

import collections

import torch
from pytorchvideo.models.hub.vision_transformers import mvit_base_16x4
from pytorchvideo.models.vision_transformers import create_multiscale_vision_transformers
from torchvision.models.video import mvit as TorchVision


class PyTorchVideo:
    @staticmethod
    def mvit_v1_b(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=16,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            droppath_rate_block=0.2,
            # additional params for PyTorch Video
            residual_pool=False,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )

    @staticmethod
    def mvit_base_16x4(**kwargs):
        # return mvit_base_16x4(pretrained=True)
        m = create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=16,
            depth=16,
            embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
            pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            droppath_rate_block=0.2,
            # additional params for PyTorch Video
            residual_pool=False,
            separate_qkv=True,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )
        # https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT_B_16x4.pyth
        d = torch.load("./MVIT_B_16x4.pyth")["model_state"]
        m.load_state_dict(d, strict=False)
        return m

    @staticmethod
    def mvit_v2_b(**kwargs):
        return create_multiscale_vision_transformers(
            spatial_size=(224, 224),
            temporal_size=32,
            depth=24,
            embed_dim_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
            atten_head_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
            pool_q_stride_size=[
                [0, 1, 1, 1],
                [1, 1, 1, 1],
                [2, 1, 2, 2],
                [3, 1, 1, 1],
                [4, 1, 1, 1],
                [5, 1, 2, 2],
                [6, 1, 1, 1],
                [7, 1, 1, 1],
                [8, 1, 1, 1],
                [9, 1, 1, 1],
                [10, 1, 1, 1],
                [11, 1, 1, 1],
                [12, 1, 1, 1],
                [13, 1, 1, 1],
                [14, 1, 1, 1],
                [15, 1, 1, 1],
                [16, 1, 1, 1],
                [17, 1, 1, 1],
                [18, 1, 1, 1],
                [19, 1, 1, 1],
                [20, 1, 1, 1],
                [21, 1, 2, 2],
                [22, 1, 1, 1],
                [23, 1, 1, 1],
            ],
            droppath_rate_block=0.3,
            # additional params for PyTorch Video
            residual_pool=True,
            separate_qkv=False,
            pool_kv_stride_adaptive=(1, 8, 8),
            pool_kvq_kernel=(3, 3, 3),
            head_dropout_rate=0.0,
            **kwargs,
        )


def ptv_to_tv_weights(state_dict, separate_qkv):
    d = dict(state_dict)

    # merge qkv if necessary
    if separate_qkv:
        components = collections.defaultdict(dict)
        for k in list(d.keys()):
            for pattern in ["q", "k", "v"]:
                if f".attn.{pattern}." in k:
                    group = k.rsplit(".", 2)[0]
                    components[group][k] = d.pop(k)
                    break
        for group in components.keys():
            for typ in ["weight", "bias"]:
                l = []
                for pattern in ["q", "k", "v"]:
                    l.append(components[group].pop(f"{group}.{pattern}.{typ}"))
                d[f"{group}.qkv.{typ}"] = torch.cat(l, dim=0)

    # remapping keys
    mapping = collections.OrderedDict(
        [
            ("patch_embed.patch_model.weight", "conv_proj.weight"),
            ("patch_embed.patch_model.bias", "conv_proj.bias"),
            ("cls_positional_encoding.cls_token", "pos_encoding.class_token"),
            ("cls_positional_encoding.pos_embed_spatial", "pos_encoding.spatial_pos"),
            ("cls_positional_encoding.pos_embed_temporal", "pos_encoding.temporal_pos"),
            ("cls_positional_encoding.pos_embed_class", "pos_encoding.class_pos"),
            ("attn.proj.weight", "attn.project.0.weight"),
            ("attn.proj.bias", "attn.project.0.bias"),
            ("attn.pool_q.weight", "attn.pool_q.pool.weight"),
            ("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
            ("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
            ("attn.pool_k.weight", "attn.pool_k.pool.weight"),
            ("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
            ("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
            ("attn.pool_v.weight", "attn.pool_v.pool.weight"),
            ("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
            ("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
            ("mlp.fc1.weight", "mlp.0.weight"),
            ("mlp.fc1.bias", "mlp.0.bias"),
            ("mlp.fc2.weight", "mlp.3.weight"),
            ("mlp.fc2.bias", "mlp.3.bias"),
            ("norm_embed.weight", "norm.weight"),
            ("norm_embed.bias", "norm.bias"),
            ("head.proj.weight", "head.1.weight"),
            ("head.proj.bias", "head.1.bias"),
            ("proj.weight", "project.weight"),
            ("proj.bias", "project.bias"),
        ]
    )
    for k in list(d.keys()):
        for pattern, replacement in mapping.items():
            if pattern in k:
                new_key = k.replace(pattern, replacement)
                d[new_key] = d.pop(k)
                break

    # matching dimensions
    d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
    d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
    d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
    d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]

    # removing unnecessary keys
    for k in list(d.keys()):
        if "attn._attention_pool_" in k:
            del d[k]
    return d


def compare_models(ptv_model_fn, tv_model_fn, input_shape):
    print(tv_model_fn.__name__)
    x = torch.randn(input_shape)

    ptv_m = ptv_model_fn().eval()
    exp_result = ptv_m(x).sum()

    separate_qkv = isinstance(ptv_m.blocks[0].attn.qkv, torch.nn.Identity)

    d = ptv_m.state_dict()
    d = ptv_to_tv_weights(d, separate_qkv)

    tv_m = tv_model_fn().eval()
    tv_m.load_state_dict(d)
    result = tv_m(x).sum()

    torch.testing.assert_close(result, exp_result, rtol=0, atol=1e-6)
    print("OK")


compare_models(PyTorchVideo.mvit_v1_b, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_base_16x4, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))
# compare_models(PyTorchVideo.mvit_v2_b, TorchVision.mvit_v2_b, (1, 3, 32, 224, 224))

And to verify speed:

import time


def benchmark(model_fn, input_shape, device, n=5, warmup=0.1):
    torch.manual_seed(42)
    m = model_fn().to(device).eval()
    x = torch.randn(input_shape).to(device)

    s = []
    for i in range(n):
        start = time.time()
        m(x)
        t = time.time() - start
        if i > n * warmup:
            s.append(t)

    print(model_fn.__name__, torch.tensor(s).median())


device = "cuda"
batch_size = 4
n = 100

print(f"device={device}, batch_size={batch_size}, n={n}")
for name, backend in [("TorchVision", TorchVision), ("PyTorchVideo", PyTorchVideo)]:
    print(name)
    benchmark(backend.mvit_v1_b, (batch_size, 3, 16, 224, 224), device, n=n)
    benchmark(backend.mvit_v2_b, (batch_size, 3, 32, 224, 224), device, n=n)

Output (less is better):

device=cuda, batch_size=4, n=100
TorchVision
mvit_v1_b tensor(0.0298)
mvit_v2_b tensor(0.1094)
PyTorchVideo
mvit_v1_b tensor(0.0310)
mvit_v2_b tensor(0.1127)

@datumbox
Copy link
Contributor Author

Confirm that the weights are ported correctly from PyTorch Video:

import torch
from torchvision.models.video import mvit_v1_b
from pytorchvideo.models.hub.vision_transformers import mvit_base_16x4


mvt = mvit_base_16x4().eval()
# Load weights from: https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT_B_16x4.pyth
# strict=False due to https://github.com/facebookresearch/pytorchvideo/issues/206
mvt.load_state_dict(torch.load("MVIT_B_16x4.pyth")["model_state"], strict=False)  


tv = mvit_v1_b(weights="DEFAULT").eval()

batch = torch.randn((4, 3, 16, 224, 224))
y1 = mvt(batch)
y2 = tv(batch)
assert torch.equal(y1, y2)

print("OK")

@datumbox datumbox merged commit 3e7683f into pytorch:mvit Jun 23, 2022
@datumbox datumbox deleted the mvitv1 branch June 23, 2022 14:11
datumbox added a commit that referenced this pull request Jun 24, 2022
* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy
facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2022
Summary:
* Adding MViT v2 architecture (#6105)

* Adding mvitv2 architecture

* Fixing memory issues on tests and minor refactorings.

* Adding input validation

* Adding docs and minor refactoring

* Add `min_temporal_size` in the supported meta-data.

* Switch Tuple[int, int, int] with List[int] to support easier the 2D case

* Adding more docs and references

* Change naming conventions of classes to follow the same pattern as MobileNetV3

* Fix test breakage.

* Update todos

* Performance optimizations.

* Add support to MViT v1 (#6179)

* Switch implementation to v1 variant.

* Fix docs

* Adding back a v2 pseudovariant

* Changing the way the network are configured.

* Temporarily removing v2

* Adding weights.

* Expand _squeeze/_unsqueeze to support arbitrary dims.

* Update references script.

* Fix tests.

* Fixing frames and preprocessing.

* Fix std/mean values in transforms.

* Add permanent Dropout and update the weights.

* Update accuracies.

* Fix documentation

* Remove unnecessary expected file.

* Skip big model test

* Rewrite the configuration logic to reduce LOC.

* Fix mypy

Reviewed By: NicolasHug

Differential Revision: D37450352

fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants