|
2 | 2 | import warnings
|
3 | 3 |
|
4 | 4 | import torch
|
| 5 | +from torch.onnx import symbolic_opset11 as opset11 |
| 6 | +from torch.onnx.symbolic_helper import parse_args |
5 | 7 |
|
6 |
| -_onnx_opset_version_11 = 11 |
7 |
| -_onnx_opset_version_16 = 16 |
8 |
| -base_onnx_opset_version = _onnx_opset_version_11 |
| 8 | +_ONNX_OPSET_VERSION_11 = 11 |
| 9 | +_ONNX_OPSET_VERSION_16 = 16 |
| 10 | +BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11 |
9 | 11 |
|
10 | 12 |
|
11 |
| -def _register_custom_op(): |
12 |
| - from torch.onnx.symbolic_helper import parse_args |
13 |
| - from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze |
14 |
| - |
15 |
| - @parse_args("v", "v", "f") |
16 |
| - def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): |
17 |
| - boxes = unsqueeze(g, boxes, 0) |
18 |
| - scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) |
19 |
| - max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) |
20 |
| - iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) |
21 |
| - nms_out = g.op( |
22 |
| - "NonMaxSuppression", |
23 |
| - g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
24 |
| - g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
25 |
| - max_output_per_class, |
26 |
| - iou_threshold, |
27 |
| - ) |
28 |
| - return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) |
29 |
| - |
30 |
| - def _process_batch_indices_for_roi_align(g, rois): |
31 |
| - indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1) |
32 |
| - return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64) |
33 |
| - |
34 |
| - def _process_rois_for_roi_align(g, rois): |
35 |
| - return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) |
36 |
| - |
37 |
| - def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int): |
38 |
| - if sampling_ratio < 0: |
39 |
| - warnings.warn( |
40 |
| - "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. " |
41 |
| - "The model will be exported with a sampling_ratio of 0." |
42 |
| - ) |
43 |
| - sampling_ratio = 0 |
44 |
| - return sampling_ratio |
45 |
| - |
46 |
| - @parse_args("v", "v", "f", "i", "i", "i", "i") |
47 |
| - def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): |
48 |
| - batch_indices = _process_batch_indices_for_roi_align(g, rois) |
49 |
| - rois = _process_rois_for_roi_align(g, rois) |
50 |
| - if aligned: |
51 |
| - warnings.warn( |
52 |
| - "ROIAlign with aligned=True is only supported in opset >= 16. " |
53 |
| - "Please export with opset 16 or higher, or use aligned=False." |
54 |
| - ) |
55 |
| - sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio) |
56 |
| - return g.op( |
57 |
| - "RoiAlign", |
58 |
| - input, |
59 |
| - rois, |
60 |
| - batch_indices, |
61 |
| - spatial_scale_f=spatial_scale, |
62 |
| - output_height_i=pooled_height, |
63 |
| - output_width_i=pooled_width, |
64 |
| - sampling_ratio_i=sampling_ratio, |
65 |
| - ) |
| 13 | +@parse_args("v", "v", "f") |
| 14 | +def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): |
| 15 | + boxes = opset11.unsqueeze(g, boxes, 0) |
| 16 | + scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0) |
| 17 | + max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) |
| 18 | + iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) |
| 19 | + |
| 20 | + # Cast boxes and scores to float32 in case they are float64 inputs |
| 21 | + nms_out = g.op( |
| 22 | + "NonMaxSuppression", |
| 23 | + g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
| 24 | + g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT), |
| 25 | + max_output_per_class, |
| 26 | + iou_threshold, |
| 27 | + ) |
| 28 | + return opset11.squeeze( |
| 29 | + g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1 |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +def _process_batch_indices_for_roi_align(g, rois): |
| 34 | + indices = opset11.squeeze( |
| 35 | + g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1 |
| 36 | + ) |
| 37 | + return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64) |
| 38 | + |
| 39 | + |
| 40 | +def _process_rois_for_roi_align(g, rois): |
| 41 | + return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) |
66 | 42 |
|
67 |
| - @parse_args("v", "v", "f", "i", "i", "i", "i") |
68 |
| - def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): |
69 |
| - batch_indices = _process_batch_indices_for_roi_align(g, rois) |
70 |
| - rois = _process_rois_for_roi_align(g, rois) |
71 |
| - coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" |
72 |
| - sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio) |
73 |
| - return g.op( |
74 |
| - "RoiAlign", |
75 |
| - input, |
76 |
| - rois, |
77 |
| - batch_indices, |
78 |
| - coordinate_transformation_mode_s=coordinate_transformation_mode, |
79 |
| - spatial_scale_f=spatial_scale, |
80 |
| - output_height_i=pooled_height, |
81 |
| - output_width_i=pooled_width, |
82 |
| - sampling_ratio_i=sampling_ratio, |
| 43 | + |
| 44 | +def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int): |
| 45 | + if sampling_ratio < 0: |
| 46 | + warnings.warn( |
| 47 | + "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. " |
| 48 | + "The model will be exported with a sampling_ratio of 0." |
83 | 49 | )
|
| 50 | + sampling_ratio = 0 |
| 51 | + return sampling_ratio |
| 52 | + |
84 | 53 |
|
85 |
| - @parse_args("v", "v", "f", "i", "i") |
86 |
| - def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): |
87 |
| - roi_pool = g.op( |
88 |
| - "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale |
| 54 | +@parse_args("v", "v", "f", "i", "i", "i", "i") |
| 55 | +def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): |
| 56 | + batch_indices = _process_batch_indices_for_roi_align(g, rois) |
| 57 | + rois = _process_rois_for_roi_align(g, rois) |
| 58 | + if aligned: |
| 59 | + warnings.warn( |
| 60 | + "ROIAlign with aligned=True is only supported in opset >= 16. " |
| 61 | + "Please export with opset 16 or higher, or use aligned=False." |
89 | 62 | )
|
90 |
| - return roi_pool, None |
| 63 | + sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio) |
| 64 | + return g.op( |
| 65 | + "RoiAlign", |
| 66 | + input, |
| 67 | + rois, |
| 68 | + batch_indices, |
| 69 | + spatial_scale_f=spatial_scale, |
| 70 | + output_height_i=pooled_height, |
| 71 | + output_width_i=pooled_width, |
| 72 | + sampling_ratio_i=sampling_ratio, |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +@parse_args("v", "v", "f", "i", "i", "i", "i") |
| 77 | +def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): |
| 78 | + batch_indices = _process_batch_indices_for_roi_align(g, rois) |
| 79 | + rois = _process_rois_for_roi_align(g, rois) |
| 80 | + coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" |
| 81 | + sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio) |
| 82 | + return g.op( |
| 83 | + "RoiAlign", |
| 84 | + input, |
| 85 | + rois, |
| 86 | + batch_indices, |
| 87 | + coordinate_transformation_mode_s=coordinate_transformation_mode, |
| 88 | + spatial_scale_f=spatial_scale, |
| 89 | + output_height_i=pooled_height, |
| 90 | + output_width_i=pooled_width, |
| 91 | + sampling_ratio_i=sampling_ratio, |
| 92 | + ) |
91 | 93 |
|
92 |
| - from torch.onnx import register_custom_op_symbolic |
93 | 94 |
|
94 |
| - register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11) |
95 |
| - register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11) |
96 |
| - register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16) |
97 |
| - register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11) |
| 95 | +@parse_args("v", "v", "f", "i", "i") |
| 96 | +def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): |
| 97 | + roi_pool = g.op( |
| 98 | + "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale |
| 99 | + ) |
| 100 | + return roi_pool, None |
| 101 | + |
| 102 | + |
| 103 | +def _register_custom_op(): |
| 104 | + torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11) |
| 105 | + torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11) |
| 106 | + torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16) |
| 107 | + torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11) |
0 commit comments