Skip to content

Commit d80c4eb

Browse files
committed
First rough implementation of ViTDet.
1 parent 48e5ad7 commit d80c4eb

File tree

3 files changed

+483
-69
lines changed

3 files changed

+483
-69
lines changed

torchvision/models/detection/backbone_utils.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33

44
from torch import nn, Tensor
55
from torchvision.ops import misc as misc_nn_ops
6-
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
6+
from torchvision.ops.feature_pyramid_network import (
7+
ExtraFPNBlock,
8+
FeaturePyramidNetwork,
9+
LastLevelMaxPool,
10+
SimpleFeaturePyramidNetwork,
11+
)
712

8-
from .. import mobilenet, resnet
13+
from .. import mobilenet, resnet, vision_transformer
914
from .._api import _get_enum_from_fn, WeightsEnum
1015
from .._utils import handle_legacy_interface, IntermediateLayerGetter
1116

@@ -59,6 +64,46 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
5964
return x
6065

6166

67+
class BackboneWithSimpleFPN(nn.Module):
68+
"""
69+
Adds a Simple FPN on top of a model.
70+
Args:
71+
backbone (nn.Module)
72+
in_channels_list (int): number of channels from the provided feature map.
73+
out_channels (int): number of channels in the SFPN.
74+
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
75+
be performed. It is expected to take the fpn features, the original
76+
features and the names of features as input, and returns
77+
a new list of feature maps and their corresponding names. By
78+
default a ``LastLevelMaxPool`` is used.
79+
"""
80+
81+
def __init__(
82+
self,
83+
backbone: nn.Module,
84+
in_channels: int,
85+
out_channels: int,
86+
extra_blocks: Optional[ExtraFPNBlock] = None,
87+
) -> None:
88+
super().__init__()
89+
90+
if extra_blocks is None:
91+
extra_blocks = LastLevelMaxPool()
92+
93+
self.backbone = backbone
94+
self.fpn = SimpleFeaturePyramidNetwork(
95+
in_channels=in_channels,
96+
out_channels=out_channels,
97+
extra_blocks=extra_blocks,
98+
)
99+
self.out_channels = out_channels
100+
101+
def forward(self, x: Tensor) -> List[Tensor]:
102+
x = self.backbone(x)
103+
x = self.fpn(x)
104+
return x
105+
106+
62107
@handle_legacy_interface(
63108
weights=(
64109
"pretrained",
@@ -240,3 +285,54 @@ def _mobilenet_extractor(
240285
)
241286
m.out_channels = out_channels # type: ignore[assignment]
242287
return m
288+
289+
290+
def vit_sfpn_backbone(
291+
*,
292+
backbone_name: str,
293+
weights: Optional[WeightsEnum],
294+
extra_blocks: Optional[ExtraFPNBlock] = None,
295+
) -> BackboneWithSimpleFPN:
296+
"""
297+
Constructs a specified ViT backbone with SFPN on top.
298+
299+
Examples::
300+
301+
>>> from torchvision.models.detection.backbone_utils import vit_sfpn_backbone
302+
>>> backbone = vit_sfpn_backbone('vit_b_16', weights=ViT_B_16_Weights.DEFAULT)
303+
>>> # get some dummy image
304+
>>> x = torch.rand(1,3,64,64)
305+
>>> # compute the output
306+
>>> output = backbone(x)
307+
>>> print([(k, v.shape) for k, v in output.items()])
308+
>>> # returns
309+
>>> # TODO
310+
311+
Args:
312+
backbone_name (string): resnet architecture. Possible values are 'vit_b_16', 'vit_b_32', 'vit_l_16',
313+
'vit_l_32', 'vit_h_14'
314+
weights (WeightsEnum, optional): The pretrained weights for the model
315+
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
316+
be performed. It is expected to take the sfpn features, the original
317+
features and the names of the original features as input, and returns
318+
a new list of feature maps and their corresponding names. By
319+
default a ``LastLevelMaxPool`` is used.
320+
"""
321+
backbone = vision_transformer.__dict__[backbone_name](weights=weights, include_head=False)
322+
return _vit_sfpn_extractor(backbone, extra_blocks)
323+
324+
325+
def _vit_sfpn_extractor(
326+
backbone: vision_transformer.VisionTransformer,
327+
extra_blocks: Optional[ExtraFPNBlock] = None,
328+
) -> BackboneWithSimpleFPN:
329+
# TODO: set trainable layers?
330+
331+
if extra_blocks is None:
332+
extra_blocks = LastLevelMaxPool()
333+
334+
in_channels = backbone.hidden_dim
335+
out_channels = 256
336+
return BackboneWithSimpleFPN(
337+
backbone, in_channels, out_channels, extra_blocks=extra_blocks
338+
)

torchvision/models/detection/mask_rcnn.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from .._meta import _COCO_CATEGORIES
1111
from .._utils import _ovewrite_value_param, handle_legacy_interface
1212
from ..resnet import resnet50, ResNet50_Weights
13+
from ..vision_transformer import ViT_B_16_Weights, vit_b_16
1314
from ._utils import overwrite_eps
14-
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
15+
from .backbone_utils import _resnet_fpn_extractor, _vit_sfpn_extractor, _validate_trainable_layers
1516
from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
1617

1718

@@ -585,3 +586,96 @@ def maskrcnn_resnet50_fpn_v2(
585586
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
586587

587588
return model
589+
590+
591+
def maskrcnn_vit_b_16_sfpn(
592+
*,
593+
weights: Optional[WeightsEnum] = None, # TODO: Change this to MaskRCNN_ViT_B_16_SFPN_V1_Weights
594+
progress: bool = True,
595+
num_classes: Optional[int] = None,
596+
weights_backbone: Optional[ViT_B_16_Weights] = ViT_B_16_Weights.IMAGENET1K_V1,
597+
**kwargs: Any,
598+
) -> MaskRCNN:
599+
"""Mask R-CNN model with a ViT-B/16 backbone from the `ViT <https://arxiv.org/abs/2010.11929>`_ paper.
600+
601+
.. betastatus:: detection module
602+
603+
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
604+
image, and should be in ``0-1`` range. Different images can have different sizes.
605+
606+
The behavior of the model changes depending if it is in training or evaluation mode.
607+
608+
During training, the model expects both the input tensors, as well as a targets (list of dictionary),
609+
containing:
610+
611+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
612+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
613+
- labels (``Int64Tensor[N]``): the class label for each ground-truth box
614+
- masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
615+
616+
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
617+
losses for both the RPN and the R-CNN, and the mask loss.
618+
619+
During inference, the model requires only the input tensors, and returns the post-processed
620+
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
621+
follows, where ``N`` is the number of detected instances:
622+
623+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
624+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
625+
- labels (``Int64Tensor[N]``): the predicted labels for each instance
626+
- scores (``Tensor[N]``): the scores or each instance
627+
- masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
628+
obtain the final segmentation masks, the soft masks can be thresholded, generally
629+
with a value of 0.5 (``mask >= 0.5``)
630+
631+
For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
632+
633+
Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
634+
635+
Example::
636+
637+
>>> model = torchvision.models.detection.maskrcnn_vit_b_16_sfpn(weights=MaskRCNN_ViT_B_16_SFPN_Weights.DEFAULT)
638+
>>> model.eval()
639+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
640+
>>> predictions = model(x)
641+
>>>
642+
>>> # optionally, if you want to export the model to ONNX:
643+
>>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
644+
645+
Args:
646+
weights (:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights`, optional): The
647+
pretrained weights to use. See
648+
:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights` below for
649+
more details, and possible values. By default, no pre-trained
650+
weights are used.
651+
progress (bool, optional): If True, displays a progress bar of the
652+
download to stderr. Default is True.
653+
num_classes (int, optional): number of output classes of the model (including the background)
654+
weights_backbone (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The
655+
pretrained weights for the backbone.
656+
**kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
657+
base class. Please refer to the `source code
658+
<https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
659+
for more details about this class.
660+
661+
.. autoclass:: torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_V1_Weights
662+
:members:
663+
"""
664+
# TODO: Add ViTDet weights.
665+
# weights = MaskRCNN_ViT_B_16_SFPN_V1_Weights.verify(weights)
666+
weights_backbone = ViT_B_16_Weights.verify(weights_backbone)
667+
668+
if weights is not None:
669+
weights_backbone = None
670+
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
671+
elif num_classes is None:
672+
num_classes = 91
673+
674+
backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False)
675+
backbone = _vit_sfpn_extractor(backbone)
676+
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
677+
678+
if weights is not None:
679+
model.load_state_dict(weights.get_state_dict(progress=progress))
680+
681+
return model

0 commit comments

Comments
 (0)