@@ -630,7 +630,7 @@ def test_nms_ref(self, iou, seed):
630
630
boxes , scores = self ._create_tensors_with_iou (1000 , iou )
631
631
keep_ref = self ._reference_nms (boxes , scores , iou )
632
632
keep = ops .nms (boxes , scores , iou )
633
- assert torch .allclose (keep , keep_ref ), err_msg .format (iou )
633
+ torch .testing . assert_close (keep , keep_ref , msg = err_msg .format (iou ) )
634
634
635
635
def test_nms_input_errors (self ):
636
636
with pytest .raises (RuntimeError ):
@@ -661,7 +661,7 @@ def test_qnms(self, iou, scale, zero_point):
661
661
keep = ops .nms (boxes , scores , iou )
662
662
qkeep = ops .nms (qboxes , qscores , iou )
663
663
664
- assert torch .allclose (qkeep , keep ), err_msg .format (iou )
664
+ torch .testing . assert_close (qkeep , keep , msg = err_msg .format (iou ) )
665
665
666
666
@needs_cuda
667
667
@pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
@@ -1237,7 +1237,7 @@ def _run_cartesian_test(target_fn: Callable):
1237
1237
boxes2 = gen_box (7 )
1238
1238
a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1239
1239
b = target_fn (boxes1 , boxes2 )
1240
- assert torch .allclose (a , b )
1240
+ torch .testing . assert_close (a , b )
1241
1241
1242
1242
1243
1243
class TestBoxIou (TestIouBase ):
0 commit comments