Skip to content

Commit 673fda8

Browse files
committed
Replace type T with accumulator.
1 parent 51500c7 commit 673fda8

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

test/test_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,17 @@ def test_autocast(self):
449449
with torch.cuda.amp.autocast():
450450
self.test_nms_cuda(dtype=dtype)
451451

452+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
453+
def test_nms_cuda_float16(self):
454+
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669],
455+
[279.2440, 197.9812, 1189.4746, 849.2019]]).cuda()
456+
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
457+
458+
iou_thres = 0.2
459+
keep32 = ops.nms(boxes, scores, iou_thres)
460+
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
461+
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
462+
452463

453464
class DeformConvTester(OpTester, unittest.TestCase):
454465
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):

torchvision/csrc/ops/cuda/nms_kernel.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/AccumulateType.h>
23
#include <ATen/cuda/CUDAContext.h>
34
#include <c10/cuda/CUDAGuard.h>
45
#include <torch/library.h>
@@ -20,9 +21,10 @@ __device__ inline bool devIoU(
2021
T left = max(a[0], b[0]), right = min(a[2], b[2]);
2122
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
2223
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
23-
T interS = width * height;
24-
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
25-
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
24+
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
25+
acc_T interS = (acc_T)width * height;
26+
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
27+
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
2628
return (interS / (Sa + Sb - interS)) > threshold;
2729
}
2830

0 commit comments

Comments
 (0)