Skip to content

Commit 01dd815

Browse files
mthrokfmassa
authored andcommitted
Improve speed/accuracy of FasterRCNN by introducing a score threshold on RPN (#3205)
Summary: * Introduce small score threshold on rpn * Adding docs and fixing keypoint and mask. * Making value 0.0 by default for BC. * Fixing for onnx. * Update threshold. * Removing non-default threshold from reference scripts. Reviewed By: datumbox Differential Revision: D25954569 fbshipit-source-id: a429a13c7153f0b87c377464d3d062aa28370002 Co-authored-by: Francisco Massa <[email protected]>
1 parent 4720367 commit 01dd815

File tree

8 files changed

+46
-9
lines changed

8 files changed

+46
-9
lines changed

references/detection/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ def main(args):
9292
collate_fn=utils.collate_fn)
9393

9494
print("Creating model")
95-
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
96-
pretrained=args.pretrained)
95+
kwargs = {}
96+
if "rcnn" in args.model:
97+
kwargs["rpn_score_thresh"] = 0.0
98+
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
99+
**kwargs)
97100
model.to(device)
98101

99102
model_without_ddp = model

test/test_models_detection_negative_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_targets_to_anchors(self):
4444
rpn_anchor_generator, rpn_head,
4545
0.5, 0.3,
4646
256, 0.5,
47-
2000, 2000, 0.7)
47+
2000, 2000, 0.7, 0.05)
4848

4949
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
5050

test/test_onnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from common_utils import set_rng_seed
12
import io
23
import torch
34
from torchvision import ops
@@ -197,12 +198,14 @@ def _init_test_rpn(self):
197198
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
198199
rpn_post_nms_top_n = dict(training=2000, testing=1000)
199200
rpn_nms_thresh = 0.7
201+
rpn_score_thresh = 0.0
200202

201203
rpn = RegionProposalNetwork(
202204
rpn_anchor_generator, rpn_head,
203205
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
204206
rpn_batch_size_per_image, rpn_positive_fraction,
205-
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
207+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
208+
score_thresh=rpn_score_thresh)
206209
return rpn
207210

208211
def _init_test_roi_heads_faster_rcnn(self):
@@ -255,6 +258,8 @@ def get_features(self, images):
255258
return features
256259

257260
def test_rpn(self):
261+
set_rng_seed(0)
262+
258263
class RPNModule(torch.nn.Module):
259264
def __init__(self_module):
260265
super(RPNModule, self_module).__init__()

torchvision/models/detection/_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,14 @@ def decode(self, rel_codes, boxes):
173173
box_sum = 0
174174
for val in boxes_per_image:
175175
box_sum += val
176+
if box_sum > 0:
177+
rel_codes = rel_codes.reshape(box_sum, -1)
176178
pred_boxes = self.decode_single(
177-
rel_codes.reshape(box_sum, -1), concat_boxes
179+
rel_codes, concat_boxes
178180
)
179-
return pred_boxes.reshape(box_sum, -1, 4)
181+
if box_sum > 0:
182+
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
183+
return pred_boxes
180184

181185
def decode_single(self, rel_codes, boxes):
182186
"""

torchvision/models/detection/faster_rcnn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class FasterRCNN(GeneralizedRCNN):
7979
for computing the loss
8080
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
8181
of the RPN
82+
rpn_score_thresh (float): during inference, only return proposals with a classification score
83+
greater than rpn_score_thresh
8284
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
8385
the locations indicated by the bounding boxes
8486
box_head (nn.Module): module that takes the cropped feature maps as input
@@ -153,6 +155,7 @@ def __init__(self, backbone, num_classes=None,
153155
rpn_nms_thresh=0.7,
154156
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
155157
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
158+
rpn_score_thresh=0.0,
156159
# Box parameters
157160
box_roi_pool=None, box_head=None, box_predictor=None,
158161
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -197,7 +200,8 @@ def __init__(self, backbone, num_classes=None,
197200
rpn_anchor_generator, rpn_head,
198201
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
199202
rpn_batch_size_per_image, rpn_positive_fraction,
200-
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
203+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
204+
score_thresh=rpn_score_thresh)
201205

202206
if box_roi_pool is None:
203207
box_roi_pool = MultiScaleRoIAlign(

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class KeypointRCNN(FasterRCNN):
7474
for computing the loss
7575
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
7676
of the RPN
77+
rpn_score_thresh (float): during inference, only return proposals with a classification score
78+
greater than rpn_score_thresh
7779
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
7880
the locations indicated by the bounding boxes
7981
box_head (nn.Module): module that takes the cropped feature maps as input
@@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
158160
rpn_nms_thresh=0.7,
159161
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
160162
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
163+
rpn_score_thresh=0.0,
161164
# Box parameters
162165
box_roi_pool=None, box_head=None, box_predictor=None,
163166
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
204207
rpn_nms_thresh,
205208
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
206209
rpn_batch_size_per_image, rpn_positive_fraction,
210+
rpn_score_thresh,
207211
# Box parameters
208212
box_roi_pool, box_head, box_predictor,
209213
box_score_thresh, box_nms_thresh, box_detections_per_img,

torchvision/models/detection/mask_rcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class MaskRCNN(FasterRCNN):
7575
for computing the loss
7676
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
7777
of the RPN
78+
rpn_score_thresh (float): during inference, only return proposals with a classification score
79+
greater than rpn_score_thresh
7880
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
7981
the locations indicated by the bounding boxes
8082
box_head (nn.Module): module that takes the cropped feature maps as input
@@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
158160
rpn_nms_thresh=0.7,
159161
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
160162
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
163+
rpn_score_thresh=0.0,
161164
# Box parameters
162165
box_roi_pool=None, box_head=None, box_predictor=None,
163166
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
@@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
204207
rpn_nms_thresh,
205208
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
206209
rpn_batch_size_per_image, rpn_positive_fraction,
210+
rpn_score_thresh,
207211
# Box parameters
208212
box_roi_pool, box_head, box_predictor,
209213
box_score_thresh, box_nms_thresh, box_detections_per_img,

torchvision/models/detection/rpn.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(self,
141141
fg_iou_thresh, bg_iou_thresh,
142142
batch_size_per_image, positive_fraction,
143143
#
144-
pre_nms_top_n, post_nms_top_n, nms_thresh):
144+
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
145145
super(RegionProposalNetwork, self).__init__()
146146
self.anchor_generator = anchor_generator
147147
self.head = head
@@ -163,6 +163,7 @@ def __init__(self,
163163
self._pre_nms_top_n = pre_nms_top_n
164164
self._post_nms_top_n = post_nms_top_n
165165
self.nms_thresh = nms_thresh
166+
self.score_thresh = score_thresh
166167
self.min_size = 1e-3
167168

168169
def pre_nms_top_n(self):
@@ -251,17 +252,29 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
251252
levels = levels[batch_idx, top_n_idx]
252253
proposals = proposals[batch_idx, top_n_idx]
253254

255+
objectness_prob = F.sigmoid(objectness)
256+
254257
final_boxes = []
255258
final_scores = []
256-
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
259+
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
257260
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
261+
262+
# remove small boxes
258263
keep = box_ops.remove_small_boxes(boxes, self.min_size)
259264
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
265+
266+
# remove low scoring boxes
267+
# use >= for Backwards compatibility
268+
keep = torch.where(scores >= self.score_thresh)[0]
269+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
270+
260271
# non-maximum suppression, independently done per level
261272
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
273+
262274
# keep only topk scoring predictions
263275
keep = keep[:self.post_nms_top_n()]
264276
boxes, scores = boxes[keep], scores[keep]
277+
265278
final_boxes.append(boxes)
266279
final_scores.append(scores)
267280
return final_boxes, final_scores

0 commit comments

Comments
 (0)