Skip to content

Conversation

RangiLyu
Copy link
Member

@RangiLyu RangiLyu commented Jan 16, 2023

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

RTMDet-Ins ONNX and TensorRT support.

Modification

  1. Add rewriter for RTMDetInsHead
  2. Fix the rescale of the object detection model class(need to check other models)

Inference Results(TRT)

Please use configs/mmdet/instance-seg/instance-seg_rtmdet-ins_tensorrt_static-640x640.py

image

mask AP is aligned with pytorch RTMDet-tiny

Evaluate annotation type *segm*
DONE (t=38.74s).
Accumulating evaluation results...
DONE (t=7.01s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.354
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.550
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.376
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.131
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.383
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.567
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.302
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.475
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.502
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.252
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.569
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.718
01/17 12:42:14 - mmengine - INFO - segm_mAP_copypaste: 0.354 0.550 0.376 0.131 0.383 0.567

Notice

Currently only support single batch inference.

@CLAassistant
Copy link

CLAassistant commented Jan 16, 2023

CLA assistant check
All committers have signed the CLA.

@RangiLyu RangiLyu changed the title [WIP][Feature] Support RTMDet-Ins. [Feature] Support RTMDet-Ins. Jan 17, 2023
@1095788063
Copy link

转换后推理无结果,还有转换时的推理图片报错和位置也不对。

@1095788063
Copy link

image

@1095788063
Copy link

image

@1095788063
Copy link

模型已经加载成功,但是推理每个结果返回
image

@1095788063
Copy link

image
没有结果返回,但是却有耗时

@1095788063
Copy link

D:\Python\Python38\python.exe D:\ayjdata\Code\Deep_learning\OpenMMLab\mmdeploy\mmdeploy-rtmdet_ins_2023-01-18\tools\check_env.py
01/18 15:30:06 - mmengine - INFO -

01/18 15:30:06 - mmengine - INFO - Environmental information
01/18 15:30:09 - mmengine - INFO - sys.platform: win32
01/18 15:30:09 - mmengine - INFO - Python: 3.8.10 (tags/v3.8.10:3d8993a, May 3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)]
01/18 15:30:09 - mmengine - INFO - CUDA available: True
01/18 15:30:09 - mmengine - INFO - numpy_random_seed: 2147483648
01/18 15:30:09 - mmengine - INFO - GPU 0: NVIDIA GeForce RTX 3070
01/18 15:30:09 - mmengine - INFO - CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3
01/18 15:30:09 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.3, V11.3.58
01/18 15:30:09 - mmengine - INFO - MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.29.30147 版
01/18 15:30:09 - mmengine - INFO - GCC: n/a
01/18 15:30:09 - mmengine - INFO - PyTorch: 1.11.0+cu113
01/18 15:30:09 - mmengine - INFO - PyTorch compiling details: PyTorch built with:

  • C++ Version: 199711
  • MSVC 192829337
  • Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v2.5.2 (Git Hash a9302535553c73243c632ad3c4c80beec3d19a1e)
  • OpenMP 2019
  • LAPACK is enabled (usually provided by MKL)
  • CPU capability usage: AVX2
  • CUDA Runtime 11.3
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  • CuDNN 8.2
  • Magma 2.5.4
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=C:/actions-runner/_work/pytorch/pytorch/builder/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:experimental -IC:/actions-runner/_work/pytorch/pytorch/builder/windows/mkl/include -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF,

01/18 15:30:09 - mmengine - INFO - TorchVision: 0.12.0+cu113
01/18 15:30:09 - mmengine - INFO - OpenCV: 4.5.5
01/18 15:30:09 - mmengine - INFO - MMEngine: 0.4.0
01/18 15:30:09 - mmengine - INFO - MMCV: 2.0.0rc3
01/18 15:30:09 - mmengine - INFO - MMCV Compiler: MSVC 192829924
01/18 15:30:09 - mmengine - INFO - MMCV CUDA Compiler: 11.3
01/18 15:30:09 - mmengine - INFO - MMDeploy: 1.0.0rc1+
01/18 15:30:09 - mmengine - INFO -

01/18 15:30:09 - mmengine - INFO - Backend information
01/18 15:30:09 - mmengine - INFO - tensorrt: 8.2.5.1
01/18 15:30:09 - mmengine - INFO - tensorrt custom ops: Available
01/18 15:30:10 - mmengine - INFO - ONNXRuntime: 1.8.1
01/18 15:30:10 - mmengine - INFO - ONNXRuntime-gpu: None
01/18 15:30:10 - mmengine - INFO - ONNXRuntime custom ops: NotAvailable
01/18 15:30:10 - mmengine - INFO - pplnn: None
01/18 15:30:10 - mmengine - INFO - ncnn: None
01/18 15:30:10 - mmengine - INFO - snpe: None
01/18 15:30:10 - mmengine - INFO - openvino: None
01/18 15:30:10 - mmengine - INFO - torchscript: 1.11.0+cu113
01/18 15:30:10 - mmengine - INFO - torchscript custom ops: Available
01/18 15:30:10 - mmengine - INFO - rknn-toolkit: None
01/18 15:30:10 - mmengine - INFO - rknn-toolkit2: None
01/18 15:30:10 - mmengine - INFO - ascend: None
01/18 15:30:10 - mmengine - INFO - coreml: None
01/18 15:30:10 - mmengine - INFO - tvm: None
01/18 15:30:10 - mmengine - INFO -

01/18 15:30:10 - mmengine - INFO - Codebase information
01/18 15:30:10 - mmengine - INFO - mmdet: 3.0.0rc5
01/18 15:30:10 - mmengine - INFO - mmseg: None
01/18 15:30:10 - mmengine - INFO - mmcls: None
01/18 15:30:10 - mmengine - INFO - mmocr: None
01/18 15:30:10 - mmengine - INFO - mmedit: None
01/18 15:30:10 - mmengine - INFO - mmdet3d: None
01/18 15:30:10 - mmengine - INFO - mmpose: None
01/18 15:30:10 - mmengine - INFO - mmrotate: None
01/18 15:30:10 - mmengine - INFO - mmaction: None

进程已结束,退出代码0

@mattiasbax
Copy link

I also see weird output converting to onnx using the demo input with provided pretrained checkpoints:

i.e:
python ./tools/deploy.py configs/mmdet/instance-seg/instance-seg_onnxruntime_static.py ../mmdetection/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py ../mmdetection/checkpoints/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth demo/resources/det.jpg --work-dir work_dir/RTMInSeg --device cuda:0

produces:

output_onnxruntime

@hanrui1sensetime
Copy link
Collaborator

Seems get a wrong result.

python ./tools/deploy.py configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py ../../mmdetection/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth demo/resources/det.jpg --work-dir rtmdet-ins-trt --device cuda:0

image
image

@RangiLyu
Copy link
Member Author

RangiLyu commented Feb 2, 2023

@mattiasbax @hanrui1sensetime @1095788063 Try to use configs/mmdet/instance-seg/instance-seg_rtmdet-ins_onnxruntime_static-640x640.py and configs/mmdet/instance-seg/instance-seg_rtmdet-ins_tensorrt_static-640x640.py

@RangiLyu
Copy link
Member Author

RangiLyu commented Feb 2, 2023

@lvhan028 @hanrui1sensetime request review

Copy link
Collaborator

@hanrui1sensetime hanrui1sensetime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May update UT, regression tests, and docs.

@hanrui1sensetime
Copy link
Collaborator

Using configs you have mentioned above, the visualize is still failed:
Error log is that

  File "/opt/conda/lib/python3.8/site-packages/mmengine/visualization/visualizer.py", line 828, in draw_binary_masks
    assert img.shape[:2] == binary_masks.shape[
AssertionError: `binary_marks` must have the same shape with image

My script is:

python tools/deploy.py configs/mmdet/instance-seg/instance-seg_rtmdet-ins_onnxruntime_static-640x640.py ../mmdetection/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth ../mmdetection/demo/demo.jpg --work-dir ./rtmdet-ins-ort --device cuda:0 --dump-info

In this script: img.shape[:2] is (427, 640), while binary_mask.shape[1:] is (427, 427), FYI.

@RangiLyu
Copy link
Member Author

RangiLyu commented Feb 3, 2023

@AllentDan Please review mmdeploy/codebase/mmdet/deploy/object_detection_model.py

@RangiLyu
Copy link
Member Author

RangiLyu commented Feb 3, 2023

May update UT, regression tests, and docs.

Unit tests and docs will be added in another PR later.

@RangiLyu
Copy link
Member Author

RangiLyu commented Feb 3, 2023

I also see weird output converting to onnx using the demo input with provided pretrained checkpoints:

i.e: python ./tools/deploy.py configs/mmdet/instance-seg/instance-seg_onnxruntime_static.py ../mmdetection/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py ../mmdetection/checkpoints/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth demo/resources/det.jpg --work-dir work_dir/RTMInSeg --device cuda:0

produces:

output_onnxruntime

We have fixed a hard code bug in MMDeploy and now the visualized result should be correct.

Copy link
Collaborator

@hanrui1sensetime hanrui1sensetime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Member

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RTMDet always creates inputs with static shapes. So there is no need to specify a shape in deploy_cfg. Please set input_shape=None keep pipeline[i].keep_ratio = False.

@RangiLyu RangiLyu requested review from AllentDan and removed request for grimoire February 3, 2023 06:56
Copy link
Member

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit 87e4814 into open-mmlab:dev-1.x Feb 6, 2023
lvhan028 pushed a commit that referenced this pull request Mar 1, 2023
* [Feature] Support RTMDet-Ins.

* fix visualize

* fix rewrite trt

* add config

* support torch 1.13

* fix keep ratio resize

* resolve scale factor bug

* set to None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants