Skip to content

Commit 4807308

Browse files
committed
Replace type T with accumulator.
1 parent 51500c7 commit 4807308

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

test/test_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,17 @@ def test_autocast(self):
449449
with torch.cuda.amp.autocast():
450450
self.test_nms_cuda(dtype=dtype)
451451

452+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
453+
def test_nms_cuda_float16(self):
454+
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669],
455+
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
456+
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
457+
458+
iou_thres = 0.2
459+
keep32 = ops.nms(boxes, scores, iou_thres)
460+
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
461+
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
462+
452463

453464
class DeformConvTester(OpTester, unittest.TestCase):
454465
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ __device__ inline bool devIoU(
1717
T const* const a,
1818
T const* const b,
1919
const float threshold) {
20-
T left = max(a[0], b[0]), right = min(a[2], b[2]);
21-
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
22-
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
23-
T interS = width * height;
24-
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
25-
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
20+
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
21+
acc_T left = max(a[0], b[0]), right = min(a[2], b[2]);
22+
acc_T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
23+
acc_T width = max(right - left, (acc_T)0), height = max(bottom - top, (acc_T)0);
24+
acc_T interS = width * height;
25+
acc_T Sa = (a[2] - a[0]) * (a[3] - a[1]);
26+
acc_T Sb = (b[2] - b[0]) * (b[3] - b[1]);
2627
return (interS / (Sa + Sb - interS)) > threshold;
2728
}
2829

0 commit comments

Comments
 (0)