Skip to content

Commit 892a2dd

Browse files
fmassafacebook-github-bot
authored andcommitted
[fbsync] [OPS, IMP] New batched_nms implementation (#3426)
Summary: * new batched_nms implem * flake8 * hopefully fix torchscipt tests * Use where instead of nonzero * Use same threshold (4k) for CPU and GPU * Remove use of argsort * use views again * remove print * trying stuff, I don't know what's going on * previous passed onnx checks so the error isn't in _vanilla func. Trying to return vanilla now * add tracing decorators * cleanup * wip * ignore new path with ONNX * use vanilla if tracing...???? * Remove script_if_tracing decorator as it was conflicting with _is_tracing * flake8 * Improve coverage Reviewed By: NicolasHug, cpuhrsch Differential Revision: D26945728 fbshipit-source-id: 118a41e03da2939a726e5bd18f5f77b7c0ce6339 Co-authored-by: Francisco Massa <[email protected]>
1 parent 691e6d8 commit 892a2dd

File tree

3 files changed

+72
-17
lines changed

3 files changed

+72
-17
lines changed

test/test_onnx.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# onnxruntime requires python 3.5 or above
2+
try:
3+
# This import should be before that of torch
4+
# see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840
5+
import onnxruntime
6+
except ImportError:
7+
onnxruntime = None
8+
19
from common_utils import set_rng_seed
210
import io
311
import torch
@@ -13,12 +21,6 @@
1321

1422
from collections import OrderedDict
1523

16-
# onnxruntime requires python 3.5 or above
17-
try:
18-
import onnxruntime
19-
except ImportError:
20-
onnxruntime = None
21-
2224
import unittest
2325
from torchvision.ops._register_onnx_ops import _onnx_opset_version
2426

test/test_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,28 @@ def test_nms_cuda_float16(self):
461461
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
462462
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
463463

464+
def test_batched_nms_implementations(self):
465+
"""Make sure that both implementations of batched_nms yield identical results"""
466+
467+
num_boxes = 1000
468+
iou_threshold = .9
469+
470+
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
471+
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
472+
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
473+
474+
scores = torch.rand(num_boxes)
475+
idxs = torch.randint(0, 4, size=(num_boxes,))
476+
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
477+
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
478+
479+
err_msg = "The vanilla and the trick implementation yield different nms outputs."
480+
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
481+
482+
# Also make sure an empty tensor is returned if boxes is empty
483+
empty = torch.empty((0,), dtype=torch.int64)
484+
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
485+
464486

465487
class DeformConvTester(OpTester, unittest.TestCase):
466488
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):

torchvision/ops/boxes.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from typing import Tuple
3+
from typing import List, Tuple
44
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
55
import torchvision
66
from torchvision.extension import _assert_has_ops
@@ -36,7 +36,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
3636
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
3737

3838

39-
@torch.jit._script_if_tracing
4039
def batched_nms(
4140
boxes: Tensor,
4241
scores: Tensor,
@@ -62,18 +61,50 @@ def batched_nms(
6261
the elements that have been kept by NMS, sorted
6362
in decreasing order of scores
6463
"""
65-
if boxes.numel() == 0:
66-
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
67-
# strategy: in order to perform NMS independently per class.
64+
# Benchmarks that drove the following thresholds are at
65+
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
66+
# Ideally for GPU we'd use a higher threshold
67+
if boxes.numel() > 4_000 and not torchvision._is_tracing():
68+
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
69+
else:
70+
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
71+
72+
73+
@torch.jit._script_if_tracing
74+
def _batched_nms_coordinate_trick(
75+
boxes: Tensor,
76+
scores: Tensor,
77+
idxs: Tensor,
78+
iou_threshold: float,
79+
) -> Tensor:
80+
# strategy: in order to perform NMS independently per class,
6881
# we add an offset to all the boxes. The offset is dependent
6982
# only on the class idx, and is large enough so that boxes
7083
# from different classes do not overlap
71-
else:
72-
max_coordinate = boxes.max()
73-
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
74-
boxes_for_nms = boxes + offsets[:, None]
75-
keep = nms(boxes_for_nms, scores, iou_threshold)
76-
return keep
84+
if boxes.numel() == 0:
85+
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
86+
max_coordinate = boxes.max()
87+
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
88+
boxes_for_nms = boxes + offsets[:, None]
89+
keep = nms(boxes_for_nms, scores, iou_threshold)
90+
return keep
91+
92+
93+
@torch.jit._script_if_tracing
94+
def _batched_nms_vanilla(
95+
boxes: Tensor,
96+
scores: Tensor,
97+
idxs: Tensor,
98+
iou_threshold: float,
99+
) -> Tensor:
100+
# Based on Detectron2 implementation, just manually call nms() on each class independently
101+
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
102+
for class_id in torch.unique(idxs):
103+
curr_indices = torch.where(idxs == class_id)[0]
104+
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
105+
keep_mask[curr_indices[curr_keep_indices]] = True
106+
keep_indices = torch.where(keep_mask)[0]
107+
return keep_indices[scores[keep_indices].sort(descending=True)[1]]
77108

78109

79110
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:

0 commit comments

Comments
 (0)