|
10 | 10 | from .._meta import _COCO_CATEGORIES
|
11 | 11 | from .._utils import _ovewrite_value_param, handle_legacy_interface
|
12 | 12 | from ..resnet import resnet50, ResNet50_Weights
|
| 13 | +from ..vision_transformer import ViT_B_16_Weights, vit_b_16 |
13 | 14 | from ._utils import overwrite_eps
|
14 |
| -from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers |
| 15 | +from .backbone_utils import _resnet_fpn_extractor, _vit_sfpn_extractor, _validate_trainable_layers |
15 | 16 | from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
|
16 | 17 |
|
17 | 18 |
|
@@ -585,3 +586,96 @@ def maskrcnn_resnet50_fpn_v2(
|
585 | 586 | model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
586 | 587 |
|
587 | 588 | return model
|
| 589 | + |
| 590 | + |
| 591 | +def maskrcnn_vit_b_16_sfpn( |
| 592 | + *, |
| 593 | + weights: Optional[WeightsEnum] = None, # TODO: Change this to MaskRCNN_ViT_B_16_SFPN_V1_Weights |
| 594 | + progress: bool = True, |
| 595 | + num_classes: Optional[int] = None, |
| 596 | + weights_backbone: Optional[ViT_B_16_Weights] = ViT_B_16_Weights.IMAGENET1K_V1, |
| 597 | + **kwargs: Any, |
| 598 | +) -> MaskRCNN: |
| 599 | + """Mask R-CNN model with a ViT-B/16 backbone from the `ViT <https://arxiv.org/abs/2010.11929>`_ paper. |
| 600 | +
|
| 601 | + .. betastatus:: detection module |
| 602 | +
|
| 603 | + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each |
| 604 | + image, and should be in ``0-1`` range. Different images can have different sizes. |
| 605 | +
|
| 606 | + The behavior of the model changes depending if it is in training or evaluation mode. |
| 607 | +
|
| 608 | + During training, the model expects both the input tensors, as well as a targets (list of dictionary), |
| 609 | + containing: |
| 610 | +
|
| 611 | + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with |
| 612 | + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. |
| 613 | + - labels (``Int64Tensor[N]``): the class label for each ground-truth box |
| 614 | + - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance |
| 615 | +
|
| 616 | + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression |
| 617 | + losses for both the RPN and the R-CNN, and the mask loss. |
| 618 | +
|
| 619 | + During inference, the model requires only the input tensors, and returns the post-processed |
| 620 | + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as |
| 621 | + follows, where ``N`` is the number of detected instances: |
| 622 | +
|
| 623 | + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with |
| 624 | + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. |
| 625 | + - labels (``Int64Tensor[N]``): the predicted labels for each instance |
| 626 | + - scores (``Tensor[N]``): the scores or each instance |
| 627 | + - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to |
| 628 | + obtain the final segmentation masks, the soft masks can be thresholded, generally |
| 629 | + with a value of 0.5 (``mask >= 0.5``) |
| 630 | +
|
| 631 | + For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`. |
| 632 | +
|
| 633 | + Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size. |
| 634 | +
|
| 635 | + Example:: |
| 636 | +
|
| 637 | + >>> model = torchvision.models.detection.maskrcnn_vit_b_16_sfpn(weights=MaskRCNN_ViT_B_16_SFPN_Weights.DEFAULT) |
| 638 | + >>> model.eval() |
| 639 | + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] |
| 640 | + >>> predictions = model(x) |
| 641 | + >>> |
| 642 | + >>> # optionally, if you want to export the model to ONNX: |
| 643 | + >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) |
| 644 | +
|
| 645 | + Args: |
| 646 | + weights (:class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights`, optional): The |
| 647 | + pretrained weights to use. See |
| 648 | + :class:`~torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_Weights` below for |
| 649 | + more details, and possible values. By default, no pre-trained |
| 650 | + weights are used. |
| 651 | + progress (bool, optional): If True, displays a progress bar of the |
| 652 | + download to stderr. Default is True. |
| 653 | + num_classes (int, optional): number of output classes of the model (including the background) |
| 654 | + weights_backbone (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The |
| 655 | + pretrained weights for the backbone. |
| 656 | + **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN`` |
| 657 | + base class. Please refer to the `source code |
| 658 | + <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_ |
| 659 | + for more details about this class. |
| 660 | +
|
| 661 | + .. autoclass:: torchvision.models.detection.MaskRCNN_ViT_B_16_SFPN_V1_Weights |
| 662 | + :members: |
| 663 | + """ |
| 664 | + # TODO: Add ViTDet weights. |
| 665 | + # weights = MaskRCNN_ViT_B_16_SFPN_V1_Weights.verify(weights) |
| 666 | + weights_backbone = ViT_B_16_Weights.verify(weights_backbone) |
| 667 | + |
| 668 | + if weights is not None: |
| 669 | + weights_backbone = None |
| 670 | + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) |
| 671 | + elif num_classes is None: |
| 672 | + num_classes = 91 |
| 673 | + |
| 674 | + backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False) |
| 675 | + backbone = _vit_sfpn_extractor(backbone) |
| 676 | + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) |
| 677 | + |
| 678 | + if weights is not None: |
| 679 | + model.load_state_dict(weights.get_state_dict(progress=progress)) |
| 680 | + |
| 681 | + return model |
0 commit comments