Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/regression-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
apt update && apt install unzip
python -V
python -m pip install --upgrade pip
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable onnxconverter-common
python -m pip install opencv-python==4.5.4.60 opencv-python-headless==4.5.4.60 opencv-contrib-python==4.5.4.60
python .github/scripts/prepare_reg_test.py --torch-version ${{ matrix.torch_version }} --codebases ${{ matrix.codebase}}
python -m pip install -r requirements.txt
Expand Down Expand Up @@ -221,7 +221,7 @@ jobs:
conda activate $env:TEMP_ENV
python -V
python -m pip install --upgrade pip
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable onnxconverter-common
python -m pip install opencv-python==4.5.4.60 opencv-python-headless==4.5.4.60 opencv-contrib-python==4.5.4.60
python .github/scripts/prepare_reg_test.py --torch-version ${{ matrix.torch_version }} --codebases ${{ matrix.codebase}}
python -m pip install -r requirements.txt
Expand Down
10 changes: 10 additions & 0 deletions configs/_base_/backends/onnxruntime-fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
backend_config = dict(
type='onnxruntime',
precision='fp16',
common_config=dict(
min_positive_val=1e-7,
max_finite_val=1e4,
keep_io_types=False,
disable_shape_infer=False,
op_block_list=None,
node_block_list=None))
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
'./video-recognition_static.py',
'../../_base_/backends/onnxruntime-fp16.py'
]

onnx_config = dict(input_shape=None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'./super-resolution_dynamic.py',
'../../_base_/backends/onnxruntime-fp16.py'
]
3 changes: 3 additions & 0 deletions configs/mmdet/detection/detection_onnxruntime-fp16_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'../_base_/base_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/base_instance-seg_dynamic.py',
'../../_base_/backends/onnxruntime-fp16.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./voxel-detection_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./text-detection_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'./text-recognition_dynamic.py',
'../../_base_/backends/onnxruntime-fp16.py'
]
3 changes: 3 additions & 0 deletions configs/mmpose/pose-detection_onnxruntime-fp16_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./pose-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
]
18 changes: 18 additions & 0 deletions configs/mmpose/pose-detection_simcc_onnxruntime-fp16_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = [
'./pose-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
]

onnx_config = dict(
input_shape=[192, 256],
output_names=['simcc_x', 'simcc_y'],
dynamic_axes={
'input': {
0: 'batch',
},
'simcc_x': {
0: 'batch'
},
'simcc_y': {
0: 'batch'
}
})
3 changes: 3 additions & 0 deletions configs/mmpretrain/classification_onnxruntime-fp16_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./classification_dynamic.py', '../_base_/backends/onnxruntime-fp16.py'
]
25 changes: 25 additions & 0 deletions configs/mmrotate/rotated-detection_onnxruntime-fp16_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = [
'./rotated-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
]

onnx_config = dict(
output_names=['dets', 'labels'],
input_shape=[1024, 1024],
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'dets': {
0: 'batch',
1: 'num_dets',
},
'labels': {
0: 'batch',
1: 'num_dets',
},
})

backend_config = dict(
common_config=dict(op_block_list=['NMSRotated', 'Resize']))
3 changes: 3 additions & 0 deletions configs/mmseg/segmentation_onnxruntime-fp16_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./segmentation_dynamic.py', '../_base_/backends/onnxruntime-fp16.py'
]
8 changes: 8 additions & 0 deletions docs/en/05-supported-backends/onnxruntime.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ pip install onnxruntime==1.8.1 # if you want to use cpu version
pip install onnxruntime-gpu==1.8.1 # if you want to use gpu version
```

### Install float16 conversion tool (optional)

If you want to use float16 precision, install the tool by running the following script:

```bash
pip install onnx onnxconverter-common
```

## Build custom ops

### Download ONNXRuntime Library
Expand Down
2 changes: 1 addition & 1 deletion docs/en/05-supported-backends/openvino.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Notes:

- Custom operations from OpenVINO use the domain `org.openvinotoolkit`.
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvino.ai/2022.3/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.

## Deployment config
Expand Down
8 changes: 8 additions & 0 deletions docs/zh_cn/05-supported-backends/onnxruntime.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ pip install onnxruntime==1.8.1 # if you want to use cpu version
pip install onnxruntime-gpu==1.8.1 # if you want to use gpu version
```

### Install float16 conversion tool (optional)

If you want to use float16 precision, install the tool by running the following script:

```bash
pip install onnx onnxconverter-common
```

## Build custom ops

### Download ONNXRuntime Library
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/05-supported-backends/openvino.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Notes:

- Custom operations from OpenVINO use the domain `org.openvinotoolkit`.
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvino.ai/2022.3/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.

## Deployment config
Expand Down
14 changes: 14 additions & 0 deletions mmdeploy/backend/onnxruntime/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path as osp
from typing import Any, Callable, Optional, Sequence

from mmdeploy.utils import get_backend_config, get_common_config
from ..base import BACKEND_MANAGERS, BaseBackendManager


Expand Down Expand Up @@ -125,6 +126,7 @@ def check_env(cls, log_callback: Callable = lambda _: _) -> str:
def to_backend(cls,
ir_files: Sequence[str],
work_dir: str,
deploy_cfg: Any,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
Expand All @@ -134,9 +136,21 @@ def to_backend(cls,
ir_files (Sequence[str]): The intermediate representation files.
work_dir (str): The work directory, backend files and logs should
be saved in this directory.
deploy_cfg (Any): The deploy config.
log_level (int, optional): The log level. Defaults to logging.INFO.
device (str, optional): The device type. Defaults to 'cpu'.
Returns:
Sequence[str]: Backend files.
"""
backend_cfg = get_backend_config(deploy_cfg)

precision = backend_cfg.get('precision', 'fp32')
if precision == 'fp16':
import onnx
from onnxconverter_common import float16

common_cfg = get_common_config(deploy_cfg)
model = onnx.load(ir_files[0])
model_fp16 = float16.convert_float_to_float16(model, **common_cfg)
onnx.save(model_fp16, ir_files[0])
return ir_files
7 changes: 7 additions & 0 deletions mmdeploy/backend/onnxruntime/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path as osp
from typing import Dict, Optional, Sequence

import numpy as np
import onnxruntime as ort
import torch

Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self,
if output_names is None:
output_names = [_.name for _ in sess.get_outputs()]
self.sess = sess
self._input_metas = {_.name: _ for _ in sess.get_inputs()}
self.io_binding = sess.io_binding()
self.device_id = device_id
self.device_type = 'cpu' if device == 'cpu' else 'cuda'
Expand All @@ -75,6 +77,9 @@ def forward(self, inputs: Dict[str,
"""
for name, input_tensor in inputs.items():
# set io binding for inputs/outputs
input_type = self._input_metas[name].type
if 'float16' in input_type:
input_tensor = input_tensor.to(torch.float16)
input_tensor = input_tensor.contiguous()
if self.device_type == 'cpu':
input_tensor = input_tensor.cpu()
Expand All @@ -98,6 +103,8 @@ def forward(self, inputs: Dict[str,
output_list = self.io_binding.copy_outputs_to_cpu()
outputs = {}
for output_name, numpy_tensor in zip(self._output_names, output_list):
if numpy_tensor.dtype == np.float16:
numpy_tensor = numpy_tensor.astype(np.float32)
outputs[output_name] = torch.from_numpy(numpy_tensor)

return outputs
Expand Down
6 changes: 5 additions & 1 deletion tests/regression/mmaction.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ onnxruntime:
convert_image: *convert_image
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py
backend_test: *default_backend_test
pipeline_ort_static_fp16: &pipeline_ort_static_fp16
convert_image: *convert_image
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime-fp16_static.py
backend_test: *default_backend_test

tensorrt:
pipeline_trt_2d_static_fp32: &pipeline_trt_2d_static_fp32
Expand All @@ -45,7 +49,7 @@ models:
model_configs:
- configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_ort_static_fp16
- *pipeline_trt_2d_static_fp32

- name: SlowFast
Expand Down
7 changes: 6 additions & 1 deletion tests/regression/mmagic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ onnxruntime:
convert_image: *convert_image
deploy_config: configs/mmagic/super-resolution/super-resolution_onnxruntime_dynamic.py

pipeline_ort_dynamic_fp16: &pipeline_ort_dynamic_fp16
convert_image: *convert_image
deploy_config: configs/mmagic/super-resolution/super-resolution_onnxruntime-fp16_dynamic.py


tensorrt:
pipeline_trt_static_fp32: &pipeline_trt_static_fp32
convert_image: *convert_image
Expand Down Expand Up @@ -114,7 +119,7 @@ models:
- configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py
pipelines:
- *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32
- *pipeline_ort_dynamic_fp16
# - *pipeline_trt_dynamic_fp32
- *pipeline_trt_dynamic_fp16
# - *pipeline_trt_dynamic_int8
Expand Down
14 changes: 12 additions & 2 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ onnxruntime:
backend_test: False
deploy_config: configs/mmdet/detection/detection_onnxruntime_dynamic.py

pipeline_ort_dynamic_fp16: &pipeline_ort_dynamic_fp16
convert_image: *convert_image
backend_test: False
deploy_config: configs/mmdet/detection/detection_onnxruntime-fp16_dynamic.py

pipeline_seg_ort_static_fp32: &pipeline_seg_ort_static_fp32
convert_image: *convert_image
backend_test: False
Expand All @@ -48,6 +53,11 @@ onnxruntime:
backend_test: False
deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py

pipeline_seg_ort_dynamic_fp16: &pipeline_seg_ort_dynamic_fp16
convert_image: *convert_image
backend_test: False
deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime-fp16_dynamic.py

tensorrt:
pipeline_trt_static_fp32: &pipeline_trt_static_fp32
convert_image: *convert_image
Expand Down Expand Up @@ -203,7 +213,7 @@ models:
- configs/retinanet/retinanet_r50_fpn_1x_coco.py
pipelines:
- *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32
- *pipeline_ort_dynamic_fp16
- *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32
- *pipeline_pplnn_dynamic_fp32
Expand Down Expand Up @@ -323,7 +333,7 @@ models:
- configs/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py
pipelines:
- *pipeline_seg_ts_fp32
- *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_ort_dynamic_fp16
- *pipeline_seg_trt_dynamic_fp32
- *pipeline_seg_openvino_dynamic_fp32

Expand Down
7 changes: 6 additions & 1 deletion tests/regression/mmdet3d.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ onnxruntime:
backend_test: False
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py

pipeline_ort_dynamic_kitti_fp16: &pipeline_ort_dynamic_kitti_fp16
convert_image: *convert_image
backend_test: False
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime-fp16_dynamic.py

pipeline_ort_dynamic_nus_fp32: &pipeline_ort_dynamic_nus_fp32
convert_image: *convert_image_nus
backend_test: False
Expand Down Expand Up @@ -86,7 +91,7 @@ models:
- configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py
- configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-car.py
pipelines:
- *pipeline_ort_dynamic_kitti_fp32
- *pipeline_ort_dynamic_kitti_fp16
- *pipeline_openvino_dynamic_kitti_fp32
- *pipeline_trt_dynamic_kitti_fp32
- name: PointPillars
Expand Down
13 changes: 11 additions & 2 deletions tests/regression/mmocr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ onnxruntime:
convert_image: *convert_image_det
deploy_config: configs/mmocr/text-detection/text-detection_onnxruntime_dynamic.py

pipeline_ort_detection_dynamic_fp16: &pipeline_ort_detection_dynamic_fp16
convert_image: *convert_image_det
deploy_config: configs/mmocr/text-detection/text-detection_onnxruntime-fp16_dynamic.py

pipeline_ort_detection_mrcnn_dynamic_fp32: &pipeline_ort_detection_mrcnn_dynamic_fp32
convert_image: *convert_image_det
deploy_config: configs/mmocr/text-detection/text-detection_mrcnn_onnxruntime_dynamic.py
Expand All @@ -56,6 +60,11 @@ onnxruntime:
convert_image: *convert_image_rec
deploy_config: configs/mmocr/text-recognition/text-recognition_onnxruntime_dynamic.py

pipeline_ort_recognition_dynamic_fp16: &pipeline_ort_recognition_dynamic_fp16
convert_image: *convert_image_rec
deploy_config: configs/mmocr/text-recognition/text-recognition_onnxruntime-fp16_dynamic.py


tensorrt:
# ======= detection =======
pipeline_trt_detection_static_fp32: &pipeline_trt_detection_static_fp32
Expand Down Expand Up @@ -239,7 +248,7 @@ models:
- configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py
pipelines:
- *pipeline_ts_detection_fp32
- *pipeline_ort_detection_dynamic_fp32
- *pipeline_ort_detection_dynamic_fp16
- *pipeline_trt_detection_dynamic_fp16
- *pipeline_ncnn_detection_static_fp32
- *pipeline_pplnn_detection_dynamic_fp32
Expand Down Expand Up @@ -303,7 +312,7 @@ models:
- configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py
pipelines:
- *pipeline_ts_recognition_fp32
- *pipeline_ort_recognition_dynamic_fp32
- *pipeline_ort_recognition_dynamic_fp16
- *pipeline_trt_recognition_dynamic_fp16_H32_C1
- *pipeline_ncnn_recognition_static_fp32
- *pipeline_pplnn_recognition_dynamic_fp32
Expand Down
Loading