From 673fda83b0d7b8574d1a135fa861f441d3663884 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 11 Feb 2021 18:46:55 +0000 Subject: [PATCH 1/2] Replace type T with accumulator. --- test/test_ops.py | 11 +++++++++++ torchvision/csrc/ops/cuda/nms_kernel.cu | 8 +++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 59e8ccc2c65..34039da1bf3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -449,6 +449,17 @@ def test_autocast(self): with torch.cuda.amp.autocast(): self.test_nms_cuda(dtype=dtype) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_nms_cuda_float16(self): + boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() + scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() + + iou_thres = 0.2 + keep32 = ops.nms(boxes, scores, iou_thres) + keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) + self.assertTrue(torch.all(torch.eq(keep32, keep16))) + class DeformConvTester(OpTester, unittest.TestCase): def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index 10f995f4e99..ce1ed4ae1f5 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -20,9 +21,10 @@ __device__ inline bool devIoU( T left = max(a[0], b[0]), right = min(a[2], b[2]); T top = max(a[1], b[1]), bottom = min(a[3], b[3]); T width = max(right - left, (T)0), height = max(bottom - top, (T)0); - T interS = width * height; - T Sa = (a[2] - a[0]) * (a[3] - a[1]); - T Sb = (b[2] - b[0]) * (b[3] - b[1]); + using acc_T = at::acc_type; + acc_T interS = (acc_T)width * height; + acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); + acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); return (interS / (Sa + Sb - interS)) > threshold; } From b22ef50f44d9ec6f5d8129c0b9a88060b11ac1d2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 11 Feb 2021 20:31:26 +0000 Subject: [PATCH 2/2] Upcast tensors of box ops to avoid overflow in multiplications. --- test/test_ops.py | 100 +++++++++++++++++++++++++-------------- torchvision/ops/boxes.py | 13 ++++- 2 files changed, 75 insertions(+), 38 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 34039da1bf3..244960cebcc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -451,8 +451,9 @@ def test_autocast(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_nms_cuda_float16(self): - boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() + boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() iou_thres = 0.2 @@ -840,48 +841,75 @@ def test_bbox_convert_jit(self): class BoxAreaTester(unittest.TestCase): def test_box_area(self): - # A bounding box of area 10000 and a degenerate case - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) - expected = torch.tensor([10000, 0]) - calc_area = ops.box_area(box_tensor) - assert calc_area.size() == torch.Size([2]) - assert calc_area.dtype == box_tensor.dtype - assert torch.all(torch.eq(calc_area, expected)).item() is True + def area_check(box, expected, tolerance=1e-4): + out = ops.box_area(box) + assert out.size() == expected.size() + assert ((out - expected).abs().max() < tolerance).item() + + # Check for int boxes + for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) + expected = torch.tensor([10000, 0]) + area_check(box_tensor, expected) + + # Check for float32 and float64 boxes + for dtype in [torch.float32, torch.float64]: + box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) + area_check(box_tensor, expected, tolerance=0.05) + + # Check for float16 box + box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5], + [285.25, 188.75, 1192.0, 851.0], + [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16) + expected = torch.tensor([605113.875, 600495.1875, 592247.25]) + area_check(box_tensor, expected) class BoxIouTester(unittest.TestCase): def test_iou(self): - # Boxes to test Iou - boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - - # Expected IoU matrix for these boxes - expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) - - out = ops.box_iou(boxes1, boxes2) - - # Check if all elements of tensor are as expected. - assert out.size() == torch.Size([3, 3]) - tolerance = 1e-4 - assert ((out - expected).abs().max() < tolerance).item() is True + def iou_check(box, expected, tolerance=1e-4): + out = ops.box_iou(box, box) + assert out.size() == expected.size() + assert ((out - expected).abs().max() < tolerance).item() + + # Check for int boxes + for dtype in [torch.int16, torch.int32, torch.int64]: + box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) + expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) + iou_check(box, expected) + + # Check for float boxes + for dtype in [torch.float16, torch.float32, torch.float64]: + box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) + iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) class GenBoxIouTester(unittest.TestCase): def test_gen_iou(self): - # Test Generalized IoU - boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) - - # Expected gIoU matrix for these boxes - expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], - [-0.7778, -0.8611, 1.0]]) - - out = ops.generalized_box_iou(boxes1, boxes2) - - # Check if all elements of tensor are as expected. - assert out.size() == torch.Size([3, 3]) - tolerance = 1e-4 - assert ((out - expected).abs().max() < tolerance).item() is True + def gen_iou_check(box, expected, tolerance=1e-4): + out = ops.generalized_box_iou(box, box) + assert out.size() == expected.size() + assert ((out - expected).abs().max() < tolerance).item() + + # Check for int boxes + for dtype in [torch.int16, torch.int32, torch.int64]: + box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) + expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) + gen_iou_check(box, expected) + + # Check for float boxes + for dtype in [torch.float16, torch.float32, torch.float64]: + box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) + gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) if __name__ == '__main__': diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 1a9657f41ab..2cb1be93168 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: return boxes +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + def box_area(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by its @@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor: Returns: area (Tensor[N]): area for each box """ + boxes = _upcast(boxes) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) @@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]: lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - wh = (rb - lt).clamp(min=0) # [N,M,2] + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter @@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - whi = (rbi - lti).clamp(min=0) # [N,M,2] + whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] areai = whi[:, :, 0] * whi[:, :, 1] return iou - (areai - union) / areai