diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 0676f7fab9b..771d3d591ab 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -501,6 +501,10 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_ inds = torch.nonzero(scores > self.score_thresh).squeeze(1) boxes, scores, labels = boxes[inds], scores[inds], labels[inds] + # remove empty boxes + keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) + boxes, scores, labels = boxes[keep], scores[keep], labels[keep] + # non-maximum suppression, independently done per class keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) # keep only topk scoring predictions