Skip to content

Fix NMS and IoU overflows for fp16 #3383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,18 @@ def test_autocast(self):
with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()

iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))


class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
Expand Down Expand Up @@ -829,48 +841,75 @@ def test_bbox_convert_jit(self):

class BoxAreaTester(unittest.TestCase):
def test_box_area(self):
# A bounding box of area 10000 and a degenerate case
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
expected = torch.tensor([10000, 0])
calc_area = ops.box_area(box_tensor)
assert calc_area.size() == torch.Size([2])
assert calc_area.dtype == box_tensor.dtype
assert torch.all(torch.eq(calc_area, expected)).item() is True
def area_check(box, expected, tolerance=1e-4):
out = ops.box_area(box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
expected = torch.tensor([10000, 0])
area_check(box_tensor, expected)

# Check for float32 and float64 boxes
for dtype in [torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
area_check(box_tensor, expected, tolerance=0.05)

# Check for float16 box
box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5],
[285.25, 188.75, 1192.0, 851.0],
[279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16)
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
area_check(box_tensor, expected)


class BoxIouTester(unittest.TestCase):
def test_iou(self):
# Boxes to test Iou
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected IoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])

out = ops.box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True
def iou_check(box, expected, tolerance=1e-4):
out = ops.box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
iou_check(box, expected)

# Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)


class GenBoxIouTester(unittest.TestCase):
def test_gen_iou(self):
# Test Generalized IoU
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)

# Expected gIoU matrix for these boxes
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611],
[-0.7778, -0.8611, 1.0]])

out = ops.generalized_box_iou(boxes1, boxes2)

# Check if all elements of tensor are as expected.
assert out.size() == torch.Size([3, 3])
tolerance = 1e-4
assert ((out - expected).abs().max() < tolerance).item() is True
def gen_iou_check(box, expected, tolerance=1e-4):
out = ops.generalized_box_iou(box, box)
assert out.size() == expected.size()
assert ((out - expected).abs().max() < tolerance).item()

# Check for int boxes
for dtype in [torch.int16, torch.int32, torch.int64]:
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
gen_iou_check(box, expected)

# Check for float boxes
for dtype in [torch.float16, torch.float32, torch.float64]:
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)


if __name__ == '__main__':
Expand Down
8 changes: 5 additions & 3 deletions torchvision/csrc/ops/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
Expand All @@ -20,9 +21,10 @@ __device__ inline bool devIoU(
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
acc_T interS = (acc_T)width * height;
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold;
}

Expand Down
13 changes: 11 additions & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes


def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
Expand All @@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor:
Returns:
area (Tensor[N]): area for each box
"""
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


Expand All @@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter
Expand Down Expand Up @@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = (rbi - lti).clamp(min=0) # [N,M,2]
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1]

return iou - (areai - union) / areai