Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/backend-ncnn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ jobs:
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
python -m pip install mmcv-lite
python tools/scripts/build_ubuntu_x64_ncnn.py
python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available() and ncnn_api.is_custom_ops_available()'
python -c 'import mmdeploy.apis.ncnn as ncnn_api; assert ncnn_api.is_available(with_custom_ops=True)'
2 changes: 1 addition & 1 deletion .github/workflows/backend-ort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
python -m pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
python -m pip install mmcv-lite openmim
python tools/scripts/build_ubuntu_x64_ort.py
python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available() and ort_api.is_custom_ops_available()'
python -c 'import mmdeploy.apis.onnxruntime as ort_api; assert ort_api.is_available(with_custom_ops=True)'
- name: test mmcls full pipeline
run: |
python -m mim install $(cat requirements/codebases.txt | grep mmcls)
Expand Down
3 changes: 3 additions & 0 deletions csrc/mmdeploy/backend_ops/torchscript/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ mmdeploy_export(${PROJECT_NAME}_obj)
mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj)
add_library(mmdeploy::torchscript_ops ALIAS ${PROJECT_NAME})

set(_TORCHJIT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_TORCHJIT_OPS_DIR})
61 changes: 26 additions & 35 deletions docs/en/07-developer-guide/support_new_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,20 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi
__all__ += ['onnx2ncnn', 'get_output_model_file']
```

Then add the codes about conversion to `tools/deploy.py` using these APIs if necessary.
Create a backend manager class which derive from `BaseBackendManager`, implement its `to_backend` static method.

**Example:**

```Python
# tools/deploy.py
# ...
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn

if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)

from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file

backend_files = []
for onnx_path in onnx_files:
create_process(
f'onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
deploy_cfg: Any,
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
return ir_files
```

6. Convert the models of OpenMMLab to backends (if necessary) and inference on backend engine. If you find some incompatible operators when testing, you can try to rewrite the original model for the backend following the [rewriter tutorial](support_new_model.md) or add custom operators.
Expand Down Expand Up @@ -209,23 +197,26 @@ Although the backend engines are usually implemented in C/C++, it is convenient
self.sess.run_with_iobinding(io_binding)
```

4. Add a default initialization method for the new wrapper in `mmdeploy/codebase/base/backend_model.py`
4. Create a backend manager class which derive from `BaseBackendManager`, implement its `build_wrapper` static method.

**Example:**

```Python
@staticmethod
def _build_wrapper(backend: Backend,
backend_files: Sequence[str],
device: str,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None):
if backend == Backend.ONNXRUNTIME:
from mmdeploy.backend.onnxruntime import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
@BACKEND_MANAGERS.register('onnxruntime')
class ONNXRuntimeManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
from .wrapper import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
```

5. Add docstring and unit tests for new code :).
Expand Down
60 changes: 26 additions & 34 deletions docs/zh_cn/07-developer-guide/support_new_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,20 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
__all__ += ['onnx2ncnn', 'get_output_model_file']
```

然后根据需要使用这些 APIs 为 `tools/deploy.py` 添加相关转换代码
从 BaseBackendManager 派生类,实现 `to_backend` 类方法。

**例子**

```Python
# tools/deploy.py
# ...
elif backend == Backend.NCNN:
from mmdeploy.apis.ncnn import is_available as is_available_ncnn

if not is_available_ncnn():
logging.error('ncnn support is not available.')
exit(-1)

from mmdeploy.apis.ncnn import onnx2ncnn, get_output_model_file

backend_files = []
for onnx_path in onnx_files:
create_process(
f'mmdeploy_onnx2ncnn with {onnx_path}',
target=onnx2ncnn,
args=(onnx_path, args.work_dir),
kwargs=dict(),
ret_value=ret_value)
backend_files += get_output_model_file(onnx_path, args.work_dir)
# ...
@classmethod
def to_backend(cls,
ir_files: Sequence[str],
deploy_cfg: Any,
work_dir: str,
log_level: int = logging.INFO,
device: str = 'cpu',
**kwargs) -> Sequence[str]:
return ir_files
```

6. 将 OpenMMLab 的模型转换后(如有必要)并在后端引擎上进行推理。如果在测试时发现一些不兼容的算子,可以尝试按照[重写器教程](support_new_model.md)为后端重写原始模型或添加自定义算子。
Expand Down Expand Up @@ -210,22 +198,26 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”
self.sess.run_with_iobinding(io_binding)
```

4. 为新封装装器添加默认初始化方法 `mmdeploy/codebase/base/backend_model.py`
4. 从 `BaseBackendManager` 派生接口类,实现 `build_wrapper` 静态方法

**例子**

```Python
@staticmethod
def _build_wrapper(backend: Backend,
backend_files: Sequence[str],
device: str,
output_names: Optional[Sequence[str]] = None):
if backend == Backend.ONNXRUNTIME:
from mmdeploy.backend.onnxruntime import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
@BACKEND_MANAGERS.register('onnxruntime')
class ONNXRuntimeManager(BaseBackendManager):
@classmethod
def build_wrapper(cls,
backend_files: Sequence[str],
device: str = 'cpu',
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[Any] = None,
**kwargs):
from .wrapper import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
```

5. 为新后端引擎代码添加相关注释和单元测试 :).
Expand Down
29 changes: 12 additions & 17 deletions mmdeploy/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .calibration import create_calib_input_data
from .extract_model import extract_model
from .inference import inference_model
from .pytorch2onnx import torch2onnx
from .pytorch2torchscript import torch2torchscript
from .utils import build_task_processor, get_predefined_partition_cfg
from .visualize import visualize_model

# mmcv & mmengine dependency
try:
from .calibration import create_calib_input_data
from .extract_model import extract_model
from .inference import inference_model
from .pytorch2onnx import torch2onnx
from .pytorch2torchscript import torch2torchscript
from .utils import build_task_processor, get_predefined_partition_cfg
from .visualize import visualize_model

__all__ = [
'create_calib_input_data', 'extract_model', 'inference_model',
'torch2onnx', 'torch2torchscript', 'build_task_processor',
'get_predefined_partition_cfg', 'visualize_model'
]
except Exception:
pass
__all__ = [
'create_calib_input_data', 'extract_model', 'inference_model',
'torch2onnx', 'torch2torchscript', 'build_task_processor',
'get_predefined_partition_cfg', 'visualize_model'
]
9 changes: 5 additions & 4 deletions mmdeploy/apis/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from mmengine import Config

from mmdeploy.core import patch_model
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend, get_ir_config,
load_config)
from .core import PIPELINE_MANAGER, no_mp
from .utils import create_calib_input_data as create_calib_input_data_impl


@PIPELINE_MANAGER.register_pipeline()
Expand All @@ -34,6 +30,11 @@ def create_calib_input_data(calib_file: str,
dataset_type (str, optional): The dataset type. Defaults to 'val'.
device (str, optional): Device to create dataset. Defaults to 'cpu'.
"""

from mmdeploy.core import patch_model
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend,
get_ir_config, load_config)
from .utils import create_calib_input_data as create_calib_input_data_impl
with no_mp():
if dataset_cfg is None:
dataset_cfg = model_cfg
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/apis/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import onnx

from .core import PIPELINE_MANAGER
from .onnx import extract_partition


@PIPELINE_MANAGER.register_pipeline()
Expand Down Expand Up @@ -63,5 +62,6 @@ def extract_model(model: Union[str, onnx.ModelProto],
onnx.ModelProto: The extracted model.
"""

from .onnx import extract_partition
return extract_partition(model, start_marker, end_marker, start_name_map,
end_name_map, dynamic_axes, save_file)
6 changes: 3 additions & 3 deletions mmdeploy/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import mmengine
import numpy as np
import torch

from mmdeploy.utils import get_input_shape, load_config


def inference_model(model_cfg: Union[str, mmengine.Config],
Expand Down Expand Up @@ -37,6 +34,9 @@ def inference_model(model_cfg: Union[str, mmengine.Config],
Returns:
Any: The inference results
"""
import torch

from mmdeploy.utils import get_input_shape, load_config
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)

from mmdeploy.apis.utils import build_task_processor
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/apis/ncnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.ncnn import from_onnx as _from_onnx
from mmdeploy.backend.ncnn import is_available, is_custom_ops_available
from mmdeploy.backend.ncnn import is_available
from ..core import PIPELINE_MANAGER

from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx)

__all__ = ['is_available', 'is_custom_ops_available', 'from_onnx']
__all__ = ['is_available', 'from_onnx']

if is_available():
try:
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/apis/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.onnxruntime import is_available, is_custom_ops_available
from mmdeploy.backend.onnxruntime import is_available

__all__ = ['is_available', 'is_custom_ops_available']
__all__ = ['is_available']
10 changes: 6 additions & 4 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

import mmengine

from mmdeploy.apis.core.pipeline_manager import no_mp
from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes,
get_input_shape, get_onnx_config, load_config)
from .core import PIPELINE_MANAGER
from .onnx import export


@PIPELINE_MANAGER.register_pipeline()
Expand Down Expand Up @@ -48,6 +44,12 @@ def torch2onnx(img: Any,
defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'.
"""

from mmdeploy.apis.core.pipeline_manager import no_mp
from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes,
get_input_shape, get_onnx_config, load_config)
from .onnx import export

# load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmengine.mkdir_or_exist(osp.abspath(work_dir))
Expand Down
8 changes: 5 additions & 3 deletions mmdeploy/apis/pytorch2torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
from typing import Any, Optional, Union

import mmengine
import torch

from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp
from mmdeploy.utils import get_backend, get_input_shape, load_config
from .torch_jit import trace


@PIPELINE_MANAGER.register_pipeline()
Expand All @@ -32,6 +29,11 @@ def torch2torchscript(img: Any,
defaults to `None`.
device (str): A string specifying device type, defaults to 'cuda:0'.
"""
import torch

from mmdeploy.utils import get_backend, get_input_shape, load_config
from .torch_jit import trace

# load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmengine.mkdir_or_exist(osp.abspath(work_dir))
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/apis/tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.tensorrt import is_available, is_custom_ops_available
from mmdeploy.backend.tensorrt import is_available
from ..core import PIPELINE_MANAGER

__all__ = ['is_available', 'is_custom_ops_available']
__all__ = ['is_available']

if is_available():
from mmdeploy.backend.tensorrt import from_onnx as _from_onnx
Expand Down
5 changes: 3 additions & 2 deletions mmdeploy/apis/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .calibration import create_calib_input_data
from .utils import build_task_processor, get_predefined_partition_cfg
from .utils import (build_task_processor, get_predefined_partition_cfg,
to_backend)

__all__ = [
'create_calib_input_data', 'build_task_processor',
'get_predefined_partition_cfg'
'get_predefined_partition_cfg', 'to_backend'
]
6 changes: 3 additions & 3 deletions mmdeploy/apis/utils/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
from copy import deepcopy
from typing import Callable, Dict, Optional

import h5py
import torch
import tqdm
from torch.utils.data import DataLoader

from mmdeploy.core import RewriterContext, reset_mark_function_count
from ..core import PIPELINE_MANAGER


Expand Down Expand Up @@ -46,7 +43,10 @@ def create_calib_input_data(calib_file: str,
'val', defaults to 'val'.
device (str): Specifying the device to run on, defaults to 'cpu'.
"""
import h5py
import tqdm

from mmdeploy.core import RewriterContext, reset_mark_function_count
backend = 'default'

with h5py.File(calib_file, mode='w') as file:
Expand Down
Loading