Skip to content

Commit e500124

Browse files
grimoirelvhan028
authored andcommitted
[Refactor] Add backend manager for 1.x (#1515)
* backend manager 1.x * update pplnn init * rename file * add to backend * add check env and misc * fix action * fix ut * fix comment
1 parent 85d0895 commit e500124

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1967
-989
lines changed

.github/workflows/backend-ncnn.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ jobs:
7979
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
8080
python -m pip install mmcv-lite
8181
python tools/scripts/build_ubuntu_x64_ncnn.py
82-
python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available() and ncnn_api.is_custom_ops_available()'
82+
python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available(with_custom_ops=True)'

.github/workflows/backend-ort.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
3737
python -m pip install mmcv-lite openmim
3838
python tools/scripts/build_ubuntu_x64_ort.py
39-
python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available() and ort_api.is_custom_ops_available()'
39+
python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available(with_custom_ops=True)'
4040
- name: test mmcls full pipeline
4141
run: |
4242
python -m mim install $(cat requirements/codebases.txt | grep mmcls)

csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,6 @@ mmdeploy_export(${PROJECT_NAME}_obj)
3131
mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "")
3232
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj)
3333
add_library(mmdeploy::torchscript_ops ALIAS ${PROJECT_NAME})
34+
35+
set(_TORCHJIT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
36+
install(TARGETS ${PROJECT_NAME} DESTINATION ${_TORCHJIT_OPS_DIR})

docs/en/07-developer-guide/support_new_backend.md

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -123,32 +123,20 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
123123
__all__ += ['onnx2ncnn', 'get_output_model_file']
124124
```
125125

126-
Then add the codes about conversion to `tools/deploy.py` using these APIs if necessary.
126+
Create a backend manager class which derive from `BaseBackendManager`, implement its `to_backend` static method.
127127

128128
**Example:**
129129

130130
```Python
131-
# tools/deploy.py
132-
# ...
133-
elif backend == Backend.NCNN:
134-
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
135-
136-
if not is_available_ncnn():
137-
logging.error('ncnn support is not available.')
138-
exit(-1)
139-
140-
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
141-
142-
backend_files = []
143-
for onnx_path in onnx_files:
144-
create_process(
145-
f'onnx2ncnn with {onnx_path}',
146-
target=onnx2ncnn,
147-
args=(onnx_path, args.work_dir),
148-
kwargs=dict(),
149-
ret_value=ret_value)
150-
backend_files += get_output_model_file(onnx_path, args.work_dir)
151-
# ...
131+
@classmethod
132+
def to_backend(cls,
133+
ir_files: Sequence[str],
134+
deploy_cfg: Any,
135+
work_dir: str,
136+
log_level: int = logging.INFO,
137+
device: str = 'cpu',
138+
**kwargs) -> Sequence[str]:
139+
return ir_files
152140
```
153141

154142
6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators.
@@ -209,23 +197,26 @@ Although the backend engines are usually implemented in C/C++, it is convenient
209197
self.sess.run_with_iobinding(io_binding)
210198
```
211199

212-
4. Add a default initialization method for the new wrapper in `mmdeploy/codebase/base/backend_model.py`
200+
4. Create a backend manager class which derive from `BaseBackendManager`, implement its `build_wrapper` static method.
213201

214202
**Example:**
215203

216204
```Python
217-
@staticmethod
218-
def _build_wrapper(backend: Backend,
219-
backend_files: Sequence[str],
220-
device: str,
221-
input_names: Optional[Sequence[str]] = None,
222-
output_names: Optional[Sequence[str]] = None):
223-
if backend == Backend.ONNXRUNTIME:
224-
from mmdeploy.backend.onnxruntime import ORTWrapper
225-
return ORTWrapper(
226-
onnx_file=backend_files[0],
227-
device=device,
228-
output_names=output_names)
205+
@BACKEND_MANAGERS.register('onnxruntime')
206+
class ONNXRuntimeManager(BaseBackendManager):
207+
@classmethod
208+
def build_wrapper(cls,
209+
backend_files: Sequence[str],
210+
device: str = 'cpu',
211+
input_names: Optional[Sequence[str]] = None,
212+
output_names: Optional[Sequence[str]] = None,
213+
deploy_cfg: Optional[Any] = None,
214+
**kwargs):
215+
from .wrapper import ORTWrapper
216+
return ORTWrapper(
217+
onnx_file=backend_files[0],
218+
device=device,
219+
output_names=output_names)
229220
```
230221

231222
5. Add docstring and unit tests for new code :).

docs/zh_cn/07-developer-guide/support_new_backend.md

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -123,32 +123,20 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
123123
__all__ += ['onnx2ncnn', 'get_output_model_file']
124124
```
125125

126-
然后根据需要使用这些 APIs 为 `tools/deploy.py` 添加相关转换代码
126+
从 BaseBackendManager 派生类,实现 `to_backend` 类方法。
127127

128128
**例子**
129129

130130
```Python
131-
# tools/deploy.py
132-
# ...
133-
elif backend == Backend.NCNN:
134-
from mmdeploy.apis.ncnn import is_available as is_available_ncnn
135-
136-
if not is_available_ncnn():
137-
logging.error('ncnn support is not available.')
138-
exit(-1)
139-
140-
from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file
141-
142-
backend_files = []
143-
for onnx_path in onnx_files:
144-
create_process(
145-
f'mmdeploy_onnx2ncnn with {onnx_path}',
146-
target=onnx2ncnn,
147-
args=(onnx_path, args.work_dir),
148-
kwargs=dict(),
149-
ret_value=ret_value)
150-
backend_files += get_output_model_file(onnx_path, args.work_dir)
151-
# ...
131+
@classmethod
132+
def to_backend(cls,
133+
ir_files: Sequence[str],
134+
deploy_cfg: Any,
135+
work_dir: str,
136+
log_level: int = logging.INFO,
137+
device: str = 'cpu',
138+
**kwargs) -> Sequence[str]:
139+
return ir_files
152140
```
153141

154142
6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
@@ -210,22 +198,26 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
210198
self.sess.run_with_iobinding(io_binding)
211199
```
212200

213-
4. 为新封装装器添加默认初始化方法 `mmdeploy/codebase/base/backend_model.py`
201+
4. `BaseBackendManager` 派生接口类,实现 `build_wrapper` 静态方法
214202

215203
**例子**
216204

217205
```Python
218-
@staticmethod
219-
def _build_wrapper(backend: Backend,
220-
backend_files: Sequence[str],
221-
device: str,
222-
output_names: Optional[Sequence[str]] = None):
223-
if backend == Backend.ONNXRUNTIME:
224-
from mmdeploy.backend.onnxruntime import ORTWrapper
225-
return ORTWrapper(
226-
onnx_file=backend_files[0],
227-
device=device,
228-
output_names=output_names)
206+
@BACKEND_MANAGERS.register('onnxruntime')
207+
class ONNXRuntimeManager(BaseBackendManager):
208+
@classmethod
209+
def build_wrapper(cls,
210+
backend_files: Sequence[str],
211+
device: str = 'cpu',
212+
input_names: Optional[Sequence[str]] = None,
213+
output_names: Optional[Sequence[str]] = None,
214+
deploy_cfg: Optional[Any] = None,
215+
**kwargs):
216+
from .wrapper import ORTWrapper
217+
return ORTWrapper(
218+
onnx_file=backend_files[0],
219+
device=device,
220+
output_names=output_names)
229221
```
230222

231223
5. 为新后端引擎代码添加相关注释和单元测试 :).

mmdeploy/apis/__init__.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .calibration import create_calib_input_data
3+
from .extract_model import extract_model
4+
from .inference import inference_model
5+
from .pytorch2onnx import torch2onnx
6+
from .pytorch2torchscript import torch2torchscript
7+
from .utils import build_task_processor, get_predefined_partition_cfg
8+
from .visualize import visualize_model
29

3-
# mmcv & mmengine dependency
4-
try:
5-
from .calibration import create_calib_input_data
6-
from .extract_model import extract_model
7-
from .inference import inference_model
8-
from .pytorch2onnx import torch2onnx
9-
from .pytorch2torchscript import torch2torchscript
10-
from .utils import build_task_processor, get_predefined_partition_cfg
11-
from .visualize import visualize_model
12-
13-
__all__ = [
14-
'create_calib_input_data', 'extract_model', 'inference_model',
15-
'torch2onnx', 'torch2torchscript', 'build_task_processor',
16-
'get_predefined_partition_cfg', 'visualize_model'
17-
]
18-
except Exception:
19-
pass
10+
__all__ = [
11+
'create_calib_input_data', 'extract_model', 'inference_model',
12+
'torch2onnx', 'torch2torchscript', 'build_task_processor',
13+
'get_predefined_partition_cfg', 'visualize_model'
14+
]

mmdeploy/apis/calibration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44

55
from mmengine import Config
66

7-
from mmdeploy.core import patch_model
8-
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend, get_ir_config,
9-
load_config)
107
from .core import PIPELINE_MANAGER, no_mp
11-
from .utils import create_calib_input_data as create_calib_input_data_impl
128

139

1410
@PIPELINE_MANAGER.register_pipeline()
@@ -34,6 +30,11 @@ def create_calib_input_data(calib_file: str,
3430
dataset_type (str, optional): The dataset type. Defaults to 'val'.
3531
device (str, optional): Device to create dataset. Defaults to 'cpu'.
3632
"""
33+
34+
from mmdeploy.core import patch_model
35+
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend,
36+
get_ir_config, load_config)
37+
from .utils import create_calib_input_data as create_calib_input_data_impl
3738
with no_mp():
3839
if dataset_cfg is None:
3940
dataset_cfg = model_cfg

mmdeploy/apis/extract_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import onnx
66

77
from .core import PIPELINE_MANAGER
8-
from .onnx import extract_partition
98

109

1110
@PIPELINE_MANAGER.register_pipeline()
@@ -63,5 +62,6 @@ def extract_model(model: Union[str, onnx.ModelProto],
6362
onnx.ModelProto: The extracted model.
6463
"""
6564

65+
from .onnx import extract_partition
6666
return extract_partition(model, start_marker, end_marker, start_name_map,
6767
end_name_map, dynamic_axes, save_file)

mmdeploy/apis/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
import mmengine
55
import numpy as np
6-
import torch
7-
8-
from mmdeploy.utils import get_input_shape, load_config
96

107

118
def inference_model(model_cfg: Union[str, mmengine.Config],
@@ -37,6 +34,9 @@ def inference_model(model_cfg: Union[str, mmengine.Config],
3734
Returns:
3835
Any: The inference results
3936
"""
37+
import torch
38+
39+
from mmdeploy.utils import get_input_shape, load_config
4040
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
4141

4242
from mmdeploy.apis.utils import build_task_processor

mmdeploy/apis/ncnn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from mmdeploy.backend.ncnn import from_onnx as _from_onnx
3-
from mmdeploy.backend.ncnn import is_available, is_custom_ops_available
3+
from mmdeploy.backend.ncnn import is_available
44
from ..core import PIPELINE_MANAGER
55

66
from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)
77

8-
__all__ = ['is_available', 'is_custom_ops_available', 'from_onnx']
8+
__all__ = ['is_available', 'from_onnx']
99

1010
if is_available():
1111
try:

0 commit comments

Comments
 (0)