diff --git a/test/test_onnx.py b/test/test_onnx.py index 63f182004b8..d0140c79dfc 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -7,6 +7,7 @@ onnxruntime = None from common_utils import set_rng_seed +from _assert_utils import assert_equal import io import torch from torchvision import ops @@ -483,8 +484,8 @@ def test_heatmaps_to_keypoints(self): jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) out_trace = jit_trace(maps, rois) - assert torch.all(out[0].eq(out_trace[0])) - assert torch.all(out[1].eq(out_trace[1])) + assert_equal(out[0], out_trace[0]) + assert_equal(out[1], out_trace[1]) maps2 = torch.rand(20, 2, 21, 21) rois2 = torch.rand(20, 4) @@ -492,8 +493,8 @@ def test_heatmaps_to_keypoints(self): out2 = heatmaps_to_keypoints(maps2, rois2) out_trace2 = jit_trace(maps2, rois2) - assert torch.all(out2[0].eq(out_trace2[0])) - assert torch.all(out2[1].eq(out_trace2[1])) + assert_equal(out2[0], out_trace2[0]) + assert_equal(out2[1], out_trace2[1]) def test_keypoint_rcnn(self): images, test_images = self.get_test_images()