Skip to content

Commit 8fb76e8

Browse files
authored
Port NMS tests to use pytest and introduce the cpu_only decorator (#3852)
1 parent e68d6b0 commit 8fb76e8

File tree

2 files changed

+75
-52
lines changed

2 files changed

+75
-52
lines changed

test/common_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
409409

410410

411411
def cpu_and_gpu():
412+
# TODO: make this properly handle CircleCI
412413
import pytest # noqa
413414

414415
# ignore CPU tests in RE as they're already covered by another contbuild
@@ -430,6 +431,7 @@ def cpu_and_gpu():
430431

431432

432433
def needs_cuda(test_func):
434+
# TODO: make this properly handle CircleCI
433435
import pytest # noqa
434436

435437
if IN_FBCODE and not IN_RE_WORKER:
@@ -441,3 +443,14 @@ def needs_cuda(test_func):
441443
return test_func
442444
else:
443445
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
446+
447+
448+
def cpu_only(test_func):
449+
# TODO: make this properly handle CircleCI
450+
import pytest # noqa
451+
452+
if IN_RE_WORKER:
453+
# The assumption is that all RE workers have GPUs.
454+
return pytest.mark.dont_collect(test_func)
455+
else:
456+
return test_func

test/test_ops.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from common_utils import set_rng_seed
1+
from common_utils import needs_cuda, cpu_only
22
import math
33
import unittest
4+
import pytest
45

56
import numpy as np
67

@@ -437,8 +438,8 @@ def test_msroialign_repr(self):
437438
self.assertEqual(t.__repr__(), expected_string)
438439

439440

440-
class NMSTester(unittest.TestCase):
441-
def reference_nms(self, boxes, scores, iou_threshold):
441+
class TestNMS:
442+
def _reference_nms(self, boxes, scores, iou_threshold):
442443
"""
443444
Args:
444445
box_scores (N, 5): boxes in corner-form and probabilities.
@@ -478,65 +479,73 @@ def _create_tensors_with_iou(self, N, iou_thresh):
478479
scores = torch.rand(N)
479480
return boxes, scores
480481

481-
def test_nms(self):
482+
@cpu_only
483+
@pytest.mark.parametrize("iou", (.2, .5, .8))
484+
def test_nms_ref(self, iou):
482485
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
483-
for iou in [0.2, 0.5, 0.8]:
484-
boxes, scores = self._create_tensors_with_iou(1000, iou)
485-
keep_ref = self.reference_nms(boxes, scores, iou)
486-
keep = ops.nms(boxes, scores, iou)
487-
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
488-
self.assertRaises(RuntimeError, ops.nms, torch.rand(4), torch.rand(3), 0.5)
489-
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 5), torch.rand(3), 0.5)
490-
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5)
491-
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)
492-
493-
def test_qnms(self):
486+
boxes, scores = self._create_tensors_with_iou(1000, iou)
487+
keep_ref = self._reference_nms(boxes, scores, iou)
488+
keep = ops.nms(boxes, scores, iou)
489+
assert torch.allclose(keep, keep_ref), err_msg.format(iou)
490+
491+
@cpu_only
492+
def test_nms_input_errors(self):
493+
with pytest.raises(RuntimeError):
494+
ops.nms(torch.rand(4), torch.rand(3), 0.5)
495+
with pytest.raises(RuntimeError):
496+
ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
497+
with pytest.raises(RuntimeError):
498+
ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
499+
with pytest.raises(RuntimeError):
500+
ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)
501+
502+
@cpu_only
503+
@pytest.mark.parametrize("iou", (.2, .5, .8))
504+
@pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
505+
def test_qnms(self, iou, scale, zero_point):
494506
# Note: we compare qnms vs nms instead of qnms vs reference implementation.
495507
# This is because with the int convertion, the trick used in _create_tensors_with_iou
496508
# doesn't really work (in fact, nms vs reference implem will also fail with ints)
497509
err_msg = 'NMS and QNMS give different results for IoU={}'
498-
for iou in [0.2, 0.5, 0.8]:
499-
for scale, zero_point in ((1, 0), (2, 50), (3, 10)):
500-
boxes, scores = self._create_tensors_with_iou(1000, iou)
501-
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion
510+
boxes, scores = self._create_tensors_with_iou(1000, iou)
511+
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion
502512

503-
qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point,
504-
dtype=torch.quint8)
505-
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point,
506-
dtype=torch.quint8)
513+
qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
514+
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
507515

508-
boxes = qboxes.dequantize()
509-
scores = qscores.dequantize()
516+
boxes = qboxes.dequantize()
517+
scores = qscores.dequantize()
510518

511-
keep = ops.nms(boxes, scores, iou)
512-
qkeep = ops.nms(qboxes, qscores, iou)
519+
keep = ops.nms(boxes, scores, iou)
520+
qkeep = ops.nms(qboxes, qscores, iou)
513521

514-
self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou))
522+
assert torch.allclose(qkeep, keep), err_msg.format(iou)
515523

516-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
517-
def test_nms_cuda(self, dtype=torch.float64):
524+
@needs_cuda
525+
@pytest.mark.parametrize("iou", (.2, .5, .8))
526+
def test_nms_cuda(self, iou, dtype=torch.float64):
518527
tol = 1e-3 if dtype is torch.half else 1e-5
519528
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'
520529

521-
for iou in [0.2, 0.5, 0.8]:
522-
boxes, scores = self._create_tensors_with_iou(1000, iou)
523-
r_cpu = ops.nms(boxes, scores, iou)
524-
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
525-
526-
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
527-
if not is_eq:
528-
# if the indices are not the same, ensure that it's because the scores
529-
# are duplicate
530-
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
531-
self.assertTrue(is_eq, err_msg.format(iou))
532-
533-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
534-
def test_autocast(self):
535-
for dtype in (torch.float, torch.half):
536-
with torch.cuda.amp.autocast():
537-
self.test_nms_cuda(dtype=dtype)
538-
539-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
530+
boxes, scores = self._create_tensors_with_iou(1000, iou)
531+
r_cpu = ops.nms(boxes, scores, iou)
532+
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
533+
534+
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
535+
if not is_eq:
536+
# if the indices are not the same, ensure that it's because the scores
537+
# are duplicate
538+
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
539+
assert is_eq, err_msg.format(iou)
540+
541+
@needs_cuda
542+
@pytest.mark.parametrize("iou", (.2, .5, .8))
543+
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
544+
def test_autocast(self, iou, dtype):
545+
with torch.cuda.amp.autocast():
546+
self.test_nms_cuda(iou=iou, dtype=dtype)
547+
548+
@needs_cuda
540549
def test_nms_cuda_float16(self):
541550
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
542551
[285.1472, 188.7374, 1192.4984, 851.0669],
@@ -546,8 +555,9 @@ def test_nms_cuda_float16(self):
546555
iou_thres = 0.2
547556
keep32 = ops.nms(boxes, scores, iou_thres)
548557
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
549-
self.assertTrue(torch.all(torch.eq(keep32, keep16)))
558+
assert torch.all(torch.eq(keep32, keep16))
550559

560+
@cpu_only
551561
def test_batched_nms_implementations(self):
552562
"""Make sure that both implementations of batched_nms yield identical results"""
553563

@@ -564,11 +574,11 @@ def test_batched_nms_implementations(self):
564574
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
565575

566576
err_msg = "The vanilla and the trick implementation yield different nms outputs."
567-
self.assertTrue(torch.allclose(keep_vanilla, keep_trick), err_msg)
577+
assert torch.allclose(keep_vanilla, keep_trick), err_msg
568578

569579
# Also make sure an empty tensor is returned if boxes is empty
570580
empty = torch.empty((0,), dtype=torch.int64)
571-
self.assertTrue(torch.allclose(empty, ops.batched_nms(empty, None, None, None)))
581+
assert torch.allclose(empty, ops.batched_nms(empty, None, None, None))
572582

573583

574584
class DeformConvTester(OpTester, unittest.TestCase):

0 commit comments

Comments
 (0)