Skip to content

Commit ec35b40

Browse files
authored
fix roi align symbolic function in onnx opset>=16 (#2428)
1 parent bb031c6 commit ec35b40

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

mmdeploy/mmcv/ops/roi_align.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,38 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
5858
else:
5959
from torch.onnx.symbolic_opset9 import _cast_Long
6060
from torch.onnx.symbolic_opset11 import add, select
61-
batch_indices = _cast_Long(
62-
g,
63-
g.op(
64-
'Squeeze',
65-
select(
66-
g, rois, 1,
67-
g.op(
68-
'Constant',
69-
value_t=torch.tensor([0], dtype=torch.long))),
70-
axes_i=[1]), False)
61+
ir_cfg = get_ir_config(ctx.cfg)
62+
opset_version = ir_cfg.get('opset_version', 11)
63+
if opset_version < 13:
64+
batch_indices = _cast_Long(
65+
g,
66+
g.op(
67+
'Squeeze',
68+
select(
69+
g, rois, 1,
70+
g.op(
71+
'Constant',
72+
value_t=torch.tensor([0], dtype=torch.long))),
73+
axes_i=[1]), False)
74+
else:
75+
axes = g.op(
76+
'Constant', value_t=torch.tensor([1], dtype=torch.long))
77+
batch_indices = _cast_Long(
78+
g,
79+
g.op(
80+
'Squeeze',
81+
select(
82+
g, rois, 1,
83+
g.op(
84+
'Constant',
85+
value_t=torch.tensor([0], dtype=torch.long))),
86+
axes), False)
7187
rois = select(
7288
g, rois, 1,
7389
g.op(
7490
'Constant',
7591
value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
76-
ir_cfg = get_ir_config(ctx.cfg)
77-
opset_version = ir_cfg.get('opset_version', 11)
92+
7893
if opset_version < 16:
7994
# preprocess rois to make compatible with opset 16-
8095
# as for opset 16+, `aligned` get implemented inside onnxruntime.
@@ -96,6 +111,10 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
96111
sampling_ratio_i=sampling_ratio,
97112
mode_s=pool_mode)
98113
else:
114+
if aligned:
115+
coordinate_transformation_mode = 'half_pixel'
116+
else:
117+
coordinate_transformation_mode = 'output_half_pixel'
99118
return g.op(
100119
'RoiAlign',
101120
input,
@@ -106,4 +125,5 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
106125
spatial_scale_f=spatial_scale,
107126
sampling_ratio_i=sampling_ratio,
108127
mode_s=pool_mode,
109-
aligned_i=aligned)
128+
coordinate_transformation_mode_s=coordinate_transformation_mode
129+
)

0 commit comments

Comments
 (0)