Skip to content

[ONNX] Fix roi_align ONNX export #3355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 12, 2021
10 changes: 10 additions & 0 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])

def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
Expand All @@ -150,6 +155,11 @@ def test_roi_align_aligned(self):
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])

@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
def test_roi_align_malformed_boxes(self):
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
Expand Down
6 changes: 6 additions & 0 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampli
" ONNX forces ROIs to be 1x1 or larger.")
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
rois = g.op("Sub", rois, scale)

# ONNX doesn't support negative sampling_ratio
if sampling_ratio < 0:
warnings.warn("ONNX doesn't support negative sampling ratio,"
"therefore is is set to 0 in order to be exported.")
sampling_ratio = 0
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)

Expand Down