Skip to content

Commit ae8b92e

Browse files
committed
Adding docs and fixing keypoint and mask.
1 parent 60318d9 commit ae8b92e

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

test/test_onnx.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _init_test_rpn(self):
197197
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
198198
rpn_post_nms_top_n = dict(training=2000, testing=1000)
199199
rpn_nms_thresh = 0.7
200-
rpn_score_thresh = 0.05
200+
rpn_score_thresh = 0.0
201201

202202
rpn = RegionProposalNetwork(
203203
rpn_anchor_generator, rpn_head,
@@ -368,7 +368,8 @@ def get_test_images(self):
368368
def test_faster_rcnn(self):
369369
images, test_images = self.get_test_images()
370370
dummy_image = [torch.ones(3, 100, 100) * 0.3]
371-
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
371+
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300,
372+
rpn_score_thresh=0.0)
372373
model.eval()
373374
model(images)
374375
# Test exported model on images of different size, or dummy input
@@ -419,7 +420,8 @@ def test_paste_mask_in_image(self):
419420
def test_mask_rcnn(self):
420421
images, test_images = self.get_test_images()
421422
dummy_image = [torch.ones(3, 100, 100) * 0.3]
422-
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
423+
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300,
424+
rpn_score_thresh=0.0)
423425
model.eval()
424426
model(images)
425427
# Test exported model on images of different size, or dummy input
@@ -469,7 +471,8 @@ def test_heatmaps_to_keypoints(self):
469471
def test_keypoint_rcnn(self):
470472
images, test_images = self.get_test_images()
471473
dummy_images = [torch.ones(3, 100, 100) * 0.3]
472-
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
474+
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300,
475+
rpn_score_thresh=0.0)
473476
model.eval()
474477
model(images)
475478
self.run_model(model, [(images,), (test_images,), (dummy_images,)],

torchvision/models/detection/_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,14 @@ def decode(self, rel_codes, boxes):
173173
box_sum = 0
174174
for val in boxes_per_image:
175175
box_sum += val
176+
if box_sum > 0:
177+
rel_codes = rel_codes.reshape(box_sum, -1)
176178
pred_boxes = self.decode_single(
177-
rel_codes.reshape(box_sum, -1), concat_boxes
179+
rel_codes, concat_boxes
178180
)
179-
return pred_boxes.reshape(box_sum, -1, 4)
181+
if box_sum > 0:
182+
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
183+
return pred_boxes
180184

181185
def decode_single(self, rel_codes, boxes):
182186
"""

torchvision/models/detection/faster_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class FasterRCNN(GeneralizedRCNN):
7979
for computing the loss
8080
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
8181
of the RPN
82+
rpn_score_thresh (float): during inference, only return proposals with a classification score
83+
greater than rpn_score_thresh
8284
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
8385
the locations indicated by the bounding boxes
8486
box_head (nn.Module): module that takes the cropped feature maps as input

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class KeypointRCNN(FasterRCNN):
7474
for computing the loss
7575
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
7676
of the RPN
77+
rpn_score_thresh (float): during inference, only return proposals with a classification score
78+
greater than rpn_score_thresh
7779
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
7880
the locations indicated by the bounding boxes
7981
box_head (nn.Module): module that takes the cropped feature maps as input
@@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
158160
rpn_nms_thresh=0.7,
159161
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
160162
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
163+
rpn_score_thresh=0.05,
161164
# Box parameters
162165
box_roi_pool=None, box_head=None, box_predictor=None,
163166
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
204207
rpn_nms_thresh,
205208
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
206209
rpn_batch_size_per_image, rpn_positive_fraction,
210+
rpn_score_thresh,
207211
# Box parameters
208212
box_roi_pool, box_head, box_predictor,
209213
box_score_thresh, box_nms_thresh, box_detections_per_img,

torchvision/models/detection/mask_rcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class MaskRCNN(FasterRCNN):
7878
for computing the loss
7979
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
8080
of the RPN
81+
rpn_score_thresh (float): during inference, only return proposals with a classification score
82+
greater than rpn_score_thresh
8183
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
8284
the locations indicated by the bounding boxes
8385
box_head (nn.Module): module that takes the cropped feature maps as input
@@ -161,6 +163,7 @@ def __init__(self, backbone, num_classes=None,
161163
rpn_nms_thresh=0.7,
162164
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
163165
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
166+
rpn_score_thresh=0.05,
164167
# Box parameters
165168
box_roi_pool=None, box_head=None, box_predictor=None,
166169
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -207,6 +210,7 @@ def __init__(self, backbone, num_classes=None,
207210
rpn_nms_thresh,
208211
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
209212
rpn_batch_size_per_image, rpn_positive_fraction,
213+
rpn_score_thresh,
210214
# Box parameters
211215
box_roi_pool, box_head, box_predictor,
212216
box_score_thresh, box_nms_thresh, box_detections_per_img,

0 commit comments

Comments
 (0)