Skip to content

Commit 30bb1ce

Browse files
justinchubymalfet
andauthored
[ONNX] misc improvements (#7249)
Co-authored-by: Nikita Shulga <[email protected]>
1 parent d805aea commit 30bb1ce

File tree

2 files changed

+97
-87
lines changed

2 files changed

+97
-87
lines changed

test/test_onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def run_model(
3434
opset_version: Optional[int] = None,
3535
):
3636
if opset_version is None:
37-
opset_version = _register_onnx_ops.base_onnx_opset_version
37+
opset_version = _register_onnx_ops.BASE_ONNX_OPSET_VERSION
3838

3939
model.eval()
4040

@@ -139,7 +139,7 @@ def test_roi_align(self):
139139
self.run_model(model, [(x, single_roi)])
140140

141141
def test_roi_align_aligned(self):
142-
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
142+
supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
143143
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
144144
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
145145
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
@@ -166,7 +166,7 @@ def test_roi_align_aligned(self):
166166
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
167167

168168
def test_roi_align_malformed_boxes(self):
169-
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
169+
supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
170170
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
171171
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
172172
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)

torchvision/ops/_register_onnx_ops.py

Lines changed: 94 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,96 +2,106 @@
22
import warnings
33

44
import torch
5+
from torch.onnx import symbolic_opset11 as opset11
6+
from torch.onnx.symbolic_helper import parse_args
57

6-
_onnx_opset_version_11 = 11
7-
_onnx_opset_version_16 = 16
8-
base_onnx_opset_version = _onnx_opset_version_11
8+
_ONNX_OPSET_VERSION_11 = 11
9+
_ONNX_OPSET_VERSION_16 = 16
10+
BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
911

1012

11-
def _register_custom_op():
12-
from torch.onnx.symbolic_helper import parse_args
13-
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
14-
15-
@parse_args("v", "v", "f")
16-
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
17-
boxes = unsqueeze(g, boxes, 0)
18-
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
19-
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
20-
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
21-
nms_out = g.op(
22-
"NonMaxSuppression",
23-
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
24-
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
25-
max_output_per_class,
26-
iou_threshold,
27-
)
28-
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
29-
30-
def _process_batch_indices_for_roi_align(g, rois):
31-
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
32-
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
33-
34-
def _process_rois_for_roi_align(g, rois):
35-
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
36-
37-
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
38-
if sampling_ratio < 0:
39-
warnings.warn(
40-
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
41-
"The model will be exported with a sampling_ratio of 0."
42-
)
43-
sampling_ratio = 0
44-
return sampling_ratio
45-
46-
@parse_args("v", "v", "f", "i", "i", "i", "i")
47-
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
48-
batch_indices = _process_batch_indices_for_roi_align(g, rois)
49-
rois = _process_rois_for_roi_align(g, rois)
50-
if aligned:
51-
warnings.warn(
52-
"ROIAlign with aligned=True is only supported in opset >= 16. "
53-
"Please export with opset 16 or higher, or use aligned=False."
54-
)
55-
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
56-
return g.op(
57-
"RoiAlign",
58-
input,
59-
rois,
60-
batch_indices,
61-
spatial_scale_f=spatial_scale,
62-
output_height_i=pooled_height,
63-
output_width_i=pooled_width,
64-
sampling_ratio_i=sampling_ratio,
65-
)
13+
@parse_args("v", "v", "f")
14+
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
15+
boxes = opset11.unsqueeze(g, boxes, 0)
16+
scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
17+
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
18+
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
19+
20+
# Cast boxes and scores to float32 in case they are float64 inputs
21+
nms_out = g.op(
22+
"NonMaxSuppression",
23+
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
24+
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
25+
max_output_per_class,
26+
iou_threshold,
27+
)
28+
return opset11.squeeze(
29+
g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
30+
)
31+
32+
33+
def _process_batch_indices_for_roi_align(g, rois):
34+
indices = opset11.squeeze(
35+
g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
36+
)
37+
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
38+
39+
40+
def _process_rois_for_roi_align(g, rois):
41+
return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
6642

67-
@parse_args("v", "v", "f", "i", "i", "i", "i")
68-
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
69-
batch_indices = _process_batch_indices_for_roi_align(g, rois)
70-
rois = _process_rois_for_roi_align(g, rois)
71-
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
72-
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
73-
return g.op(
74-
"RoiAlign",
75-
input,
76-
rois,
77-
batch_indices,
78-
coordinate_transformation_mode_s=coordinate_transformation_mode,
79-
spatial_scale_f=spatial_scale,
80-
output_height_i=pooled_height,
81-
output_width_i=pooled_width,
82-
sampling_ratio_i=sampling_ratio,
43+
44+
def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
45+
if sampling_ratio < 0:
46+
warnings.warn(
47+
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
48+
"The model will be exported with a sampling_ratio of 0."
8349
)
50+
sampling_ratio = 0
51+
return sampling_ratio
52+
8453

85-
@parse_args("v", "v", "f", "i", "i")
86-
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
87-
roi_pool = g.op(
88-
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
54+
@parse_args("v", "v", "f", "i", "i", "i", "i")
55+
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
56+
batch_indices = _process_batch_indices_for_roi_align(g, rois)
57+
rois = _process_rois_for_roi_align(g, rois)
58+
if aligned:
59+
warnings.warn(
60+
"ROIAlign with aligned=True is only supported in opset >= 16. "
61+
"Please export with opset 16 or higher, or use aligned=False."
8962
)
90-
return roi_pool, None
63+
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
64+
return g.op(
65+
"RoiAlign",
66+
input,
67+
rois,
68+
batch_indices,
69+
spatial_scale_f=spatial_scale,
70+
output_height_i=pooled_height,
71+
output_width_i=pooled_width,
72+
sampling_ratio_i=sampling_ratio,
73+
)
74+
75+
76+
@parse_args("v", "v", "f", "i", "i", "i", "i")
77+
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
78+
batch_indices = _process_batch_indices_for_roi_align(g, rois)
79+
rois = _process_rois_for_roi_align(g, rois)
80+
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
81+
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
82+
return g.op(
83+
"RoiAlign",
84+
input,
85+
rois,
86+
batch_indices,
87+
coordinate_transformation_mode_s=coordinate_transformation_mode,
88+
spatial_scale_f=spatial_scale,
89+
output_height_i=pooled_height,
90+
output_width_i=pooled_width,
91+
sampling_ratio_i=sampling_ratio,
92+
)
9193

92-
from torch.onnx import register_custom_op_symbolic
9394

94-
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
95-
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
96-
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
97-
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)
95+
@parse_args("v", "v", "f", "i", "i")
96+
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
97+
roi_pool = g.op(
98+
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
99+
)
100+
return roi_pool, None
101+
102+
103+
def _register_custom_op():
104+
torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
105+
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
106+
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
107+
torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)

0 commit comments

Comments
 (0)