-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added typing annotations to models/detection [1/n] #4220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b2f6615
4fb038d
deda5d7
5490821
4cfc220
6306746
e8c93cf
6871ccc
53fe949
ecc58a7
6a30c92
fb3ea88
1f5f715
d13311d
5adda72
d16be19
4580541
1ba33ce
254e51b
47f75dc
6aff88c
516fb68
b026039
18ec557
24e3f74
ab7c671
ff5698a
50641d5
6a9c0cb
700d74d
c0e836b
0a93c17
a0b3b2a
fc8032c
730a33d
4cd5b84
c278189
fe7e289
1fda3e4
0daf0d8
4fa9294
f9c7fbe
f1eeea1
12e7d7b
eae251f
9d0d6cf
e437b60
a6bb70c
76b26ba
623d5e4
bb73a22
4e8e9f7
ca5d600
8976768
6fc8812
7422d3b
a10588c
5f2dcc4
70e5b6a
7d4e919
01fdb48
cace5ff
e941e48
623017b
becbbd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||
from torch import nn | ||||
import torch.nn.functional as F | ||||
from typing import Optional, Tuple, Any | ||||
|
||||
from torchvision.ops import MultiScaleRoIAlign | ||||
|
||||
|
@@ -141,24 +142,41 @@ class FasterRCNN(GeneralizedRCNN): | |||
>>> predictions = model(x) | ||||
""" | ||||
|
||||
def __init__(self, backbone, num_classes=None, | ||||
# transform parameters | ||||
min_size=800, max_size=1333, | ||||
image_mean=None, image_std=None, | ||||
# RPN parameters | ||||
rpn_anchor_generator=None, rpn_head=None, | ||||
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, | ||||
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, | ||||
rpn_nms_thresh=0.7, | ||||
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, | ||||
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, | ||||
rpn_score_thresh=0.0, | ||||
# Box parameters | ||||
box_roi_pool=None, box_head=None, box_predictor=None, | ||||
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, | ||||
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, | ||||
box_batch_size_per_image=512, box_positive_fraction=0.25, | ||||
bbox_reg_weights=None): | ||||
def __init__( | ||||
self, | ||||
backbone: nn.Module, | ||||
num_classes: Optional[int] = None, | ||||
# transform parameters | ||||
min_size: int = 800, | ||||
max_size: int = 1333, | ||||
image_mean: Optional[Tuple[float]] = None, | ||||
image_std: Optional[Tuple[float]] = None, | ||||
# RPN parameters | ||||
rpn_anchor_generator: Optional[AnchorGenerator] = None, | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
rpn_head: Optional[nn.Module] = None, | ||||
rpn_pre_nms_top_n_train: int = 2000, | ||||
rpn_pre_nms_top_n_test: int = 1000, | ||||
rpn_post_nms_top_n_train: int = 2000, | ||||
rpn_post_nms_top_n_test: int = 1000, | ||||
rpn_nms_thresh: float = 0.7, | ||||
rpn_fg_iou_thresh: float = 0.7, | ||||
rpn_bg_iou_thresh: float = 0.3, | ||||
rpn_batch_size_per_image: int = 256, | ||||
rpn_positive_fraction: float = 0.5, | ||||
rpn_score_thresh: float = 0.0, | ||||
# Box parameters | ||||
box_roi_pool: Optional[MultiScaleRoIAlign] = None, | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
box_head: Optional[nn.Module] = None, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was bit unsure on these. I think that user can write their own classes and just pass them to Let me know, I will change all of these otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we allow passing custom classes, we need structural duck typing here. Meaning we need to define a custom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure of what you mean @pmeier . I'm new to type hints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes. As a rule of thumb, the input types should be as loose as possible and the output types as strict as possible. After some digging, I think
is the only time we do anything with the object passed as In other cases @frgfm's comment might have some merit in other cases: for example, if we at some point access
Especially if users are allowed to pass custom objects, I would always prefer 2. over 1. because it doesn't require any changes on the user side. I don't know if this works with torchscript though. Let's discuss this if the need for something like this arises. |
||||
box_predictor: Optional[nn.Module] = None, | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
box_score_thresh: float = 0.05, | ||||
box_nms_thresh: float = 0.5, | ||||
box_detections_per_img: int = 100, | ||||
box_fg_iou_thresh: float = 0.5, | ||||
box_bg_iou_thresh: float = 0.5, | ||||
box_batch_size_per_image: int = 512, | ||||
box_positive_fraction: float = 0.25, | ||||
bbox_reg_weights: Optional[Tuple[float]] = None | ||||
oke-aditya marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
) -> None: | ||||
|
||||
if not hasattr(backbone, "out_channels"): | ||||
raise ValueError( | ||||
|
@@ -245,7 +263,11 @@ class TwoMLPHead(nn.Module): | |||
representation_size (int): size of the intermediate representation | ||||
""" | ||||
|
||||
def __init__(self, in_channels, representation_size): | ||||
def __init__( | ||||
self, | ||||
in_channels: int, | ||||
representation_size: int | ||||
) -> None: | ||||
super(TwoMLPHead, self).__init__() | ||||
|
||||
self.fc6 = nn.Linear(in_channels, representation_size) | ||||
|
@@ -270,7 +292,11 @@ class FastRCNNPredictor(nn.Module): | |||
num_classes (int): number of output classes (including background) | ||||
""" | ||||
|
||||
def __init__(self, in_channels, num_classes): | ||||
def __init__( | ||||
self, | ||||
in_channels: int, | ||||
num_classes: int | ||||
) -> None: | ||||
super(FastRCNNPredictor, self).__init__() | ||||
self.cls_score = nn.Linear(in_channels, num_classes) | ||||
self.bbox_pred = nn.Linear(in_channels, num_classes * 4) | ||||
|
@@ -295,8 +321,14 @@ def forward(self, x): | |||
} | ||||
|
||||
|
||||
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, | ||||
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): | ||||
def fasterrcnn_resnet50_fpn( | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
""" | ||||
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. | ||||
|
||||
|
@@ -374,8 +406,16 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, | |||
return model | ||||
|
||||
|
||||
def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91, | ||||
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): | ||||
def _fasterrcnn_mobilenet_v3_large_fpn( | ||||
weights_name: str, | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
trainable_backbone_layers = _validate_trainable_layers( | ||||
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) | ||||
|
||||
|
@@ -397,8 +437,15 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress= | |||
return model | ||||
|
||||
|
||||
def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, | ||||
trainable_backbone_layers=None, **kwargs): | ||||
def fasterrcnn_mobilenet_v3_large_320_fpn( | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
""" | ||||
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. | ||||
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See | ||||
|
@@ -435,8 +482,15 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c | |||
trainable_backbone_layers=trainable_backbone_layers, **kwargs) | ||||
|
||||
|
||||
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, | ||||
trainable_backbone_layers=None, **kwargs): | ||||
def fasterrcnn_mobilenet_v3_large_fpn( | ||||
pretrained: bool = False, | ||||
progress: bool = True, | ||||
num_classes: int = 91, | ||||
pretrained_backbone: bool = True, | ||||
trainable_backbone_layers: Optional[int] = None, | ||||
**kwargs: Any, | ||||
) -> FasterRCNN: | ||||
|
||||
""" | ||||
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. | ||||
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be
List[float]
instead ofTuple[float]
?Similar to:
vision/torchvision/models/detection/ssd.py
Line 169 in 3e7653c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would argue that the typing in ssd is too specific: it can be a
Optional[List[float]]
indeed, butOptional[Tuple[float, float, float]]
is accepted since it will be passed down to https://github.com/pytorch/vision/blob/master/torchvision/models/detection/transform.py#L136-L137 for the transformationsShould we use a
Union
or enforce aList
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, I just saw your comment later on, since the backbone can take any number of channels in the input tensor, the second option is actually
Optional[Tuple[float, ...]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed to
Optional[Tuple[float, ...]]
Since in the docstring we mentionTuple[float, float, float]
.