diff --git a/test/test_onnx.py b/test/test_onnx.py index a4170b5242f..63f182004b8 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -129,6 +129,11 @@ def test_roi_align(self): model = ops.RoIAlign((5, 5), 1, 2) self.run_model(model, [(x, single_roi)]) + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) + model = ops.RoIAlign((5, 5), 1, -1) + self.run_model(model, [(x, single_roi)]) + def test_roi_align_aligned(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32) @@ -150,6 +155,11 @@ def test_roi_align_aligned(self): model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True) self.run_model(model, [(x, single_roi)]) + x = torch.rand(1, 1, 10, 10, dtype=torch.float32) + single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32) + model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True) + self.run_model(model, [(x, single_roi)]) + @unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes def test_roi_align_malformed_boxes(self): x = torch.randn(1, 1, 10, 10, dtype=torch.float32) diff --git a/torchvision/ops/_register_onnx_ops.py b/torchvision/ops/_register_onnx_ops.py index 02013844aac..8e8ed331803 100644 --- a/torchvision/ops/_register_onnx_ops.py +++ b/torchvision/ops/_register_onnx_ops.py @@ -29,6 +29,12 @@ def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampli " ONNX forces ROIs to be 1x1 or larger.") scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float) rois = g.op("Sub", rois, scale) + + # ONNX doesn't support negative sampling_ratio + if sampling_ratio < 0: + warnings.warn("ONNX doesn't support negative sampling ratio," + "therefore is is set to 0 in order to be exported.") + sampling_ratio = 0 return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale, output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)