Skip to content

Commit c307db4

Browse files
NicolasHugpmeier
andauthored
Use torch.testing.assert_close in test_detection_utils.py (#3881)
Co-authored-by: Philip Meier <[email protected]>
1 parent 86d4541 commit c307db4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/test_models_detection_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torchvision.models.detection.transform import GeneralizedRCNNTransform
55
import unittest
66
from torchvision.models.detection import backbone_utils
7+
from _assert_utils import assert_equal
78

89

910
class Tester(unittest.TestCase):
@@ -55,8 +56,8 @@ def test_transform_copy_targets(self):
5556
targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}]
5657
targets_copy = copy.deepcopy(targets)
5758
out = transform(image, targets) # noqa: F841
58-
self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes']))
59-
self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes']))
59+
assert_equal(targets[0]['boxes'], targets_copy[0]['boxes'])
60+
assert_equal(targets[1]['boxes'], targets_copy[1]['boxes'])
6061

6162
def test_not_float_normalize(self):
6263
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))

0 commit comments

Comments
 (0)