Skip to content

Added typing annotations to models/detection [1/n] #4220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
b2f6615
fix
oke-aditya May 20, 2021
4fb038d
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 20, 2021
deda5d7
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
5490821
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
4cfc220
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya May 21, 2021
6306746
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 24, 2021
e8c93cf
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
6871ccc
Merge branch 'master' of https://github.com/pytorch/vision
oke-aditya Jul 28, 2021
53fe949
start adding typing
oke-aditya Jul 28, 2021
ecc58a7
finish typing frcnn
oke-aditya Jul 29, 2021
6a30c92
Type up models
oke-aditya Jul 29, 2021
fb3ea88
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Jul 29, 2021
1f5f715
fix bugs
oke-aditya Jul 29, 2021
d13311d
fix bugs
oke-aditya Jul 29, 2021
5adda72
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Jul 29, 2021
d16be19
combine PRs
oke-aditya Jul 29, 2021
4580541
fix typing
oke-aditya Jul 29, 2021
1ba33ce
fixup types
oke-aditya Jul 29, 2021
254e51b
type another file
oke-aditya Jul 29, 2021
47f75dc
fix bug
oke-aditya Jul 29, 2021
6aff88c
fix roi
oke-aditya Jul 30, 2021
516fb68
try fixing types
oke-aditya Jul 30, 2021
b026039
remove unnecessary import
oke-aditya Jul 30, 2021
18ec557
fix some types
oke-aditya Jul 30, 2021
24e3f74
undo fmt
oke-aditya Jul 30, 2021
ab7c671
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Jul 30, 2021
ff5698a
fixup types
oke-aditya Jul 30, 2021
50641d5
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 2, 2021
6a9c0cb
undo roi heads and skip for now
oke-aditya Aug 2, 2021
700d74d
enable mypy
oke-aditya Aug 2, 2021
c0e836b
fixup mypy
oke-aditya Aug 2, 2021
0a93c17
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 5, 2021
a0b3b2a
fix test
oke-aditya Aug 5, 2021
fc8032c
small fix for tuple
oke-aditya Aug 5, 2021
730a33d
commit correctly
oke-aditya Aug 5, 2021
4cd5b84
undo
oke-aditya Aug 5, 2021
c278189
Merge branch 'add_typing1' of github.com:oke-aditya/vision into add_t…
oke-aditya Aug 5, 2021
fe7e289
fix
oke-aditya Aug 5, 2021
1fda3e4
fix makefile
oke-aditya Aug 5, 2021
0daf0d8
don't modify test
oke-aditya Aug 5, 2021
4fa9294
fix tuple bug
oke-aditya Aug 5, 2021
f9c7fbe
Merge branch 'master' of https://github.com/pytorch/vision into add_t…
oke-aditya Aug 16, 2021
f1eeea1
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Sep 13, 2021
12e7d7b
Merge branch 'main' into add_typing1
pmeier Sep 14, 2021
eae251f
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Sep 22, 2021
9d0d6cf
Try onnx workaround
oke-aditya Sep 22, 2021
e437b60
Merge branch 'add_typing1' of github.com:oke-aditya/vision into add_t…
oke-aditya Sep 22, 2021
a6bb70c
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Sep 22, 2021
76b26ba
Merge branch 'main' into add_typing1
datumbox Sep 22, 2021
623d5e4
Merge branch 'add_typing1' of github.com:oke-aditya/vision into add_t…
oke-aditya Sep 22, 2021
bb73a22
Fix bugs
oke-aditya Sep 22, 2021
4e8e9f7
Green up CI
oke-aditya Sep 22, 2021
ca5d600
fixes ssd
oke-aditya Sep 22, 2021
8976768
Use Union for TS
oke-aditya Sep 22, 2021
6fc8812
Fix
oke-aditya Sep 22, 2021
7422d3b
Fix small type
oke-aditya Sep 22, 2021
a10588c
Fix type
oke-aditya Sep 22, 2021
5f2dcc4
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Sep 25, 2021
70e5b6a
Handle tuple[tuple[int, int]] case, List[float]
oke-aditya Sep 25, 2021
7d4e919
Use nn.ModuleList and nn.ModuleDict
oke-aditya Sep 25, 2021
01fdb48
Resolve conflicts
oke-aditya Oct 5, 2021
cace5ff
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Oct 6, 2021
e941e48
Fix JIT issue, final update for CI Check
oke-aditya Oct 6, 2021
623017b
Merge branch 'main' of https://github.com/pytorch/vision into add_typ…
oke-aditya Oct 6, 2021
becbbd6
Fix JIT issue
oke-aditya Oct 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,23 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.detection.*]
[mypy-torchvision.models.detection._utils]

ignore_errors = True

[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True

[mypy-torchvision.models.detection.roi_heads]

ignore_errors = True

[mypy-torchvision.models.detection.ssdlite]

ignore_errors = True

Expand Down
120 changes: 70 additions & 50 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Tuple, Any, cast, List

import torch.nn.functional as F
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
Expand Down Expand Up @@ -55,10 +57,10 @@ class FasterRCNN(GeneralizedRCNN):
If box_predictor is specified, num_classes should be None.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
image_mean (Tuple[float, float, float]): mean values used for input normalization.
image_mean (List[float]): mean values used for input normalization.
They are generally the mean values of the dataset on which the backbone has been trained
on
image_std (Tuple[float, float, float]): std values used for input normalization.
image_std (List[float]): std values used for input normalization.
They are generally the std values of the dataset on which the backbone has been trained on
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
maps.
Expand Down Expand Up @@ -143,39 +145,39 @@ class FasterRCNN(GeneralizedRCNN):

def __init__(
self,
backbone,
num_classes=None,
backbone: nn.Module,
num_classes: Optional[int] = None,
# transform parameters
min_size=800,
max_size=1333,
image_mean=None,
image_std=None,
min_size: int = 800,
max_size: int = 1333,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
# RPN parameters
rpn_anchor_generator=None,
rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
rpn_anchor_generator: Optional[AnchorGenerator] = None,
rpn_head: Optional[nn.Module] = None,
rpn_pre_nms_top_n_train: int = 2000,
rpn_pre_nms_top_n_test: int = 1000,
rpn_post_nms_top_n_train: int = 2000,
rpn_post_nms_top_n_test: int = 1000,
rpn_nms_thresh: float = 0.7,
rpn_fg_iou_thresh: float = 0.7,
rpn_bg_iou_thresh: float = 0.3,
rpn_batch_size_per_image: int = 256,
rpn_positive_fraction: float = 0.5,
rpn_score_thresh: float = 0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
):
box_roi_pool: Optional[MultiScaleRoIAlign] = None,
box_head: Optional[nn.Module] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

box_head: Optional[RPNHead] = None I believe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was bit unsure on these. I think that user can write their own classes and just pass them to FasterRCNN.
Hence I thought it shouldn't be so particular class like RPNHead.

Let me know, I will change all of these otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow passing custom classes, we need structural duck typing here. Meaning we need to define a custom typing.Protocol that defines how these objects need to look like. @oke-aditya can you make a list of all the attributes we access and methods we call on these objects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure of what you mean @pmeier . I'm new to type hints.
I think trying to be very strict in typing by custom classes might be too much. If nn.Module satisfies the purpose I think it should be good enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nn.Module satisfies the purpose I think it should be good enough?

Yes. As a rule of thumb, the input types should be as loose as possible and the output types as strict as possible.

After some digging, I think

box_features = self.box_head(box_features)

is the only time we do anything with the object passed as box_head. Since it is possible to call an nn.Module, using it as annotation is fine.


In other cases @frgfm's comment might have some merit in other cases: for example, if we at some point access box_head.foo, nn.Module is not sufficient as annotation, since it doesn't define a foo attribute. In such a case we have two options:

  1. Annotate with a custom class such as RPNHead that has this attribute.
  2. Create a typing.Protocol that defines the interface we are looking for. In the example above the object passed to box_head needs to be callable and define an attribute foo.

Especially if users are allowed to pass custom objects, I would always prefer 2. over 1. because it doesn't require any changes on the user side. I don't know if this works with torchscript though.

Let's discuss this if the need for something like this arises.

box_predictor: Optional[nn.Module] = None,
box_score_thresh: float = 0.05,
box_nms_thresh: float = 0.5,
box_detections_per_img: int = 100,
box_fg_iou_thresh: float = 0.5,
box_bg_iou_thresh: float = 0.5,
box_batch_size_per_image: int = 512,
box_positive_fraction: float = 0.25,
bbox_reg_weights: Optional[Tuple[float, ...]] = None,
) -> None:

if not hasattr(backbone, "out_channels"):
raise ValueError(
Expand All @@ -194,7 +196,7 @@ def __init__(
if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor " "is not specified")

out_channels = backbone.out_channels
out_channels = cast(int, backbone.out_channels)

if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
Expand Down Expand Up @@ -229,7 +231,7 @@ def __init__(

if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(representation_size, num_classes)
box_predictor = FastRCNNPredictor(representation_size, num_classes) # type: ignore[arg-type]

roi_heads = RoIHeads(
# Box
Expand Down Expand Up @@ -264,7 +266,7 @@ class TwoMLPHead(nn.Module):
representation_size (int): size of the intermediate representation
"""

def __init__(self, in_channels, representation_size):
def __init__(self, in_channels: int, representation_size: int) -> None:
super(TwoMLPHead, self).__init__()

self.fc6 = nn.Linear(in_channels, representation_size)
Expand All @@ -289,7 +291,7 @@ class FastRCNNPredictor(nn.Module):
num_classes (int): number of output classes (including background)
"""

def __init__(self, in_channels, num_classes):
def __init__(self, in_channels: int, num_classes: int) -> None:
super(FastRCNNPredictor, self).__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
Expand All @@ -312,8 +314,13 @@ def forward(self, x):


def fasterrcnn_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.

Expand Down Expand Up @@ -395,14 +402,15 @@ def fasterrcnn_resnet50_fpn(


def _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=91,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
weights_name: str,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:

trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)
Expand Down Expand Up @@ -436,8 +444,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn(


def fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:

"""
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand Down Expand Up @@ -481,8 +495,14 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(


def fasterrcnn_mobilenet_v3_large_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:

"""
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
Expand Down
29 changes: 21 additions & 8 deletions torchvision/models/detection/generalized_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class GeneralizedRCNN(nn.Module):
the model
"""

def __init__(self, backbone, rpn, roi_heads, transform):
def __init__(
self,
backbone: nn.Module,
rpn: nn.Module,
roi_heads: nn.Module,
transform: nn.Module,
) -> None:
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
Expand All @@ -32,16 +38,23 @@ def __init__(self, backbone, rpn, roi_heads, transform):
# used only on torchscript mode
self._has_warned = False

@torch.jit.unused
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
def eager_outputs(
self,
losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]],
) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]:

if self.training:
return losses

return detections

def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
def forward(
self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]:

"""
Args:
images (list[Tensor]): images to be processed
Expand All @@ -68,11 +81,11 @@ def forward(self, images, targets=None):
else:
raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))

original_image_sizes: List[Tuple[int, int]] = []
original_image_sizes: List[List[int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
original_image_sizes.append([val[0], val[1]])

images, targets = self.transform(images, targets)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ImageList(object):
and storing in a field the original sizes of each image
"""

def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]):
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
"""
Args:
tensors (tensor)
Expand Down
Loading