diff --git a/test/test_ops.py b/test/test_ops.py index 2e9fac8bc42..964199edc66 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,4 +1,5 @@ from common_utils import needs_cuda, cpu_only +from _assert_utils import assert_equal import math import unittest import pytest @@ -78,7 +79,7 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwa sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs) tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 - self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol)) + torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) def _test_backward(self, device, contiguous): pool_size = 2 @@ -363,7 +364,7 @@ def make_rois(num_rois=1000): abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) t_scale = torch.full_like(abs_diff, fill_value=scale) - self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5)) + torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5) x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype) qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8) @@ -555,7 +556,7 @@ def test_nms_cuda_float16(self): iou_thres = 0.2 keep32 = ops.nms(boxes, scores, iou_thres) keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) - assert torch.all(torch.eq(keep32, keep16)) + assert_equal(keep32, keep16) @cpu_only def test_batched_nms_implementations(self): @@ -573,12 +574,13 @@ def test_batched_nms_implementations(self): keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold) keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) - err_msg = "The vanilla and the trick implementation yield different nms outputs." - assert torch.allclose(keep_vanilla, keep_trick), err_msg + torch.testing.assert_close( + keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs." + ) # Also make sure an empty tensor is returned if boxes is empty empty = torch.empty((0,), dtype=torch.int64) - assert torch.allclose(empty, ops.batched_nms(empty, None, None, None)) + torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) class DeformConvTester(OpTester, unittest.TestCase): @@ -690,15 +692,17 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype): bias = layer.bias.data expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) - self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), - '\nres:\n{}\nexpected:\n{}'.format(res, expected)) + torch.testing.assert_close( + res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + ) # no modulation test res = layer(x, offset) expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) - self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), - '\nres:\n{}\nexpected:\n{}'.format(res, expected)) + torch.testing.assert_close( + res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + ) # test for wrong sizes with self.assertRaises(RuntimeError): @@ -778,7 +782,7 @@ def test_compare_cpu_cuda_grads(self): else: self.assertTrue(init_weight.grad is not None) res_grads = init_weight.grad.to("cpu") - self.assertTrue(true_cpu_grads.allclose(res_grads)) + torch.testing.assert_close(true_cpu_grads, res_grads) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_autocast(self): @@ -812,14 +816,14 @@ def test_frozenbatchnorm2d_eps(self): bn = torch.nn.BatchNorm2d(sample_size[1]).eval() bn.load_state_dict(state_dict) # Difference is expected to fall in an acceptable range - self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6) # Check computation for eps > 0 fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5) fbn.load_state_dict(state_dict, strict=False) bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval() bn.load_state_dict(state_dict) - self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) + torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6) def test_frozenbatchnorm2d_n_arg(self): """Ensure a warning is thrown when passing `n` kwarg @@ -860,20 +864,10 @@ def test_bbox_same(self): exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - box_same = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy") - self.assertEqual(exp_xyxy.size(), torch.Size([4, 4])) - self.assertEqual(exp_xyxy.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_same, exp_xyxy)).item() - - box_same = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh") - self.assertEqual(exp_xyxy.size(), torch.Size([4, 4])) - self.assertEqual(exp_xyxy.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_same, exp_xyxy)).item() - - box_same = ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh") - self.assertEqual(exp_xyxy.size(), torch.Size([4, 4])) - self.assertEqual(exp_xyxy.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_same, exp_xyxy)).item() + assert exp_xyxy.size() == torch.Size([4, 4]) + assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy) + assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy) + assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy) def test_bbox_xyxy_xywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. @@ -883,16 +877,13 @@ def test_bbox_xyxy_xywh(self): exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + assert exp_xywh.size() == torch.Size([4, 4]) box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") - self.assertEqual(exp_xywh.size(), torch.Size([4, 4])) - self.assertEqual(exp_xywh.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_xywh, exp_xywh)).item() + assert_equal(box_xywh, exp_xywh) # Reverse conversion box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy") - self.assertEqual(box_xyxy.size(), torch.Size([4, 4])) - self.assertEqual(box_xyxy.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_xyxy, box_tensor)).item() + assert_equal(box_xyxy, box_tensor) def test_bbox_xyxy_cxcywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. @@ -902,16 +893,13 @@ def test_bbox_xyxy_cxcywh(self): exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") - self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4])) - self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item() + assert_equal(box_cxcywh, exp_cxcywh) # Reverse conversion box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy") - self.assertEqual(box_xyxy.size(), torch.Size([4, 4])) - self.assertEqual(box_xyxy.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_xyxy, box_tensor)).item() + assert_equal(box_xyxy, box_tensor) def test_bbox_xywh_cxcywh(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], @@ -921,16 +909,13 @@ def test_bbox_xywh_cxcywh(self): exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh") - self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4])) - self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item() + assert_equal(box_cxcywh, exp_cxcywh) # Reverse conversion box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") - self.assertEqual(box_xywh.size(), torch.Size([4, 4])) - self.assertEqual(box_xywh.dtype, box_tensor.dtype) - assert torch.all(torch.eq(box_xywh, box_tensor)).item() + assert_equal(box_xywh, box_tensor) def test_bbox_invalid(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], @@ -951,19 +936,18 @@ def test_bbox_convert_jit(self): box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') - self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE) + torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') - self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE) + torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) class BoxAreaTester(unittest.TestCase): def test_box_area(self): def area_check(box, expected, tolerance=1e-4): out = ops.box_area(box) - assert out.size() == expected.size() - assert ((out - expected).abs().max() < tolerance).item() + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) # Check for int boxes for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: @@ -991,8 +975,7 @@ class BoxIouTester(unittest.TestCase): def test_iou(self): def iou_check(box, expected, tolerance=1e-4): out = ops.box_iou(box, box) - assert out.size() == expected.size() - assert ((out - expected).abs().max() < tolerance).item() + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) # Check for int boxes for dtype in [torch.int16, torch.int32, torch.int64]: @@ -1013,8 +996,7 @@ class GenBoxIouTester(unittest.TestCase): def test_gen_iou(self): def gen_iou_check(box, expected, tolerance=1e-4): out = ops.generalized_box_iou(box, box) - assert out.size() == expected.size() - assert ((out - expected).abs().max() < tolerance).item() + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) # Check for int boxes for dtype in [torch.int16, torch.int32, torch.int64]: