@@ -58,23 +58,38 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
58
58
else :
59
59
from torch .onnx .symbolic_opset9 import _cast_Long
60
60
from torch .onnx .symbolic_opset11 import add , select
61
- batch_indices = _cast_Long (
62
- g ,
63
- g .op (
64
- 'Squeeze' ,
65
- select (
66
- g , rois , 1 ,
67
- g .op (
68
- 'Constant' ,
69
- value_t = torch .tensor ([0 ], dtype = torch .long ))),
70
- axes_i = [1 ]), False )
61
+ ir_cfg = get_ir_config (ctx .cfg )
62
+ opset_version = ir_cfg .get ('opset_version' , 11 )
63
+ if opset_version < 13 :
64
+ batch_indices = _cast_Long (
65
+ g ,
66
+ g .op (
67
+ 'Squeeze' ,
68
+ select (
69
+ g , rois , 1 ,
70
+ g .op (
71
+ 'Constant' ,
72
+ value_t = torch .tensor ([0 ], dtype = torch .long ))),
73
+ axes_i = [1 ]), False )
74
+ else :
75
+ axes = g .op (
76
+ 'Constant' , value_t = torch .tensor ([1 ], dtype = torch .long ))
77
+ batch_indices = _cast_Long (
78
+ g ,
79
+ g .op (
80
+ 'Squeeze' ,
81
+ select (
82
+ g , rois , 1 ,
83
+ g .op (
84
+ 'Constant' ,
85
+ value_t = torch .tensor ([0 ], dtype = torch .long ))),
86
+ axes ), False )
71
87
rois = select (
72
88
g , rois , 1 ,
73
89
g .op (
74
90
'Constant' ,
75
91
value_t = torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .long )))
76
- ir_cfg = get_ir_config (ctx .cfg )
77
- opset_version = ir_cfg .get ('opset_version' , 11 )
92
+
78
93
if opset_version < 16 :
79
94
# preprocess rois to make compatible with opset 16-
80
95
# as for opset 16+, `aligned` get implemented inside onnxruntime.
@@ -96,6 +111,10 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
96
111
sampling_ratio_i = sampling_ratio ,
97
112
mode_s = pool_mode )
98
113
else :
114
+ if aligned :
115
+ coordinate_transformation_mode = 'half_pixel'
116
+ else :
117
+ coordinate_transformation_mode = 'output_half_pixel'
99
118
return g .op (
100
119
'RoiAlign' ,
101
120
input ,
@@ -106,4 +125,5 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
106
125
spatial_scale_f = spatial_scale ,
107
126
sampling_ratio_i = sampling_ratio ,
108
127
mode_s = pool_mode ,
109
- aligned_i = aligned )
128
+ coordinate_transformation_mode_s = coordinate_transformation_mode
129
+ )
0 commit comments