Skip to content

Commit a2151b9

Browse files
authored
replace assert torch.allclose with torch.testing.assert_allclose (#6895)
1 parent 79ca506 commit a2151b9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/test_architecture_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_maxvit_window_partition(self):
2020
x_hat = partition(x, partition_size)
2121
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
2222

23-
assert torch.allclose(x, x_hat)
23+
torch.testing.assert_close(x, x_hat)
2424

2525
def test_maxvit_grid_partition(self):
2626
input_shape = (1, 3, 224, 224)
@@ -39,7 +39,7 @@ def test_maxvit_grid_partition(self):
3939
x_hat = post_swap(x_hat)
4040
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
4141

42-
assert torch.allclose(x, x_hat)
42+
torch.testing.assert_close(x, x_hat)
4343

4444

4545
if __name__ == "__main__":

test/test_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def test_nms_ref(self, iou, seed):
630630
boxes, scores = self._create_tensors_with_iou(1000, iou)
631631
keep_ref = self._reference_nms(boxes, scores, iou)
632632
keep = ops.nms(boxes, scores, iou)
633-
assert torch.allclose(keep, keep_ref), err_msg.format(iou)
633+
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
634634

635635
def test_nms_input_errors(self):
636636
with pytest.raises(RuntimeError):
@@ -661,7 +661,7 @@ def test_qnms(self, iou, scale, zero_point):
661661
keep = ops.nms(boxes, scores, iou)
662662
qkeep = ops.nms(qboxes, qscores, iou)
663663

664-
assert torch.allclose(qkeep, keep), err_msg.format(iou)
664+
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
665665

666666
@needs_cuda
667667
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@@ -1237,7 +1237,7 @@ def _run_cartesian_test(target_fn: Callable):
12371237
boxes2 = gen_box(7)
12381238
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
12391239
b = target_fn(boxes1, boxes2)
1240-
assert torch.allclose(a, b)
1240+
torch.testing.assert_close(a, b)
12411241

12421242

12431243
class TestBoxIou(TestIouBase):

0 commit comments

Comments
 (0)