diff --git a/docs/source/models.rst b/docs/source/models.rst index 0edb394f08f..769c2d2721b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -465,7 +465,7 @@ pre-trained weights: .. toctree:: :maxdepth: 1 - models/video_mvitv2 + models/video_mvit models/video_resnet | diff --git a/docs/source/models/video_mvitv2.rst b/docs/source/models/video_mvit.rst similarity index 72% rename from docs/source/models/video_mvitv2.rst rename to docs/source/models/video_mvit.rst index e9ad556ded7..713ca769f0b 100644 --- a/docs/source/models/video_mvitv2.rst +++ b/docs/source/models/video_mvit.rst @@ -12,17 +12,15 @@ The MViT V2 model is based on the Model builders -------------- -The following model builders can be used to instantiate a MViTV2 model, with or +The following model builders can be used to instantiate a MViT model, with or without pre-trained weights. All the model builders internally rely on the -``torchvision.models.video.MViTV2`` base class. Please refer to the `source +``torchvision.models.video.MViT`` base class. Please refer to the `source code -`_ for +`_ for more details about this class. .. autosummary:: :toctree: generated/ :template: function.rst - mvit_v2_t - mvit_v2_s - mvit_v2_b + mvit_v1_b diff --git a/references/video_classification/train.py b/references/video_classification/train.py index e1df08cbe4a..a746470be9b 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -152,7 +152,7 @@ def main(args): split="train", step_between_clips=1, transform=transform_train, - frame_rate=15, + frame_rate=args.frame_rate, extensions=( "avi", "mp4", @@ -189,7 +189,7 @@ def main(args): split="val", step_between_clips=1, transform=transform_test, - frame_rate=15, + frame_rate=args.frame_rate, extensions=( "avi", "mp4", @@ -324,6 +324,7 @@ def parse_args(): parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip") + parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate") parser.add_argument( "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider" ) diff --git a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v1_b_expect.pkl similarity index 54% rename from test/expect/ModelTester.test_mvit_v2_s_expect.pkl rename to test/expect/ModelTester.test_mvit_v1_b_expect.pkl index 48c342eec2d..cc6592c97bd 100644 Binary files a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl and b/test/expect/ModelTester.test_mvit_v1_b_expect.pkl differ diff --git a/test/expect/ModelTester.test_mvit_v2_t_expect.pkl b/test/expect/ModelTester.test_mvit_v2_t_expect.pkl deleted file mode 100644 index 384fe05b50c..00000000000 Binary files a/test/expect/ModelTester.test_mvit_v2_t_expect.pkl and /dev/null differ diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 05efd39099e..7961d173e3f 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -181,7 +181,7 @@ def test_transforms_jit(model_fn): "input_shape": (1, 3, 520, 520), }, "video": { - "input_shape": (1, 4, 3, 112, 112), + "input_shape": (1, 3, 4, 112, 112), }, "optical_flow": { "input_shape": (1, 3, 128, 128), @@ -195,6 +195,8 @@ def test_transforms_jit(model_fn): if module_name == "optical_flow": args = (x, x) else: + if module_name == "video": + x = x.permute(0, 2, 1, 3, 4) args = (x,) problematic_weights = [] diff --git a/test/test_models.py b/test/test_models.py index c8d81216aa0..63e4870801f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -309,15 +309,9 @@ def _check_input_backprop(model, inputs): "image_size": 56, "input_shape": (1, 3, 56, 56), }, - "mvit_v2_t": { + "mvit_v1_b": { "input_shape": (1, 3, 16, 224, 224), }, - "mvit_v2_s": { - "input_shape": (1, 3, 16, 224, 224), - }, - "mvit_v2_b": { - "input_shape": (1, 3, 32, 224, 224), - }, } # speeding up slow models: slow_models = [ @@ -347,7 +341,6 @@ def _check_input_backprop(model, inputs): skipped_big_models = { "vit_h_14", "regnet_y_128gf", - "mvit_v2_b", } # The following contains configuration and expected values to be used tests that are model specific diff --git a/torchvision/models/video/__init__.py b/torchvision/models/video/__init__.py index 393a26ccbbe..8990f64a1dc 100644 --- a/torchvision/models/video/__init__.py +++ b/torchvision/models/video/__init__.py @@ -1,2 +1,2 @@ -from .mvitv2 import * +from .mvit import * from .resnet import * diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvit.py similarity index 59% rename from torchvision/models/video/mvitv2.py rename to torchvision/models/video/mvit.py index 206c745f466..a3b543cf8ae 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvit.py @@ -1,4 +1,5 @@ import math +from dataclasses import dataclass from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple @@ -7,26 +8,32 @@ import torch.nn as nn from ...ops import StochasticDepth, MLP +from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import WeightsEnum +from .._api import WeightsEnum, Weights +from .._meta import _KINETICS400_CATEGORIES from .._utils import _ovewrite_named_param __all__ = [ - "MViTV2", - "MViT_V2_T_Weights", - "MViT_V2_S_Weights", - "MViT_V2_B_Weights", - "mvit_v2_t", - "mvit_v2_s", - "mvit_v2_b", + "MViT", + "MViT_V1_B_Weights", + "mvit_v1_b", ] -# TODO: check if we should implement relative pos embedding (Section 4.1 in the paper). Ref: -# https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py#L45 -# TODO: add weights -# TODO: test on references +# TODO: Consider handle 2d input if Temporal is 1 + + +@dataclass +class MSBlockConfig: + num_heads: int + input_channels: int + output_channels: int + kernel_q: List[int] + kernel_kv: List[int] + stride_q: List[int] + stride_kv: List[int] def _prod(s: Sequence[int]) -> int: @@ -36,18 +43,18 @@ def _prod(s: Sequence[int]) -> int: return product -def _unsqueeze(x: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]: tensor_dim = x.dim() - if tensor_dim == 3: - x = x.unsqueeze(1) - elif tensor_dim != 4: + if tensor_dim == target_dim - 1: + x = x.unsqueeze(expand_dim) + elif tensor_dim != target_dim: raise ValueError(f"Unsupported input dimension {x.shape}") return x, tensor_dim -def _squeeze(x: torch.Tensor, tensor_dim: int) -> torch.Tensor: - if tensor_dim == 3: - x = x.squeeze(1) +def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor: + if tensor_dim == target_dim - 1: + x = x.squeeze(expand_dim) return x @@ -74,7 +81,7 @@ def __init__( self.norm_before_pool = norm_before_pool def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: - x, tensor_dim = _unsqueeze(x) + x, tensor_dim = _unsqueeze(x, 4, 1) # Separate the class token and reshape the input class_token, x = torch.tensor_split(x, indices=(1,), dim=2) @@ -95,7 +102,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten if not self.norm_before_pool and self.norm_act is not None: x = self.norm_act(x) - x = _squeeze(x, tensor_dim) + x = _squeeze(x, 4, 1, tensor_dim) return x, (T, H, W) @@ -108,6 +115,7 @@ def __init__( kernel_kv: List[int], stride_q: List[int], stride_kv: List[int], + residual_pool: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: @@ -116,6 +124,7 @@ def __init__( self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) + self.residual_pool = residual_pool self.qkv = nn.Linear(embed_dim, 3 * embed_dim) layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] @@ -182,7 +191,9 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).add_(q) + x = torch.matmul(attn, v) + if self.residual_pool: + x.add_(q) x = x.transpose(1, 2).reshape(B, -1, C) x = self.project(x) @@ -192,13 +203,8 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleBlock(nn.Module): def __init__( self, - input_channels: int, - output_channels: int, - num_heads: int, - kernel_q: List[int], - kernel_kv: List[int], - stride_q: List[int], - stride_kv: List[int], + cnf: MSBlockConfig, + residual_pool: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, @@ -206,30 +212,31 @@ def __init__( super().__init__() self.pool_skip: Optional[nn.Module] = None - if _prod(stride_q) > 1: - kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + if _prod(cnf.stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q] padding_skip = [int(k // 2) for k in kernel_skip] self.pool_skip = Pool( - nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None # type: ignore[arg-type] + nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) - self.norm1 = norm_layer(input_channels) - self.norm2 = norm_layer(input_channels) + self.norm1 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(cnf.input_channels) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.attn = MultiscaleAttention( - input_channels, - num_heads, - kernel_q=kernel_q, - kernel_kv=kernel_kv, - stride_q=stride_q, - stride_kv=stride_kv, + cnf.input_channels, + cnf.num_heads, + kernel_q=cnf.kernel_q, + kernel_kv=cnf.kernel_kv, + stride_q=cnf.stride_q, + stride_kv=cnf.stride_kv, + residual_pool=residual_pool, dropout=dropout, norm_layer=norm_layer, ) self.mlp = MLP( - input_channels, - [4 * input_channels, output_channels], + cnf.input_channels, + [4 * cnf.input_channels, cnf.output_channels], activation_layer=nn.GELU, dropout=dropout, inplace=None, @@ -238,8 +245,8 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.project: Optional[nn.Module] = None - if input_channels != output_channels: - self.project = nn.Linear(input_channels, output_channels) + if cnf.input_channels != cnf.output_channels: + self.project = nn.Linear(cnf.input_channels, cnf.output_channels) def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] @@ -274,18 +281,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat((class_token, x), dim=1).add_(pos_embedding) -class MViTV2(nn.Module): +class MViT(nn.Module): def __init__( self, spatial_size: Tuple[int, int], temporal_size: int, - embed_channels: List[int], - blocks: List[int], - heads: List[int], - pool_kv_stride: List[int], - pool_q_stride: List[int], - pool_kvq_kernel: List[int], - dropout: float = 0.0, + block_setting: Sequence[MSBlockConfig], + residual_pool: bool, + dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 400, @@ -293,17 +296,13 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: """ - MViT V2 main class. + MViT main class. Args: spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. temporal_size (int): The temporal size ``T`` of the input. - embed_channels (list of ints): A list with the embedding dimensions of each block group. - blocks (list of ints): A list with the number of blocks of each block group. - heads (list of ints): A list with the number of heads of each block group. - pool_kv_stride (list of ints): The initiale pooling stride of the first block. - pool_q_stride (list of ints): The pooling stride which reduces q in each block group. - pool_kvq_kernel (list of ints): The pooling kernel for the attention. + block_setting (sequence of MSBlockConfig): The Network structure. + residual_pool (bool): If True, use MViTv2 pooling residual connection. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -317,9 +316,9 @@ def __init__( # We remove any experimental configuration that didn't make it to the final variants of the models. To represent # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. _log_api_usage_once(self) - num_blocks = len(blocks) - if num_blocks != len(embed_channels) or num_blocks != len(heads): - raise ValueError("The parameters 'embed_channels', 'blocks' and 'heads' must have equal length.") + total_stage_blocks = len(block_setting) + if total_stage_blocks == 0: + raise ValueError("The configuration parameter can't be empty.") if block is None: block = MultiscaleBlock @@ -330,7 +329,7 @@ def __init__( # Patch Embedding module self.conv_proj = nn.Conv3d( in_channels=3, - out_channels=embed_channels[0], + out_channels=block_setting[0].input_channels, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3), @@ -338,58 +337,33 @@ def __init__( # Spatio-Temporal Class Positional Encoding self.pos_encoding = PositionalEncoding( - embed_size=embed_channels[0], + embed_size=block_setting[0].input_channels, spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), temporal_size=temporal_size // self.conv_proj.stride[0], ) # Encoder module self.blocks = nn.ModuleList() - stage_block_id = 0 - pool_countdown = blocks[0] - input_channels = output_channels = embed_channels[0] - stride_kv = pool_kv_stride - total_stage_blocks = sum(blocks) - for i, num_subblocks in enumerate(blocks): - for j in range(num_subblocks): - next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i - output_channels = embed_channels[next_block_index] - - stride_q = [1, 1, 1] - if pool_countdown == 0: - stride_q = pool_q_stride - pool_countdown = blocks[next_block_index] - - stride_kv = [max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)] - - # adjust stochastic depth probability based on the depth of the stage block - sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - - self.blocks.append( - block( - input_channels=input_channels, - output_channels=output_channels, - num_heads=heads[i], - kernel_q=pool_kvq_kernel, - kernel_kv=pool_kvq_kernel, - stride_q=stride_q, - stride_kv=stride_kv, - dropout=attention_dropout, - stochastic_depth_prob=sd_prob, - norm_layer=norm_layer, - ) + for stage_block_id, cnf in enumerate(block_setting): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + cnf=cnf, + residual_pool=residual_pool, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, ) - input_channels = output_channels - stage_block_id += 1 - pool_countdown -= 1 - self.norm = norm_layer(output_channels) + ) + self.norm = norm_layer(block_setting[-1].output_channels) # Classifier module - layers: List[nn.Module] = [] - if dropout > 0.0: - layers.append(nn.Dropout(dropout, inplace=True)) - layers.append(nn.Linear(output_channels, num_classes)) - self.head = nn.Sequential(*layers) + self.head = nn.Sequential( + nn.Dropout(dropout, inplace=True), + nn.Linear(block_setting[-1].output_channels, num_classes), + ) for m in self.modules(): if isinstance(m, nn.Linear): @@ -426,32 +400,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _mvitv2( - embed_channels: List[int], - blocks: List[int], - heads: List[int], +def _mvit( + block_setting: List[MSBlockConfig], stochastic_depth_prob: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -) -> MViTV2: +) -> MViT: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] - _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"][0]) + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"]) _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) spatial_size = kwargs.pop("spatial_size", (224, 224)) temporal_size = kwargs.pop("temporal_size", 16) - model = MViTV2( + model = MViT( spatial_size=spatial_size, temporal_size=temporal_size, - embed_channels=embed_channels, - blocks=blocks, - heads=heads, - pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]), - pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]), - pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]), + block_setting=block_setting, + residual_pool=kwargs.pop("residual_pool", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -462,126 +430,210 @@ def _mvitv2( return model -class MViT_V2_T_Weights(WeightsEnum): - pass - - -class MViT_V2_S_Weights(WeightsEnum): - pass - - -class MViT_V2_B_Weights(WeightsEnum): - pass - - -def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: - """ - Constructs a tiny MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. - - Args: - weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_T_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.video.MViT_V2_T_Weights - :members: - """ - weights = MViT_V2_T_Weights.verify(weights) - - return _mvitv2( - spatial_size=(224, 224), - temporal_size=16, - embed_channels=[96, 192, 384, 768], - blocks=[1, 2, 5, 2], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), - weights=weights, - progress=progress, - **kwargs, +class MViT_V1_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md", + "_docs": """These weights support 16-frame clip inputs and were ported from the paper.""", + "num_params": 36610672, + "_metrics": { + "Kinetics-400": { + "acc@1": 78.47, + "acc@5": 93.65, + } + }, + }, ) + DEFAULT = KINETICS400_V1 -def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: +def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: """ - Constructs a small MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. + Constructs a base MViTV1 architecture from + `Multiscale Vision Transformers `__. Args: - weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + :class:`~torchvision.models.video.MViT_V1_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` base class. Please refer to the `source code - `_ + `_ for more details about this class. - .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + .. autoclass:: torchvision.models.video.MViT_V1_B_Weights :members: """ - weights = MViT_V2_S_Weights.verify(weights) - - return _mvitv2( + weights = MViT_V1_B_Weights.verify(weights) + + block_setting = [ + MSBlockConfig( + num_heads=1, + input_channels=96, + output_channels=192, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 8, 8], + ), + MSBlockConfig( + num_heads=2, + input_channels=192, + output_channels=192, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 4, 4], + ), + MSBlockConfig( + num_heads=2, + input_channels=192, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 4, 4], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=768, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=8, + input_channels=768, + output_channels=768, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 1, 1], + ), + MSBlockConfig( + num_heads=8, + input_channels=768, + output_channels=768, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 1, 1], + ), + ] + + return _mvit( spatial_size=(224, 224), temporal_size=16, - embed_channels=[96, 192, 384, 768], - blocks=[1, 2, 11, 2], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), - weights=weights, - progress=progress, - **kwargs, - ) - - -def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: - """ - Constructs a base MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. - - Args: - weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_B_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.video.MViT_V2_B_Weights - :members: - """ - weights = MViT_V2_B_Weights.verify(weights) - - return _mvitv2( - spatial_size=(224, 224), - temporal_size=32, - embed_channels=[96, 192, 384, 768], - blocks=[2, 3, 16, 3], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.3), + block_setting=block_setting, + residual_pool=False, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, **kwargs,