Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docker/Base/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,13 @@ RUN wget -c $TENSORRT_URL && \
ENV TENSORRT_DIR=/root/workspace/TensorRT
ENV LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH
ENV PATH=$TENSORRT_DIR/bin:$PATH

# openvino
RUN wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_ubuntu20_2022.3.0.9052.9752fafe8eb_x86_64.tgz &&\
tar -zxvf ./l_openvino_toolkit*.tgz &&\
rm ./l_openvino_toolkit*.tgz &&\
mv ./l_openvino* ./openvino_toolkit &&\
bash ./openvino_toolkit/install_dependencies/install_openvino_dependencies.sh

ENV OPENVINO_DIR=/root/workspace/openvino_toolkit
ENV InferenceEngine_DIR=$OPENVINO_DIR/runtime/cmake
25 changes: 24 additions & 1 deletion mmdeploy/apis/onnx/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable

import torch

from mmdeploy.core import FUNCTION_REWRITER


def update_squeeze_unsqueeze_opset13_pass(graph, params_dict, torch_out):
"""Update Squeeze/Unsqueeze axes for opset13."""
for node in graph.nodes():
if node.kind() in ['onnx::Squeeze', 'onnx::Unsqueeze'] and \
node.hasAttribute('axes'):
axes = node['axes']
axes_node = graph.create('onnx::Constant')
axes_node.t_('value', torch.LongTensor(axes))
node.removeAttribute('axes')
node.addInput(axes_node.output())
axes_node.insertBefore(node)
return graph, params_dict, torch_out


@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
def model_to_graph__custom_optimizer(*args, **kwargs):
"""Rewriter of _model_to_graph, add custom passes."""
ctx = FUNCTION_REWRITER.get_context()
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)

if hasattr(ctx, 'opset'):
opset_version = ctx.opset
else:
from mmdeploy.utils import get_ir_config
opset_version = get_ir_config(ctx.cfg).get('opset_version', 11)
if opset_version >= 13:
graph, params_dict, torch_out = update_squeeze_unsqueeze_opset13_pass(
graph, params_dict, torch_out)
custom_passes = getattr(ctx, 'onnx_custom_passes', None)

if custom_passes is not None:
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/apis/onnx/passes/optimize_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def optimize_onnx(ctx, graph, params_dict, torch_out):
logger.warning(
'Can not optimize model, please build torchscipt extension.\n'
'More details: '
'https://github.com/open-mmlab/mmdeploy/tree/1.x/docs/en/experimental/onnx_optimizer.md' # noqa
'https://github.com/open-mmlab/mmdeploy/tree/main/docs/en/experimental/onnx_optimizer.md' # noqa
)
return graph, params_dict, torch_out
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
stride_w: int):
"""Map ops to onnx symbolics."""
# zero_h and zero_w is used to provide shape to GridPriorsTRT
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
zero_h = g.op(
'ConstantOfShape',
feat_h,
Expand Down
8 changes: 5 additions & 3 deletions mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def build_backend_model(self,

def create_input(
self,
pcd: str,
pcd: Union[str, Sequence[str]],
input_shape: Sequence[int] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]:
"""Create input for detector.

Args:
pcd (str): Input pcd file path.
pcd (str, Sequence[str]): Input pcd file path.
input_shape (Sequence[int], optional): model input shape.
Defaults to None.
data_preprocessor (Optional[BaseDataPreprocessor], optional):
Expand All @@ -115,7 +115,9 @@ def create_input(
test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = \
get_box_type(cfg.test_dataloader.dataset.box_type_3d)

# do not support batch inference
if isinstance(pcd, (list, tuple)):
pcd = pcd[0]
data = []
data_ = dict(
lidar_points=dict(lidar_path=pcd),
Expand Down
48 changes: 24 additions & 24 deletions tests/regression/mmseg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ torchscript:

models:
- name: FCN
metafile: configs/fcn/fcn.yml
metafile: configs/fcn/metafile.yaml
model_configs:
- configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -134,7 +134,7 @@ models:
- *pipeline_openvino_dynamic_fp32

- name: PSPNet
metafile: configs/pspnet/pspnet.yml
metafile: configs/pspnet/metafile.yaml
model_configs:
- configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -146,7 +146,7 @@ models:
- *pipeline_openvino_static_fp32

- name: deeplabv3
metafile: configs/deeplabv3/deeplabv3.yml
metafile: configs/deeplabv3/metafile.yaml
model_configs:
- configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -158,7 +158,7 @@ models:
- *pipeline_openvino_dynamic_fp32

- name: deeplabv3+
metafile: configs/deeplabv3plus/deeplabv3plus.yml
metafile: configs/deeplabv3plus/metafile.yaml
model_configs:
- configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -170,7 +170,7 @@ models:
- *pipeline_openvino_dynamic_fp32

- name: Fast-SCNN
metafile: configs/fastscnn/fastscnn.yml
metafile: configs/fastscnn/metafile.yaml
model_configs:
- configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
pipelines:
Expand All @@ -181,7 +181,7 @@ models:
- *pipeline_openvino_static_fp32

- name: UNet
metafile: configs/unet/unet.yml
metafile: configs/unet/metafile.yaml
model_configs:
- configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines:
Expand All @@ -192,7 +192,7 @@ models:
- *pipeline_pplnn_dynamic_fp32

- name: ANN
metafile: configs/ann/ann.yml
metafile: configs/ann/metafile.yaml
model_configs:
- configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -201,7 +201,7 @@ models:
- *pipeline_ts_fp32

- name: APCNet
metafile: configs/apcnet/apcnet.yml
metafile: configs/apcnet/metafile.yaml
model_configs:
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -211,7 +211,7 @@ models:
- *pipeline_ts_fp32

- name: BiSeNetV1
metafile: configs/bisenetv1/bisenetv1.yml
metafile: configs/bisenetv1/metafile.yaml
model_configs:
- configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
pipelines:
Expand All @@ -222,7 +222,7 @@ models:
- *pipeline_ts_fp32

- name: BiSeNetV2
metafile: configs/bisenetv2/bisenetv2.yml
metafile: configs/bisenetv2/metafile.yaml
model_configs:
- configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
pipelines:
Expand All @@ -233,7 +233,7 @@ models:
- *pipeline_ts_fp32

- name: CGNet
metafile: configs/cgnet/cgnet.yml
metafile: configs/cgnet/metafile.yaml
model_configs:
- configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
pipelines:
Expand All @@ -244,7 +244,7 @@ models:
- *pipeline_ts_fp32

- name: EMANet
metafile: configs/emanet/emanet.yml
metafile: configs/emanet/metafile.yaml
model_configs:
- configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
pipelines:
Expand All @@ -254,7 +254,7 @@ models:
- *pipeline_ts_fp32

- name: EncNet
metafile: configs/encnet/encnet.yml
metafile: configs/encnet/metafile.yaml
model_configs:
- configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -264,7 +264,7 @@ models:
- *pipeline_ts_fp32

- name: ERFNet
metafile: configs/erfnet/erfnet.yml
metafile: configs/erfnet/metafile.yaml
model_configs:
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines:
Expand All @@ -275,7 +275,7 @@ models:
- *pipeline_ts_fp32

- name: FastFCN
metafile: configs/fastfcn/fastfcn.yml
metafile: configs/fastfcn/metafile.yaml
model_configs:
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
pipelines:
Expand All @@ -286,7 +286,7 @@ models:
- *pipeline_ts_fp32

- name: GCNet
metafile: configs/gcnet/gcnet.yml
metafile: configs/gcnet/metafile.yaml
model_configs:
- configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -295,7 +295,7 @@ models:
- *pipeline_ts_fp32

- name: ICNet
metafile: configs/icnet/icnet.yml
metafile: configs/icnet/metafile.yaml
model_configs:
- configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
pipelines:
Expand All @@ -305,7 +305,7 @@ models:
- *pipeline_ts_fp32

- name: ISANet
metafile: configs/isanet/isanet.yml
metafile: configs/isanet/metafile.yaml
model_configs:
- configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -314,7 +314,7 @@ models:
- *pipeline_openvino_static_fp32_512x512

- name: OCRNet
metafile: configs/ocrnet/ocrnet.yml
metafile: configs/ocrnet/metafile.yaml
model_configs:
- configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
pipelines:
Expand All @@ -325,7 +325,7 @@ models:
- *pipeline_ts_fp32

- name: PointRend
metafile: configs/point_rend/point_rend.yml
metafile: configs/point_rend/metafile.yaml
model_configs:
- configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
pipelines:
Expand All @@ -334,7 +334,7 @@ models:
- *pipeline_ts_fp32

- name: Semantic FPN
metafile: configs/sem_fpn/sem_fpn.yml
metafile: configs/sem_fpn/metafile.yaml
model_configs:
- configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
pipelines:
Expand All @@ -345,7 +345,7 @@ models:
- *pipeline_ts_fp32

- name: STDC
metafile: configs/stdc/stdc.yml
metafile: configs/stdc/metafile.yaml
model_configs:
- configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
- configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
Expand All @@ -357,14 +357,14 @@ models:
- *pipeline_ts_fp32

- name: UPerNet
metafile: configs/upernet/upernet.yml
metafile: configs/upernet/metafile.yaml
model_configs:
- configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
pipelines:
- *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16
- name: Segmenter
metafile: configs/segmenter/segmenter.yml
metafile: configs/segmenter/metafile.yaml
model_configs:
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
pipelines:
Expand Down
3 changes: 3 additions & 0 deletions tools/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
# get metric
model_info = meta_info[model_config_name]
metafile_metric_info = model_info['Results']
# deal with mmseg case
if not isinstance(metafile_metric_info, (list, tuple)):
metafile_metric_info = [metafile_metric_info]
pytorch_metric = dict()
using_dataset = set()
using_task = set()
Expand Down