-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add support to MViT v1 #6179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add support to MViT v1 #6179
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 tasks
Similar to #6105 (comment), we verify the implementation as follows: import collections
import torch
from pytorchvideo.models.hub.vision_transformers import mvit_base_16x4
from pytorchvideo.models.vision_transformers import create_multiscale_vision_transformers
from torchvision.models.video import mvit as TorchVision
class PyTorchVideo:
@staticmethod
def mvit_v1_b(**kwargs):
return create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=16,
depth=16,
embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
droppath_rate_block=0.2,
# additional params for PyTorch Video
residual_pool=False,
separate_qkv=False,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
@staticmethod
def mvit_base_16x4(**kwargs):
# return mvit_base_16x4(pretrained=True)
m = create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=16,
depth=16,
embed_dim_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
atten_head_mul=[[1, 2.0], [3, 2.0], [14, 2.0]],
pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
droppath_rate_block=0.2,
# additional params for PyTorch Video
residual_pool=False,
separate_qkv=True,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
# https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT_B_16x4.pyth
d = torch.load("./MVIT_B_16x4.pyth")["model_state"]
m.load_state_dict(d, strict=False)
return m
@staticmethod
def mvit_v2_b(**kwargs):
return create_multiscale_vision_transformers(
spatial_size=(224, 224),
temporal_size=32,
depth=24,
embed_dim_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
atten_head_mul=[[2, 2.0], [5, 2.0], [21, 2.0]],
pool_q_stride_size=[
[0, 1, 1, 1],
[1, 1, 1, 1],
[2, 1, 2, 2],
[3, 1, 1, 1],
[4, 1, 1, 1],
[5, 1, 2, 2],
[6, 1, 1, 1],
[7, 1, 1, 1],
[8, 1, 1, 1],
[9, 1, 1, 1],
[10, 1, 1, 1],
[11, 1, 1, 1],
[12, 1, 1, 1],
[13, 1, 1, 1],
[14, 1, 1, 1],
[15, 1, 1, 1],
[16, 1, 1, 1],
[17, 1, 1, 1],
[18, 1, 1, 1],
[19, 1, 1, 1],
[20, 1, 1, 1],
[21, 1, 2, 2],
[22, 1, 1, 1],
[23, 1, 1, 1],
],
droppath_rate_block=0.3,
# additional params for PyTorch Video
residual_pool=True,
separate_qkv=False,
pool_kv_stride_adaptive=(1, 8, 8),
pool_kvq_kernel=(3, 3, 3),
head_dropout_rate=0.0,
**kwargs,
)
def ptv_to_tv_weights(state_dict, separate_qkv):
d = dict(state_dict)
# merge qkv if necessary
if separate_qkv:
components = collections.defaultdict(dict)
for k in list(d.keys()):
for pattern in ["q", "k", "v"]:
if f".attn.{pattern}." in k:
group = k.rsplit(".", 2)[0]
components[group][k] = d.pop(k)
break
for group in components.keys():
for typ in ["weight", "bias"]:
l = []
for pattern in ["q", "k", "v"]:
l.append(components[group].pop(f"{group}.{pattern}.{typ}"))
d[f"{group}.qkv.{typ}"] = torch.cat(l, dim=0)
# remapping keys
mapping = collections.OrderedDict(
[
("patch_embed.patch_model.weight", "conv_proj.weight"),
("patch_embed.patch_model.bias", "conv_proj.bias"),
("cls_positional_encoding.cls_token", "pos_encoding.class_token"),
("cls_positional_encoding.pos_embed_spatial", "pos_encoding.spatial_pos"),
("cls_positional_encoding.pos_embed_temporal", "pos_encoding.temporal_pos"),
("cls_positional_encoding.pos_embed_class", "pos_encoding.class_pos"),
("attn.proj.weight", "attn.project.0.weight"),
("attn.proj.bias", "attn.project.0.bias"),
("attn.pool_q.weight", "attn.pool_q.pool.weight"),
("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
("attn.pool_k.weight", "attn.pool_k.pool.weight"),
("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
("attn.pool_v.weight", "attn.pool_v.pool.weight"),
("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
("mlp.fc1.weight", "mlp.0.weight"),
("mlp.fc1.bias", "mlp.0.bias"),
("mlp.fc2.weight", "mlp.3.weight"),
("mlp.fc2.bias", "mlp.3.bias"),
("norm_embed.weight", "norm.weight"),
("norm_embed.bias", "norm.bias"),
("head.proj.weight", "head.1.weight"),
("head.proj.bias", "head.1.bias"),
("proj.weight", "project.weight"),
("proj.bias", "project.bias"),
]
)
for k in list(d.keys()):
for pattern, replacement in mapping.items():
if pattern in k:
new_key = k.replace(pattern, replacement)
d[new_key] = d.pop(k)
break
# matching dimensions
d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]
# removing unnecessary keys
for k in list(d.keys()):
if "attn._attention_pool_" in k:
del d[k]
return d
def compare_models(ptv_model_fn, tv_model_fn, input_shape):
print(tv_model_fn.__name__)
x = torch.randn(input_shape)
ptv_m = ptv_model_fn().eval()
exp_result = ptv_m(x).sum()
separate_qkv = isinstance(ptv_m.blocks[0].attn.qkv, torch.nn.Identity)
d = ptv_m.state_dict()
d = ptv_to_tv_weights(d, separate_qkv)
tv_m = tv_model_fn().eval()
tv_m.load_state_dict(d)
result = tv_m(x).sum()
torch.testing.assert_close(result, exp_result, rtol=0, atol=1e-6)
print("OK")
compare_models(PyTorchVideo.mvit_v1_b, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))
compare_models(PyTorchVideo.mvit_base_16x4, TorchVision.mvit_v1_b, (1, 3, 16, 224, 224))
# compare_models(PyTorchVideo.mvit_v2_b, TorchVision.mvit_v2_b, (1, 3, 32, 224, 224)) And to verify speed: import time
def benchmark(model_fn, input_shape, device, n=5, warmup=0.1):
torch.manual_seed(42)
m = model_fn().to(device).eval()
x = torch.randn(input_shape).to(device)
s = []
for i in range(n):
start = time.time()
m(x)
t = time.time() - start
if i > n * warmup:
s.append(t)
print(model_fn.__name__, torch.tensor(s).median())
device = "cuda"
batch_size = 4
n = 100
print(f"device={device}, batch_size={batch_size}, n={n}")
for name, backend in [("TorchVision", TorchVision), ("PyTorchVideo", PyTorchVideo)]:
print(name)
benchmark(backend.mvit_v1_b, (batch_size, 3, 16, 224, 224), device, n=n)
benchmark(backend.mvit_v2_b, (batch_size, 3, 32, 224, 224), device, n=n) Output (less is better):
|
Confirm that the weights are ported correctly from PyTorch Video: import torch
from torchvision.models.video import mvit_v1_b
from pytorchvideo.models.hub.vision_transformers import mvit_base_16x4
mvt = mvit_base_16x4().eval()
# Load weights from: https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/MVIT_B_16x4.pyth
# strict=False due to https://github.com/facebookresearch/pytorchvideo/issues/206
mvt.load_state_dict(torch.load("MVIT_B_16x4.pyth")["model_state"], strict=False)
tv = mvit_v1_b(weights="DEFAULT").eval()
batch = torch.randn((4, 3, 16, 224, 224))
y1 = mvt(batch)
y2 = tv(batch)
assert torch.equal(y1, y2)
print("OK") |
c937cd7
to
29e7bc3
Compare
datumbox
added a commit
that referenced
this pull request
Jun 24, 2022
* Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy
facebook-github-bot
pushed a commit
that referenced
this pull request
Jun 27, 2022
Summary: * Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. * Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. * Fix documentation * Remove unnecessary expected file. * Skip big model test * Rewrite the configuration logic to reduce LOC. * Fix mypy Reviewed By: NicolasHug Differential Revision: D37450352 fbshipit-source-id: 5c0bf1065351d8dd612012902117fd866db02899
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
residual_pool
.Reference test:
The above check shows reduced accuracy comparing to the expected one. There seems to be a regression on our reference script, Kinetics dataset or Video Decoder. The accuracy of the model using TorchVision's implementation was verified using the Slowfast reference scripts: