Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):


def cpu_and_gpu():
# TODO: make this properly handle CircleCI
import pytest # noqa

# ignore CPU tests in RE as they're already covered by another contbuild
Expand All @@ -430,6 +431,7 @@ def cpu_and_gpu():


def needs_cuda(test_func):
# TODO: make this properly handle CircleCI
import pytest # noqa

if IN_FBCODE and not IN_RE_WORKER:
Expand All @@ -441,3 +443,14 @@ def needs_cuda(test_func):
return test_func
else:
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)


def cpu_only(test_func):
# TODO: make this properly handle CircleCI
import pytest # noqa

if IN_RE_WORKER:
# The assumption is that all RE workers have GPUs.
return pytest.mark.dont_collect(test_func)
else:
return test_func
114 changes: 62 additions & 52 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from common_utils import set_rng_seed
from common_utils import needs_cuda, cpu_only
import math
import unittest
import pytest

import numpy as np

Expand Down Expand Up @@ -437,8 +438,8 @@ def test_msroialign_repr(self):
self.assertEqual(t.__repr__(), expected_string)


class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold):
class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
Expand Down Expand Up @@ -478,65 +479,73 @@ def _create_tensors_with_iou(self, N, iou_thresh):
scores = torch.rand(N)
return boxes, scores

def test_nms(self):
@cpu_only
@pytest.mark.parametrize("iou", (.2, .5, .8))
def test_nms_ref(self, iou):
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self.reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
self.assertRaises(RuntimeError, ops.nms, torch.rand(4), torch.rand(3), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 5), torch.rand(3), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)

def test_qnms(self):
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
assert torch.allclose(keep, keep_ref), err_msg.format(iou)

@cpu_only
def test_nms_input_errors(self):
with pytest.raises(RuntimeError):
ops.nms(torch.rand(4), torch.rand(3), 0.5)
with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)

@cpu_only
@pytest.mark.parametrize("iou", (.2, .5, .8))
@pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
def test_qnms(self, iou, scale, zero_point):
# Note: we compare qnms vs nms instead of qnms vs reference implementation.
# This is because with the int convertion, the trick used in _create_tensors_with_iou
# doesn't really work (in fact, nms vs reference implem will also fail with ints)
err_msg = 'NMS and QNMS give different results for IoU={}'
for iou in [0.2, 0.5, 0.8]:
for scale, zero_point in ((1, 0), (2, 50), (3, 10)):
boxes, scores = self._create_tensors_with_iou(1000, iou)
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion
boxes, scores = self._create_tensors_with_iou(1000, iou)
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion

qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point,
dtype=torch.quint8)
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point,
dtype=torch.quint8)
qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)

boxes = qboxes.dequantize()
scores = qscores.dequantize()
boxes = qboxes.dequantize()
scores = qscores.dequantize()

keep = ops.nms(boxes, scores, iou)
qkeep = ops.nms(qboxes, qscores, iou)
keep = ops.nms(boxes, scores, iou)
qkeep = ops.nms(qboxes, qscores, iou)

self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou))
assert torch.allclose(qkeep, keep), err_msg.format(iou)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self, dtype=torch.float64):
@needs_cuda
@pytest.mark.parametrize("iou", (.2, .5, .8))
def test_nms_cuda(self, iou, dtype=torch.float64):
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'

for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

is_eq = torch.allclose(r_cpu, r_cuda.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
self.assertTrue(is_eq, err_msg.format(iou))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
for dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self.test_nms_cuda(dtype=dtype)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

is_eq = torch.allclose(r_cpu, r_cuda.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_cuda
@pytest.mark.parametrize("iou", (.2, .5, .8))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype)

@needs_cuda
def test_nms_cuda_float16(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
Expand All @@ -546,8 +555,9 @@ def test_nms_cuda_float16(self):
iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
assert torch.all(torch.eq(keep32, keep16))

@cpu_only
def test_batched_nms_implementations(self):
"""Make sure that both implementations of batched_nms yield identical results"""

Expand All @@ -564,11 +574,11 @@ def test_batched_nms_implementations(self):
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

err_msg = "The vanilla and the trick implementation yield different nms outputs."
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
assert torch.allclose(keep_vanilla, keep_trick), err_msg

# Also make sure an empty tensor is returned if boxes is empty
empty = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
assert torch.allclose(empty, ops.batched_nms(empty, None, None, None))


class DeformConvTester(OpTester, unittest.TestCase):
Expand Down