Skip to content

ViTDet object detection + segmentation implementation #7690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=0),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
Expand Down
31 changes: 30 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
)
parser.add_argument(
"--image-min-size", default=800, type=int, help="resize images so that the smallest side is equal to this"
)
parser.add_argument(
"--image-max-size", default=1333, type=int, help="resize images so that the largest side is less than this"
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
Expand Down Expand Up @@ -210,7 +216,11 @@ def main(args):
)

print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
kwargs = {
"trainable_backbone_layers": args.trainable_backbone_layers,
"min_size": args.image_min_size,
"max_size": args.image_max_size,
}
if args.data_augmentation in ["multiscale", "lsj"]:
kwargs["_skip_resize"] = True
if "rcnn" in args.model:
Expand Down Expand Up @@ -246,6 +256,25 @@ def main(args):
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
elif opt_name == "vitdet":
from torchvision.models.vision_transformer import get_default_optimizer_params, get_vit_lr_decay_rate
from functools import partial

optimizer = torch.optim.AdamW(
params=get_default_optimizer_params(
model,
# params.model is meant to be set to the model object, before instantiating
# the optimizer.
base_lr=args.lr,
weight_decay_norm=0.0,
# TODO: Adjust num_layers for specific model. Currently this assumes ViT-B.
lr_factor_func=partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7),
overrides={"pos_embedding": {"weight_decay": 0.0}},
),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=0.1,
)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")

Expand Down
2 changes: 1 addition & 1 deletion references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(self, img, target=None):
class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
min_size: Union[List[int], Tuple[int, ...], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
Expand Down
100 changes: 98 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import (
ExtraFPNBlock,
FeaturePyramidNetwork,
LastLevelMaxPool,
SimpleFeaturePyramidNetwork,
)

from .. import mobilenet, resnet
from .. import mobilenet, resnet, vision_transformer
from .._api import _get_enum_from_fn, WeightsEnum
from .._utils import handle_legacy_interface, IntermediateLayerGetter

Expand Down Expand Up @@ -59,6 +64,46 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
return x


class BackboneWithSimpleFPN(nn.Module):
"""
Adds a Simple FPN on top of a model.
Args:
backbone (nn.Module)
in_channels_list (int): number of channels from the provided feature map.
out_channels (int): number of channels in the SFPN.
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original
features and the names of features as input, and returns
a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used.
"""

def __init__(
self,
backbone: nn.Module,
in_channels: int,
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()

if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

self.backbone = backbone
self.fpn = SimpleFeaturePyramidNetwork(
in_channels=in_channels,
out_channels=out_channels,
extra_blocks=extra_blocks,
)
self.out_channels = out_channels

def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.backbone(x)
x = self.fpn(x)
return x


@handle_legacy_interface(
weights=(
"pretrained",
Expand Down Expand Up @@ -240,3 +285,54 @@ def _mobilenet_extractor(
)
m.out_channels = out_channels # type: ignore[assignment]
return m


def vit_sfpn_backbone(
*,
backbone_name: str,
weights: Optional[WeightsEnum],
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithSimpleFPN:
"""
Constructs a specified ViT backbone with SFPN on top.

Examples::

>>> from torchvision.models.detection.backbone_utils import vit_sfpn_backbone
>>> backbone = vit_sfpn_backbone('vit_b_16', weights=ViT_B_16_Weights.DEFAULT)
>>> # get some dummy image
>>> x = torch.rand(1,3,64,64)
>>> # compute the output
>>> output = backbone(x)
>>> print([(k, v.shape) for k, v in output.items()])
>>> # returns
>>> # TODO

Args:
backbone_name (string): resnet architecture. Possible values are 'vit_b_16', 'vit_b_32', 'vit_l_16',
'vit_l_32', 'vit_h_14'
weights (WeightsEnum, optional): The pretrained weights for the model
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the sfpn features, the original
features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used.
"""
backbone = vision_transformer.__dict__[backbone_name](weights=weights, include_head=False)
return _vit_sfpn_extractor(backbone, extra_blocks)


def _vit_sfpn_extractor(
backbone: vision_transformer.VisionTransformer,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithSimpleFPN:
# TODO: set trainable layers?

if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

in_channels = backbone.hidden_dim
out_channels = 256
return BackboneWithSimpleFPN(
backbone, in_channels, out_channels, extra_blocks=extra_blocks
)
97 changes: 96 additions & 1 deletion torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from .._meta import _COCO_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
from ..vision_transformer import ViT_B_16_Weights, vit_b_16
from ._utils import overwrite_eps
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .backbone_utils import _resnet_fpn_extractor, _vit_sfpn_extractor, _validate_trainable_layers
from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead


Expand Down Expand Up @@ -585,3 +586,97 @@ def maskrcnn_resnet50_fpn_v2(
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

return model


@register_model()
def maskrcnn_vit_b_16_sfpn(
*,
weights: Optional[WeightsEnum] = None, # TODO: Change this to MaskRCNN_ViT_B_16_SFPN_V1_Weights
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ViT_B_16_Weights] = ViT_B_16_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> MaskRCNN:
"""Mask R-CNN model with a ViT-B/16 backbone from the `ViT <https://arxiv.org/abs/2010.11929>`_ paper.

.. betastatus:: detection module

The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets (list of dictionary),
containing:

- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
- masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance

The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
losses for both the RPN and the R-CNN, and the mask loss.

During inference, the model requires only the input tensors, and returns the post-processed
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
follows, where ``N`` is the number of detected instances:

- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
- labels (``Int64Tensor[N]``): the predicted labels for each instance
- scores (``Tensor[N]``): the scores or each instance
- masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
obtain the final segmentation masks, the soft masks can be thresholded, generally
with a value of 0.5 (``mask >= 0.5``)

For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.

Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.

Example::

>>> model = torchvision.models.detection.maskrcnn_vit_b_16_sfpn(weights=MaskRCNN_ViT_B_16_SFPN_Weights.DEFAULT)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
>>>
>>> # optionally, if you want to export the model to ONNX:
>>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)

Args:
weights (:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The
pretrained weights for the backbone.
**kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
for more details about this class.

.. autoclass:: torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_V1_Weights
:members:
"""
# TODO: Add ViTDet weights.
# weights = MaskRCNN_ViT_B_16_SFPN_V1_Weights.verify(weights)
weights_backbone = ViT_B_16_Weights.verify(weights_backbone)

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False, image_size=1024)
backbone = _vit_sfpn_extractor(backbone)
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model
Loading