Skip to content

Commit 221d615

Browse files
authored
Merge branch 'master' into densenet-typing
2 parents 9e33fc8 + f04e9cb commit 221d615

File tree

3 files changed

+89
-39
lines changed

3 files changed

+89
-39
lines changed

test/test_ops.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,18 @@ def test_autocast(self):
449449
with torch.cuda.amp.autocast():
450450
self.test_nms_cuda(dtype=dtype)
451451

452+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
453+
def test_nms_cuda_float16(self):
454+
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
455+
[285.1472, 188.7374, 1192.4984, 851.0669],
456+
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
457+
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
458+
459+
iou_thres = 0.2
460+
keep32 = ops.nms(boxes, scores, iou_thres)
461+
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
462+
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
463+
452464

453465
class DeformConvTester(OpTester, unittest.TestCase):
454466
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
@@ -829,48 +841,75 @@ def test_bbox_convert_jit(self):
829841

830842
class BoxAreaTester(unittest.TestCase):
831843
def test_box_area(self):
832-
# A bounding box of area 10000 and a degenerate case
833-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
834-
expected = torch.tensor([10000, 0])
835-
calc_area = ops.box_area(box_tensor)
836-
assert calc_area.size() == torch.Size([2])
837-
assert calc_area.dtype == box_tensor.dtype
838-
assert torch.all(torch.eq(calc_area, expected)).item() is True
844+
def area_check(box, expected, tolerance=1e-4):
845+
out = ops.box_area(box)
846+
assert out.size() == expected.size()
847+
assert ((out - expected).abs().max() < tolerance).item()
848+
849+
# Check for int boxes
850+
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
851+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
852+
expected = torch.tensor([10000, 0])
853+
area_check(box_tensor, expected)
854+
855+
# Check for float32 and float64 boxes
856+
for dtype in [torch.float32, torch.float64]:
857+
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
858+
[285.1472, 188.7374, 1192.4984, 851.0669],
859+
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
860+
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
861+
area_check(box_tensor, expected, tolerance=0.05)
862+
863+
# Check for float16 box
864+
box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5],
865+
[285.25, 188.75, 1192.0, 851.0],
866+
[279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16)
867+
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
868+
area_check(box_tensor, expected)
839869

840870

841871
class BoxIouTester(unittest.TestCase):
842872
def test_iou(self):
843-
# Boxes to test Iou
844-
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
845-
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
846-
847-
# Expected IoU matrix for these boxes
848-
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
849-
850-
out = ops.box_iou(boxes1, boxes2)
851-
852-
# Check if all elements of tensor are as expected.
853-
assert out.size() == torch.Size([3, 3])
854-
tolerance = 1e-4
855-
assert ((out - expected).abs().max() < tolerance).item() is True
873+
def iou_check(box, expected, tolerance=1e-4):
874+
out = ops.box_iou(box, box)
875+
assert out.size() == expected.size()
876+
assert ((out - expected).abs().max() < tolerance).item()
877+
878+
# Check for int boxes
879+
for dtype in [torch.int16, torch.int32, torch.int64]:
880+
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
881+
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
882+
iou_check(box, expected)
883+
884+
# Check for float boxes
885+
for dtype in [torch.float16, torch.float32, torch.float64]:
886+
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
887+
[285.1472, 188.7374, 1192.4984, 851.0669],
888+
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
889+
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
890+
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
856891

857892

858893
class GenBoxIouTester(unittest.TestCase):
859894
def test_gen_iou(self):
860-
# Test Generalized IoU
861-
boxes1 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
862-
boxes2 = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
863-
864-
# Expected gIoU matrix for these boxes
865-
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611],
866-
[-0.7778, -0.8611, 1.0]])
867-
868-
out = ops.generalized_box_iou(boxes1, boxes2)
869-
870-
# Check if all elements of tensor are as expected.
871-
assert out.size() == torch.Size([3, 3])
872-
tolerance = 1e-4
873-
assert ((out - expected).abs().max() < tolerance).item() is True
895+
def gen_iou_check(box, expected, tolerance=1e-4):
896+
out = ops.generalized_box_iou(box, box)
897+
assert out.size() == expected.size()
898+
assert ((out - expected).abs().max() < tolerance).item()
899+
900+
# Check for int boxes
901+
for dtype in [torch.int16, torch.int32, torch.int64]:
902+
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
903+
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
904+
gen_iou_check(box, expected)
905+
906+
# Check for float boxes
907+
for dtype in [torch.float16, torch.float32, torch.float64]:
908+
box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
909+
[285.1472, 188.7374, 1192.4984, 851.0669],
910+
[279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype)
911+
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
912+
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
874913

875914

876915
if __name__ == '__main__':

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/AccumulateType.h>
23
#include <ATen/cuda/CUDAContext.h>
34
#include <c10/cuda/CUDAGuard.h>
45
#include <torch/library.h>
@@ -20,9 +21,10 @@ __device__ inline bool devIoU(
2021
T left = max(a[0], b[0]), right = min(a[2], b[2]);
2122
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
2223
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
23-
T interS = width * height;
24-
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
25-
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
24+
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
25+
acc_T interS = (acc_T)width * height;
26+
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
27+
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
2628
return (interS / (Sa + Sb - interS)) > threshold;
2729
}
2830

torchvision/ops/boxes.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
170170
return boxes
171171

172172

173+
def _upcast(t: Tensor) -> Tensor:
174+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
175+
if t.is_floating_point():
176+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
177+
else:
178+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
179+
180+
173181
def box_area(boxes: Tensor) -> Tensor:
174182
"""
175183
Computes the area of a set of bounding boxes, which are specified by its
@@ -182,6 +190,7 @@ def box_area(boxes: Tensor) -> Tensor:
182190
Returns:
183191
area (Tensor[N]): area for each box
184192
"""
193+
boxes = _upcast(boxes)
185194
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
186195

187196

@@ -194,7 +203,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
194203
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
195204
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
196205

197-
wh = (rb - lt).clamp(min=0) # [N,M,2]
206+
wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
198207
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
199208

200209
union = area1[:, None] + area2 - inter
@@ -247,7 +256,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
247256
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
248257
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
249258

250-
whi = (rbi - lti).clamp(min=0) # [N,M,2]
259+
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
251260
areai = whi[:, :, 0] * whi[:, :, 1]
252261

253262
return iou - (areai - union) / areai

0 commit comments

Comments
 (0)