Skip to content

Improve speed/accuracy of *RCNN models by introducing a score threshold on RPN #3205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def main(args):
collate_fn=utils.collate_fn)

print("Creating model")
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
pretrained=args.pretrained)
kwargs = {}
if "rcnn" in args.model:
kwargs["rpn_score_thresh"] = 0.0
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model.to(device)

model_without_ddp = model
Expand Down
2 changes: 1 addition & 1 deletion test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_targets_to_anchors(self):
rpn_anchor_generator, rpn_head,
0.5, 0.3,
256, 0.5,
2000, 2000, 0.7)
2000, 2000, 0.7, 0.05)

labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)

Expand Down
7 changes: 6 additions & 1 deletion test/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from common_utils import set_rng_seed
import io
import torch
from torchvision import ops
Expand Down Expand Up @@ -197,12 +198,14 @@ def _init_test_rpn(self):
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
rpn_post_nms_top_n = dict(training=2000, testing=1000)
rpn_nms_thresh = 0.7
rpn_score_thresh = 0.0

rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
score_thresh=rpn_score_thresh)
return rpn

def _init_test_roi_heads_faster_rcnn(self):
Expand Down Expand Up @@ -255,6 +258,8 @@ def get_features(self, images):
return features

def test_rpn(self):
set_rng_seed(0)

class RPNModule(torch.nn.Module):
def __init__(self_module):
super(RPNModule, self_module).__init__()
Expand Down
8 changes: 6 additions & 2 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,14 @@ def decode(self, rel_codes, boxes):
box_sum = 0
for val in boxes_per_image:
box_sum += val
if box_sum > 0:
rel_codes = rel_codes.reshape(box_sum, -1)
pred_boxes = self.decode_single(
rel_codes.reshape(box_sum, -1), concat_boxes
rel_codes, concat_boxes
)
return pred_boxes.reshape(box_sum, -1, 4)
if box_sum > 0:
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes

def decode_single(self, rel_codes, boxes):
"""
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class FasterRCNN(GeneralizedRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -197,7 +200,8 @@ def __init__(self, backbone, num_classes=None,
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
score_thresh=rpn_score_thresh)

if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
Expand Down
4 changes: 4 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class KeypointRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class MaskRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
Expand Down
17 changes: 15 additions & 2 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self,
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
#
pre_nms_top_n, post_nms_top_n, nms_thresh):
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
Expand All @@ -163,6 +163,7 @@ def __init__(self,
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.score_thresh = score_thresh
self.min_size = 1e-3

def pre_nms_top_n(self):
Expand Down Expand Up @@ -251,17 +252,29 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]

objectness_prob = F.sigmoid(objectness)

final_boxes = []
final_scores = []
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# remove low scoring boxes
# use >= for Backwards compatibility
keep = torch.where(scores >= self.score_thresh)[0]
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)

# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]

final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
Expand Down