diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..80b5b14cd19 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -15,9 +15,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104 elif data_augmentation == "lsj": self.transforms = T.Compose( [ - T.ScaleJitter(target_size=(1024, 1024)), - T.FixedSizeCrop(size=(1024, 1024), fill=mean), T.RandomHorizontalFlip(p=hflip_prob), + T.ScaleJitter(target_size=(1024, 1024)), + T.FixedSizeCrop(size=(1024, 1024), fill=0), T.PILToTensor(), T.ConvertImageDtype(torch.float), ] diff --git a/references/detection/train.py b/references/detection/train.py index dea483c5f75..0d839282f79 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -126,6 +126,12 @@ def get_args_parser(add_help=True): parser.add_argument( "--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)" ) + parser.add_argument( + "--image-min-size", default=800, type=int, help="resize images so that the smallest side is equal to this" + ) + parser.add_argument( + "--image-max-size", default=1333, type=int, help="resize images so that the largest side is less than this" + ) parser.add_argument( "--sync-bn", dest="sync_bn", @@ -210,7 +216,11 @@ def main(args): ) print("Creating model") - kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} + kwargs = { + "trainable_backbone_layers": args.trainable_backbone_layers, + "min_size": args.image_min_size, + "max_size": args.image_max_size, + } if args.data_augmentation in ["multiscale", "lsj"]: kwargs["_skip_resize"] = True if "rcnn" in args.model: @@ -246,6 +256,25 @@ def main(args): ) elif opt_name == "adamw": optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) + elif opt_name == "vitdet": + from torchvision.models.vision_transformer import get_default_optimizer_params, get_vit_lr_decay_rate + from functools import partial + + optimizer = torch.optim.AdamW( + params=get_default_optimizer_params( + model, + # params.model is meant to be set to the model object, before instantiating + # the optimizer. + base_lr=args.lr, + weight_decay_norm=0.0, + # TODO: Adjust num_layers for specific model. Currently this assumes ViT-B. + lr_factor_func=partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7), + overrides={"pos_embedding": {"weight_decay": 0.0}}, + ), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=0.1, + ) else: raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.") diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..42e4fe4512a 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -406,7 +406,7 @@ def forward(self, img, target=None): class RandomShortestSize(nn.Module): def __init__( self, - min_size: Union[List[int], Tuple[int], int], + min_size: Union[List[int], Tuple[int, ...], int], max_size: int, interpolation: InterpolationMode = InterpolationMode.BILINEAR, ): diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 87ae8627fbe..e1ed1d61dee 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -3,9 +3,14 @@ from torch import nn, Tensor from torchvision.ops import misc as misc_nn_ops -from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool +from torchvision.ops.feature_pyramid_network import ( + ExtraFPNBlock, + FeaturePyramidNetwork, + LastLevelMaxPool, + SimpleFeaturePyramidNetwork, +) -from .. import mobilenet, resnet +from .. import mobilenet, resnet, vision_transformer from .._api import _get_enum_from_fn, WeightsEnum from .._utils import handle_legacy_interface, IntermediateLayerGetter @@ -59,6 +64,46 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return x +class BackboneWithSimpleFPN(nn.Module): + """ + Adds a Simple FPN on top of a model. + Args: + backbone (nn.Module) + in_channels_list (int): number of channels from the provided feature map. + out_channels (int): number of channels in the SFPN. + extra_blocks (ExtraFPNBlock or None): if provided, extra operations will + be performed. It is expected to take the fpn features, the original + features and the names of features as input, and returns + a new list of feature maps and their corresponding names. By + default a ``LastLevelMaxPool`` is used. + """ + + def __init__( + self, + backbone: nn.Module, + in_channels: int, + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + ) -> None: + super().__init__() + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + self.backbone = backbone + self.fpn = SimpleFeaturePyramidNetwork( + in_channels=in_channels, + out_channels=out_channels, + extra_blocks=extra_blocks, + ) + self.out_channels = out_channels + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + x = self.backbone(x) + x = self.fpn(x) + return x + + @handle_legacy_interface( weights=( "pretrained", @@ -240,3 +285,54 @@ def _mobilenet_extractor( ) m.out_channels = out_channels # type: ignore[assignment] return m + + +def vit_sfpn_backbone( + *, + backbone_name: str, + weights: Optional[WeightsEnum], + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> BackboneWithSimpleFPN: + """ + Constructs a specified ViT backbone with SFPN on top. + + Examples:: + + >>> from torchvision.models.detection.backbone_utils import vit_sfpn_backbone + >>> backbone = vit_sfpn_backbone('vit_b_16', weights=ViT_B_16_Weights.DEFAULT) + >>> # get some dummy image + >>> x = torch.rand(1,3,64,64) + >>> # compute the output + >>> output = backbone(x) + >>> print([(k, v.shape) for k, v in output.items()]) + >>> # returns + >>> # TODO + + Args: + backbone_name (string): resnet architecture. Possible values are 'vit_b_16', 'vit_b_32', 'vit_l_16', + 'vit_l_32', 'vit_h_14' + weights (WeightsEnum, optional): The pretrained weights for the model + extra_blocks (ExtraFPNBlock or None): if provided, extra operations will + be performed. It is expected to take the sfpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names. By + default a ``LastLevelMaxPool`` is used. + """ + backbone = vision_transformer.__dict__[backbone_name](weights=weights, include_head=False) + return _vit_sfpn_extractor(backbone, extra_blocks) + + +def _vit_sfpn_extractor( + backbone: vision_transformer.VisionTransformer, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> BackboneWithSimpleFPN: + # TODO: set trainable layers? + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + in_channels = backbone.hidden_dim + out_channels = 256 + return BackboneWithSimpleFPN( + backbone, in_channels, out_channels, extra_blocks=extra_blocks + ) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 695dd4d63ec..8ac62bfeedd 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -10,8 +10,9 @@ from .._meta import _COCO_CATEGORIES from .._utils import _ovewrite_value_param, handle_legacy_interface from ..resnet import resnet50, ResNet50_Weights +from ..vision_transformer import ViT_B_16_Weights, vit_b_16 from ._utils import overwrite_eps -from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _vit_sfpn_extractor, _validate_trainable_layers from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead @@ -585,3 +586,97 @@ def maskrcnn_resnet50_fpn_v2( model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) return model + + +@register_model() +def maskrcnn_vit_b_16_sfpn( + *, + weights: Optional[WeightsEnum] = None, # TODO: Change this to MaskRCNN_ViT_B_16_SFPN_V1_Weights + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ViT_B_16_Weights] = ViT_B_16_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> MaskRCNN: + """Mask R-CNN model with a ViT-B/16 backbone from the `ViT `_ paper. + + .. betastatus:: detection module + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses for both the RPN and the R-CNN, and the mask loss. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detected instances: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each instance + - scores (``Tensor[N]``): the scores or each instance + - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to + obtain the final segmentation masks, the soft masks can be thresholded, generally + with a value of 0.5 (``mask >= 0.5``) + + For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`. + + Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. + + Example:: + + >>> model = torchvision.models.detection.maskrcnn_vit_b_16_sfpn(weights=MaskRCNN_ViT_B_16_SFPN_Weights.DEFAULT) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + >>> + >>> # optionally, if you want to export the model to ONNX: + >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) + + Args: + weights (:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The + pretrained weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_V1_Weights + :members: + """ + # TODO: Add ViTDet weights. + # weights = MaskRCNN_ViT_B_16_SFPN_V1_Weights.verify(weights) + weights_backbone = ViT_B_16_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False, image_size=1024) + backbone = _vit_sfpn_extractor(backbone) + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index f2983ef9db0..39f8aac862b 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,7 +1,9 @@ +import copy +import itertools import math -from collections import OrderedDict +from collections import OrderedDict, defaultdict from functools import partial -from typing import Any, Callable, Dict, List, NamedTuple, Optional +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple import torch import torch.nn as nn @@ -29,6 +31,381 @@ ] +class LayerNorm(torch.nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + TODO: Remove this? The pytorch version doesn't work because it MUST have channels as last dimension.. + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) + self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def get_vit_lr_decay_rate(name: str, lr_decay_rate: float = 1.0, num_layers: int = 12): + """ + Calculate lr decay rate for different ViT blocks. + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py . + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if "backbone.backbone" in name: + if ".pos_embedding" in name or ".conv_proj" in name: + layer_id = 0 + elif ".encoder_layer" in name and ".residual." not in name: + layer_id = int(name[name.find(".encoder_layer_") + len(".encoder_layer_") :].split(".")[0]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_default_optimizer_params( + model: torch.nn.Module, + base_lr: Optional[float] = None, + weight_decay: Optional[float] = None, + weight_decay_norm: Optional[float] = None, + bias_lr_factor: Optional[float] = 1.0, + weight_decay_bias: Optional[float] = None, + lr_factor_func: Optional[Callable] = None, + overrides: Optional[Dict[str, Dict[str, float]]] = None, +) -> List[Dict[str, Any]]: + """ + Get default param list for optimizer, with support for a few types of + overrides. If no overrides needed, this is equivalent to `model.parameters()`. + + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py . + + Args: + base_lr: lr for every group by default. Can be omitted to use the one in optimizer. + weight_decay: weight decay for every group by default. Can be omitted to use the one + in optimizer. + weight_decay_norm: override weight decay for params in normalization layers + bias_lr_factor: multiplier of lr for bias parameters. + weight_decay_bias: override weight decay for bias parameters. + lr_factor_func: function to calculate lr decay rate by mapping the parameter names to + corresponding lr decay rate. Note that setting this option requires + also setting ``base_lr``. + overrides: if not `None`, provides values for optimizer hyperparameters + (LR, weight decay) for module parameters with a given name; e.g. + ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and + weight decay values for all module parameters named `embedding`. + + For common detection models, ``weight_decay_norm`` is the only option + needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings + from Detectron1 that are not found useful. + + Example: + :: + torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), + lr=0.01, weight_decay=1e-4, momentum=0.9) + """ + if overrides is None: + overrides = {} + defaults = {} + if base_lr is not None: + defaults["lr"] = base_lr + if weight_decay is not None: + defaults["weight_decay"] = weight_decay + bias_overrides = {} + if bias_lr_factor is not None and bias_lr_factor != 1.0: + # NOTE: unlike Detectron v1, we now by default make bias hyperparameters + # exactly the same as regular weights. + if base_lr is None: + raise ValueError("bias_lr_factor requires base_lr") + bias_overrides["lr"] = base_lr * bias_lr_factor + if weight_decay_bias is not None: + bias_overrides["weight_decay"] = weight_decay_bias + if len(bias_overrides): + if "bias" in overrides: + raise ValueError("Conflicting overrides for 'bias'") + overrides["bias"] = bias_overrides + if lr_factor_func is not None: + if base_lr is None: + raise ValueError("lr_factor_func requires base_lr") + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + LayerNorm, + ) + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + + hyperparams = copy.copy(defaults) + if isinstance(module, norm_module_types) and weight_decay_norm is not None: + hyperparams["weight_decay"] = weight_decay_norm + if lr_factor_func is not None: + hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") + + hyperparams.update(overrides.get(module_param_name, {})) + params.append({"params": [value], **hyperparams}) + return reduce_param_groups(params) + + +def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + """ + # Transform parameter groups into per-parameter structure. + # Later items in `params` can overwrite parameters set in previous items. + ret = defaultdict(dict) + for item in params: + assert "params" in item + cur_params = {x: y for x, y in item.items() if x != "params" and x != "param_names"} + if "param_names" in item: + for param_name, param in zip(item["param_names"], item["params"]): + ret[param].update({"param_names": [param_name], "params": [param], **cur_params}) + else: + for param in item["params"]: + ret[param].update({"params": [param], **cur_params}) + return list(ret.values()) + + +def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py . + """ + # Reorganize the parameter groups and merge duplicated groups. + # The number of parameter groups needs to be as small as possible in order + # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead + # of using a parameter_group per single parameter, we reorganize the + # parameter groups and merge duplicated groups. This approach speeds + # up multi-tensor optimizer significantly. + params = _expand_param_groups(params) + groups = defaultdict(list) # re-group all parameter groups by their hyperparams + for item in params: + cur_params = tuple((x, y) for x, y in item.items() if x != "params" and x != "param_names") + groups[cur_params].append({"params": item["params"]}) + if "param_names" in item: + groups[cur_params][-1]["param_names"] = item["param_names"] + + ret = [] + for param_keys, param_values in groups.items(): + cur = {kv[0]: kv[1] for kv in param_keys} + cur["params"] = list( + itertools.chain.from_iterable([params["params"] for params in param_values]) + ) + if len(param_values) > 0 and "param_names" in param_values[0]: + cur["param_names"] = list( + itertools.chain.from_iterable([params["param_names"] for params in param_values]) + ) + ret.append(cur) + return ret + + +def window_partition(x: torch.Tensor, window_size: int): + """ Partition into non-overlapping windows with padding if needed. + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py . + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]): + """ Window unpartition into original sequences and removing padding. + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py . + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py . + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = torch.nn.functional.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int] +): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py . + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + +class ViTAttention(nn.Module): + """ + Multi-head Attention block with relative position embeddings. + Original version from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py . + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + class ConvStemConfig(NamedTuple): out_channels: int kernel_size: int @@ -92,25 +469,48 @@ def __init__( hidden_dim: int, mlp_dim: int, dropout: float, - attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + window_size: int = 0, + input_size: int = 0, ): super().__init__() self.num_heads = num_heads # Attention block self.ln_1 = norm_layer(hidden_dim) - self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.self_attention = ViTAttention( + dim=hidden_dim, + num_heads=num_heads, + qkv_bias=True, + use_rel_pos=True, + rel_pos_zero_init=True, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) self.dropout = nn.Dropout(dropout) + self.window_size = window_size # MLP block self.ln_2 = norm_layer(hidden_dim) self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + torch._assert(input.dim() == 4, f"Expected (batch_size, h, w, hidden_dim) got {input.shape}") x = self.ln_1(input) - x, _ = self.self_attention(x, x, x, need_weights=False) + + # Window partition + pad_hw = (0, 0) + H = 0 + W = 0 + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.self_attention(x) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + x = self.dropout(x) x = x + input @@ -124,19 +524,19 @@ class Encoder(nn.Module): def __init__( self, - seq_length: int, num_layers: int, num_heads: int, hidden_dim: int, + patch_size: int, + image_size: int, + window_size: int, + window_block_indices: List[int], mlp_dim: int, dropout: float, - attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() - # Note that batch_size is on the first dim because - # we have batch_first=True in nn.MultiAttention() by default - self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) layers: OrderedDict[str, nn.Module] = OrderedDict() for i in range(num_layers): @@ -145,16 +545,15 @@ def __init__( hidden_dim, mlp_dim, dropout, - attention_dropout, norm_layer, + window_size=window_size if i in window_block_indices else 0, + input_size=(image_size // patch_size, image_size // patch_size), ) self.layers = nn.Sequential(layers) - self.ln = norm_layer(hidden_dim) def forward(self, input: torch.Tensor): - torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") - input = input + self.pos_embedding - return self.ln(self.layers(self.dropout(input))) + torch._assert(input.dim() == 4, f"Expected (batch_size, h, w, hidden_dim) got {input.shape}") + return self.layers(self.dropout(input)) class VisionTransformer(nn.Module): @@ -168,12 +567,15 @@ def __init__( num_heads: int, hidden_dim: int, mlp_dim: int, + window_size: int, + window_block_indices: List[int], dropout: float = 0.0, - attention_dropout: float = 0.0, num_classes: int = 1000, + pretrain_image_size: int = 224, representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, + include_head: bool = True, ): super().__init__() _log_api_usage_once(self) @@ -182,7 +584,6 @@ def __init__( self.patch_size = patch_size self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim - self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes self.representation_size = representation_size @@ -214,33 +615,49 @@ def __init__( in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size ) - seq_length = (image_size // patch_size) ** 2 - # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) - seq_length += 1 + + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + patch_dim = pretrain_image_size // patch_size + self.pos_embedding = nn.Parameter( + torch.empty(1, hidden_dim, patch_dim, patch_dim).normal_(std=0.02) # from BERT + ) self.encoder = Encoder( - seq_length, num_layers, num_heads, hidden_dim, + patch_size, + image_size, + window_size, + window_block_indices, mlp_dim, dropout, - attention_dropout, norm_layer, ) - self.seq_length = seq_length - heads_layers: OrderedDict[str, nn.Module] = OrderedDict() - if representation_size is None: - heads_layers["head"] = nn.Linear(hidden_dim, num_classes) - else: - heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) - heads_layers["act"] = nn.Tanh() - heads_layers["head"] = nn.Linear(representation_size, num_classes) + self.heads = None + if include_head: + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) - self.heads = nn.Sequential(heads_layers) + self.heads = nn.Sequential(heads_layers) + + if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + if isinstance(self.heads.head, nn.Linear): + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) if isinstance(self.conv_proj, nn.Conv2d): # Init the patchify stem @@ -256,53 +673,50 @@ def __init__( if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) - if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): - fan_in = self.heads.pre_logits.in_features - nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) - nn.init.zeros_(self.heads.pre_logits.bias) - - if isinstance(self.heads.head, nn.Linear): - nn.init.zeros_(self.heads.head.weight) - nn.init.zeros_(self.heads.head.bias) - - def _process_input(self, x: torch.Tensor) -> torch.Tensor: - n, c, h, w = x.shape - p = self.patch_size - torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!") - torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!") - n_h = h // p - n_w = w // p + def forward(self, x: torch.Tensor): + # Compute patches from the image # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + # Where n_h and n_w are the number of patches in the height and width x = self.conv_proj(x) - # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) - x = x.reshape(n, self.hidden_dim, n_h * n_w) - # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) - # The self attention layer expects inputs in the format (N, S, E) - # where S is the source sequence length, N is the batch size, E is the - # embedding dimension - x = x.permute(0, 2, 1) - - return x + # Add the positional embedding, but resize to match the image first + pos_embedding = torch.nn.functional.interpolate( + self.pos_embedding, + size=(x.shape[2], x.shape[3]), + mode="bicubic", + align_corners=False, + ) + x = x + pos_embedding - def forward(self, x: torch.Tensor): - # Reshape and permute the input tensor - x = self._process_input(x) - n = x.shape[0] + # Flatten the patches + # n, hidden_dim, n_h, n_w = x.shape + # (n, hidden_dim, n_h, n_w) -> (n, n_h, n_w, hidden_dim) + x = x.permute(0, 2, 3, 1) + # # (n, n_h, n_w, hidden_dim) -> (n, (n_h * n_w), hidden_dim) + # x = x.reshape(n, (n_h * n_w), hidden_dim) - # Expand the class token to the full batch - batch_class_token = self.class_token.expand(n, -1, -1) - x = torch.cat([batch_class_token, x], dim=1) + # TODO: Fix the classifier token. + # # Add the class token + # batch_class_token = self.class_token.expand(n, -1, -1) + # x = torch.cat([batch_class_token, x], dim=1) + # Encode the patches x = self.encoder(x) - # Classifier "token" as used by standard language architectures - x = x[:, 0] - - x = self.heads(x) - - return x + if self.heads is not None: + # Classifier "token" as used by standard language architectures + x = x[:, 0] + x = self.heads(x) + return x + else: + # # Skip classifier token + # x = x[:, 1:] + # # (n, (n_h * n_w), hidden_dim) -> (n, n_h, n_w, hidden_dim) + # x = x.reshape(n, n_h, n_w, hidden_dim) + # (n, n_h, n_w, hidden_dim) -> (n, hidden_dim, n_h, n_w) + x = x.permute(0, 3, 1, 2) + return x def _vision_transformer( @@ -311,15 +725,18 @@ def _vision_transformer( num_heads: int, hidden_dim: int, mlp_dim: int, + window_size: int, + window_block_indices: List[int], weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - assert weights.meta["min_size"][0] == weights.meta["min_size"][1] - _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0]) + # assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + # _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0]) image_size = kwargs.pop("image_size", 224) + include_head = kwargs.pop("include_head", True) model = VisionTransformer( image_size=image_size, @@ -328,11 +745,30 @@ def _vision_transformer( num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, + window_size=window_size, + window_block_indices=window_block_indices, + include_head=include_head, **kwargs, ) if weights: - model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + state_dict = weights.get_state_dict(progress=progress, check_hash=True) + + # Remove head if we don't include the head + if not include_head and "heads.head.weight" in state_dict and "heads.head.bias" in state_dict: + del state_dict["heads.head.weight"] + del state_dict["heads.head.bias"] + + # Fix compatibility with legacy state dict + if "encoder.pos_embedding" in state_dict: + pos_embedding = state_dict["encoder.pos_embedding"] + pos_embedding = pos_embedding.permute(0, 2, 1) + pos_embedding = pos_embedding[:, :, 1:] + pos_embedding = pos_embedding.reshape(1, -1, 14, 14) + state_dict["pos_embedding"] = pos_embedding + del state_dict["encoder.pos_embedding"] + + model.load_state_dict(state_dict) return model @@ -644,6 +1080,18 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru num_heads=12, hidden_dim=768, mlp_dim=3072, + window_size=14, + window_block_indices=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], weights=weights, progress=progress, **kwargs, @@ -678,6 +1126,18 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru num_heads=12, hidden_dim=768, mlp_dim=3072, + window_size=14, + window_block_indices=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], weights=weights, progress=progress, **kwargs, @@ -712,6 +1172,9 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru num_heads=16, hidden_dim=1024, mlp_dim=4096, + window_size=14, + # 5, 11, 17, 23 for global attention + window_block_indices=list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)), weights=weights, progress=progress, **kwargs, @@ -746,6 +1209,9 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru num_heads=16, hidden_dim=1024, mlp_dim=4096, + window_size=14, + # 5, 11, 17, 23 for global attention + window_block_indices=list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)), weights=weights, progress=progress, **kwargs, @@ -780,6 +1246,9 @@ def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = Tru num_heads=16, hidden_dim=1280, mlp_dim=5120, + window_size=14, + # 7, 15, 23, 31 for global attention + window_block_indices=list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)), weights=weights, progress=progress, **kwargs, diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..c58e290fef7 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -16,7 +16,7 @@ from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d -from .feature_pyramid_network import FeaturePyramidNetwork +from .feature_pyramid_network import FeaturePyramidNetwork, SimpleFeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation @@ -54,6 +54,7 @@ "PSRoIPool", "MultiScaleRoIAlign", "FeaturePyramidNetwork", + "SimpleFeaturePyramidNetwork", "sigmoid_focal_loss", "stochastic_depth", "StochasticDepth", diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 2e7aef0e2fa..ac14cf3a240 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -1,9 +1,11 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple +import torch import torch.nn.functional as F from torch import nn, Tensor +from ..models.vision_transformer import LayerNorm from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once @@ -204,6 +206,140 @@ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: return out +class SimpleFeaturePyramidNetwork(nn.Module): + """ + Module that adds a Simple FPN from on top of a set of feature maps. This is based on + `"Exploring Plain Vision Transformer Backbones for Object Detection" `_. + + Unlike regular FPN, Simple FPN expects a single feature map, + on which the Simple FPN will be added. + + Args: + in_channels (int): number of channels for the input feature map that + is passed to the module + out_channels (int): number of channels of the Simple FPN representation + extra_blocks (ExtraFPNBlock or None): if provided, extra operations will + be performed. It is expected to take the fpn features, the original + features and the names of the original features as input, and returns + a new list of feature maps and their corresponding names + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: LayerNorm + + Examples:: + + >>> m = torchvision.ops.SimpleFeaturePyramidNetwork(10, 5) + >>> # get some dummy data + >>> x = torch.rand(1, 10, 64, 64) + >>> # compute the Simple FPN on top of x + >>> output = m(x) + >>> print([(k, v.shape) for k, v in output.items()]) + >>> # returns + >>> [('feat0', torch.Size([1, 5, 64, 64])), + >>> ('feat2', torch.Size([1, 5, 16, 16])), + >>> ('feat3', torch.Size([1, 5, 8, 8]))] + + """ + + _version = 2 + + def __init__( + self, + in_channels: int, + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + norm_layer: Optional[Callable[..., nn.Module]] = LayerNorm, + ): + super().__init__() + _log_api_usage_once(self) + self.blocks = nn.ModuleList() + if in_channels <= 0: + raise ValueError("in_channels <= 0 is currently not supported") + + for block_index in range(4): + layers = [] + + current_in_channels = in_channels + if block_index == 0: + layers.extend([ + nn.ConvTranspose2d( + in_channels, + in_channels // 2, + kernel_size=2, + stride=2, + ), + LayerNorm(in_channels // 2), + nn.GELU(), + nn.ConvTranspose2d( + in_channels // 2, + in_channels // 4, + kernel_size=2, + stride=2, + ), + ]) + current_in_channels = in_channels // 4 + elif block_index == 1: + layers.append( + nn.ConvTranspose2d( + in_channels, + in_channels // 2, + kernel_size=2, + stride=2, + ), + ) + current_in_channels = in_channels // 2 + elif block_index == 2: + # nothing to do for this scale + pass + elif block_index == 3: + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + + layers.extend([ + Conv2dNormActivation( + current_in_channels, + out_channels, + kernel_size=1, + padding=0, + norm_layer=norm_layer, + activation_layer=None + ), + Conv2dNormActivation( + out_channels, + out_channels, + kernel_size=3, + norm_layer=norm_layer, + activation_layer=None + ) + ]) + + self.blocks.append(nn.Sequential(*layers)) + + if extra_blocks is not None: + if not isinstance(extra_blocks, ExtraFPNBlock): + raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}") + self.extra_blocks = extra_blocks + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + """ + Computes the Simple FPN for a feature map. + + Args: + x (Tensor): input feature map. + + Returns: + results (list[Tensor]): feature maps after FPN layers. + They are ordered from highest resolution first. + """ + results = [block(x) for block in self.blocks] + names = [f"{i}" for i in range(len(self.blocks))] + + if self.extra_blocks is not None: + results, names = self.extra_blocks(results, [x], names) + + # make it back an OrderedDict + out = OrderedDict([(k, v) for k, v in zip(names, results)]) + + return out + + class LastLevelMaxPool(ExtraFPNBlock): """ Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map