From fe7aef1a2d0dc42430f8d1645a3539748f9ad482 Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 11 Dec 2023 12:32:11 +0800 Subject: [PATCH 01/10] support ONNX&TensorRT exportation of RTMO --- ...detection_yolox-pose_onnxruntime_static.py | 13 +++ ...olox-pose_tensorrt-fp16_dynamic-640x640.py | 35 ++++++ .../codebase/mmpose/deploy/pose_detection.py | 36 +++++++ .../mmpose/deploy/pose_detection_model.py | 6 ++ .../codebase/mmpose/models/heads/__init__.py | 5 +- .../codebase/mmpose/models/heads/rtmo_head.py | 100 ++++++++++++++++++ setup.py | 4 +- 7 files changed, 195 insertions(+), 4 deletions(-) create mode 100644 configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py create mode 100644 configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py create mode 100644 mmdeploy/codebase/mmpose/models/heads/rtmo_head.py diff --git a/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py b/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py new file mode 100644 index 0000000000..f95c0ef9b9 --- /dev/null +++ b/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py @@ -0,0 +1,13 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict(output_names=['dets', 'keypoints']) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) diff --git a/configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py b/configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py new file mode 100644 index 0000000000..c46ab9cbad --- /dev/null +++ b/configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py @@ -0,0 +1,35 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt-fp16.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 640, 640], + opt_shape=[1, 3, 640, 640], + max_shape=[1, 3, 640, 640]))) + ]) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 5e6b0c5c6f..9621b92f38 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import os from collections import defaultdict from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -179,6 +180,41 @@ def build_backend_model( **kwargs) return model.eval().to(self.device) + def build_pytorch_model(self, + model_checkpoint: Optional[str] = None, + cfg_options: Optional[Dict] = None, + **kwargs) -> torch.nn.Module: + """Initialize torch model and switch to deploy mode. + + Args: + model_checkpoint (str): The checkpoint file of torch model, + defaults to `None`. + cfg_options (dict): Optional config key-pair parameters. + + Returns: + nn.Module: An initialized torch model generated by other OpenMMLab + codebases. + """ + # Initialize the PyTorch model using parent class method + torch_model = super().build_pytorch_model(model_checkpoint, + cfg_options, **kwargs) + + # Check if called from torch2onnx within 'apis/pytorch2onnx.py' + callers = inspect.stack() + is_torch2onnx_call = ( + len(callers) > 1 and callers[1].function == 'torch2onnx' + and callers[1].filename.endswith( + os.path.join('apis', 'pytorch2onnx.py'))) + + # If model has a 'switch_to_deploy' method and is called from + # torch2onnx, activate this method + if is_torch2onnx_call and hasattr(torch_model, + 'switch_to_deploy') and callable( + torch_model.switch_to_deploy): + torch_model.switch_to_deploy() + + return torch_model + def create_input(self, imgs: Union[str, np.ndarray, Sequence], input_shape: Sequence[int] = None, diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index a2ec9a21ad..fc460b8a42 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -101,6 +101,10 @@ def forward(self, codebase_cfg = get_codebase_config(self.deploy_cfg) codec = self.model_cfg.codec + + if codec.type == 'YOLOXPoseAnnotationProcessor': + return self.pack_yolox_pose_result(batch_outputs, data_samples) + if isinstance(codec, (list, tuple)): codec = codec[-1] @@ -184,6 +188,8 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor], """ assert preds[0].shape[0] == len(data_samples) batched_dets, batched_kpts = preds + # print(f'batched_dets:\n {batched_dets.flatten()[:10]}\n\n') + # print(f'batched_kpts:\n {batched_kpts.flatten()[:10]}\n\n') for data_sample_idx, data_sample in enumerate(data_samples): bboxes = batched_dets[data_sample_idx, :, :4] bbox_scores = batched_dets[data_sample_idx, :, 4] diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index 10bd18a0d9..d276162aff 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403 +from . import simcc_head # noqa: F401,F403 +from . import mspn_head, rtmo_head, yolox_pose_head -__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head'] +__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py new file mode 100644 index 0000000000..0481c57927 --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from mmpose.structures.bbox import bbox_xyxy2cs +from torch import Tensor + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmcv.ops.nms import multiclass_nms +from mmdeploy.utils import Backend, get_backend + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmpose.models.heads.hybrid_heads.' + 'rtmo_head.RTMOHead.forward') +def predict(self, + x: Tuple[Tensor], + batch_data_samples: List = [], + test_cfg: Optional[dict] = None): + """Get predictions and transform to bbox and keypoints results. + Args: + x (Tuple[Tensor]): The input tensor from upstream network. + batch_data_samples: Batch image meta info. Defaults to None. + test_cfg: The runtime config for testing process. + + Returns: + Tuple[Tensor]: Predict bbox and keypoint results. + - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, + has shape (batch_size, num_instances, 5), the last dimension 5 + arrange as (x1, y1, x2, y2, score). + - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D + Tensor, has shape (batch_size, num_instances, num_keypoints, 5), + the last dimension 3 arrange as (x, y, score). + """ + + # deploy context + ctx = FUNCTION_REWRITER.get_context() + backend = get_backend(ctx.cfg) + deploy_cfg = ctx.cfg + + cfg = self.test_cfg if test_cfg is None else test_cfg + + # get predictions + cls_scores, bbox_preds, _, kpt_vis, pose_vecs = self.head_module(x)[:5] + assert len(cls_scores) == len(bbox_preds) + num_imgs = cls_scores[0].shape[0] + + # flatten cls_scores, bbox_preds and objectness + scores = self._flatten_predictions(cls_scores).sigmoid() + flatten_bbox_preds = self._flatten_predictions(bbox_preds) + flatten_pose_vecs = self._flatten_predictions(pose_vecs) + flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid() + bboxes = self.decode_bbox(flatten_bbox_preds, self.flatten_priors, + self.flatten_stride) + + if backend == Backend.TENSORRT: + # pad for batched_nms because its output index is filled with -1 + bboxes = torch.cat( + [bboxes, + bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))], + dim=1) + + scores = torch.cat( + [scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1) + + # nms parameters + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.get('nms_thr', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.get('pre_top_k', -1) + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + # do nms + _, _, nms_indices = multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=True) + + batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) + + # filter bounding boxes + dets = torch.cat([bboxes, scores], dim=2) + dets = dets[batch_inds, nms_indices, ...] + bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], 1.25), dim=-1) + + # filter and decode keypoints + pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...] + kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...] + grids = self.flatten_priors[nms_indices, ...] + keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids) + pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1) + + return dets, pred_kpts diff --git a/setup.py b/setup.py index bd6cf1dc42..f981415112 100644 --- a/setup.py +++ b/setup.py @@ -142,7 +142,7 @@ def get_extensions(): # argument if platform.system() != 'Windows': if parse_version(torch.__version__) <= parse_version('1.12.1'): - extra_compile_args['cxx'] = ['-std=c++14'] + extra_compile_args['cxx'] = ['-std=c++11'] else: extra_compile_args['cxx'] = ['-std=c++17'] @@ -165,7 +165,7 @@ def get_extensions(): # argument if 'nvcc' in extra_compile_args and platform.system() != 'Windows': if parse_version(torch.__version__) <= parse_version('1.12.1'): - extra_compile_args['nvcc'] += ['-std=c++14'] + extra_compile_args['nvcc'] += ['-std=c++11'] else: extra_compile_args['nvcc'] += ['-std=c++17'] From 6ac11026185cc7c50191b7ac0e37158a59e53434 Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 11 Dec 2023 12:55:29 +0800 Subject: [PATCH 02/10] add configs for rtmo --- ...pose-detection_rtmo_onnxruntime_dynamic.py | 25 +++++++++++++++++++ ...ion_rtmo_tensorrt-fp16_dynamic-640x640.py} | 5 ++-- ...detection_yolox-pose_onnxruntime_static.py | 13 ---------- .../mmpose/deploy/pose_detection_model.py | 6 ----- setup.py | 4 +-- 5 files changed, 30 insertions(+), 23 deletions(-) create mode 100644 configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py rename configs/mmpose/{pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py => pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py} (94%) delete mode 100644 configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py diff --git a/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py b/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py new file mode 100644 index 0000000000..c1fbdaaeb0 --- /dev/null +++ b/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py @@ -0,0 +1,25 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=2000, + keep_top_k=50, + background_label_id=-1, + )) diff --git a/configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py b/configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py similarity index 94% rename from configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py rename to configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py index c46ab9cbad..cedc8f7097 100644 --- a/configs/mmpose/pose-detection_yolox-pose_tensorrt-fp16_dynamic-640x640.py +++ b/configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py @@ -13,6 +13,7 @@ 0: 'batch' } }) + backend_config = dict( common_config=dict(max_workspace_size=1 << 30), model_inputs=[ @@ -29,7 +30,7 @@ score_threshold=0.05, iou_threshold=0.5, max_output_boxes_per_class=200, - pre_top_k=5000, - keep_top_k=100, + pre_top_k=2000, + keep_top_k=50, background_label_id=-1, )) diff --git a/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py b/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py deleted file mode 100644 index f95c0ef9b9..0000000000 --- a/configs/mmpose/pose-detection_yolox-pose_onnxruntime_static.py +++ /dev/null @@ -1,13 +0,0 @@ -_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] - -onnx_config = dict(output_names=['dets', 'keypoints']) - -codebase_config = dict( - post_processing=dict( - score_threshold=0.05, - iou_threshold=0.5, - max_output_boxes_per_class=200, - pre_top_k=5000, - keep_top_k=100, - background_label_id=-1, - )) diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index fc460b8a42..a2ec9a21ad 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -101,10 +101,6 @@ def forward(self, codebase_cfg = get_codebase_config(self.deploy_cfg) codec = self.model_cfg.codec - - if codec.type == 'YOLOXPoseAnnotationProcessor': - return self.pack_yolox_pose_result(batch_outputs, data_samples) - if isinstance(codec, (list, tuple)): codec = codec[-1] @@ -188,8 +184,6 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor], """ assert preds[0].shape[0] == len(data_samples) batched_dets, batched_kpts = preds - # print(f'batched_dets:\n {batched_dets.flatten()[:10]}\n\n') - # print(f'batched_kpts:\n {batched_kpts.flatten()[:10]}\n\n') for data_sample_idx, data_sample in enumerate(data_samples): bboxes = batched_dets[data_sample_idx, :, :4] bbox_scores = batched_dets[data_sample_idx, :, 4] diff --git a/setup.py b/setup.py index f981415112..bd6cf1dc42 100644 --- a/setup.py +++ b/setup.py @@ -142,7 +142,7 @@ def get_extensions(): # argument if platform.system() != 'Windows': if parse_version(torch.__version__) <= parse_version('1.12.1'): - extra_compile_args['cxx'] = ['-std=c++11'] + extra_compile_args['cxx'] = ['-std=c++14'] else: extra_compile_args['cxx'] = ['-std=c++17'] @@ -165,7 +165,7 @@ def get_extensions(): # argument if 'nvcc' in extra_compile_args and platform.system() != 'Windows': if parse_version(torch.__version__) <= parse_version('1.12.1'): - extra_compile_args['nvcc'] += ['-std=c++11'] + extra_compile_args['nvcc'] += ['-std=c++14'] else: extra_compile_args['nvcc'] += ['-std=c++17'] From 21dc0ba5805c6596bede3d183ae40dd207eb99d4 Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 11 Dec 2023 12:59:35 +0800 Subject: [PATCH 03/10] replace bbox expansion factor with parameter bbox_padding --- mmdeploy/codebase/mmpose/models/heads/rtmo_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py index 0481c57927..4fe073751a 100644 --- a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py @@ -88,7 +88,7 @@ def predict(self, # filter bounding boxes dets = torch.cat([bboxes, scores], dim=2) dets = dets[batch_inds, nms_indices, ...] - bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], 1.25), dim=-1) + bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) # filter and decode keypoints pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...] From e669d1615bb913f2ea81601dd9a2c8530f49048b Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 11 Dec 2023 19:00:41 +0800 Subject: [PATCH 04/10] refine code --- mmdeploy/codebase/mmpose/models/heads/__init__.py | 3 +-- mmdeploy/codebase/mmpose/models/heads/rtmo_head.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index d276162aff..45ece714ad 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import simcc_head # noqa: F401,F403 -from . import mspn_head, rtmo_head, yolox_pose_head +from . import mspn_head, rtmo_head, simcc_head, yolox_pose_head __all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py index 4fe073751a..9c0926d24e 100644 --- a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py @@ -46,7 +46,7 @@ def predict(self, assert len(cls_scores) == len(bbox_preds) num_imgs = cls_scores[0].shape[0] - # flatten cls_scores, bbox_preds and objectness + # flatten and concat predictions scores = self._flatten_predictions(cls_scores).sigmoid() flatten_bbox_preds = self._flatten_predictions(bbox_preds) flatten_pose_vecs = self._flatten_predictions(pose_vecs) @@ -85,15 +85,15 @@ def predict(self, batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) - # filter bounding boxes + # filter predictions dets = torch.cat([bboxes, scores], dim=2) dets = dets[batch_inds, nms_indices, ...] - bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) - - # filter and decode keypoints pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...] kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...] grids = self.flatten_priors[nms_indices, ...] + + # filter and decode keypoints + bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids) pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1) From c3ec9813bc5fc3a64d4bbdab0cbf06228954f40f Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 11 Dec 2023 19:07:31 +0800 Subject: [PATCH 05/10] refine comment --- mmdeploy/codebase/mmpose/models/heads/rtmo_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py index 9c0926d24e..20bc748ac2 100644 --- a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py @@ -92,7 +92,7 @@ def predict(self, kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...] grids = self.flatten_priors[nms_indices, ...] - # filter and decode keypoints + # decode keypoints bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids) pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1) From 22078610e35a90c2cef39e6403cd1c394973fb18 Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 13 Dec 2023 14:03:35 +0800 Subject: [PATCH 06/10] apply model.switch_to_deploy in BaseTask.build_pytorch_model --- mmdeploy/codebase/base/task.py | 4 +++ .../codebase/mmpose/deploy/pose_detection.py | 36 ------------------- 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 048433e070..c52565dbbb 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -126,6 +126,10 @@ def build_pytorch_model(self, if hasattr(model, 'backbone') and hasattr(model.backbone, 'switch_to_deploy'): model.backbone.switch_to_deploy() + + if hasattr(model, 'switch_to_deploy') and callable(model.switch_to_deploy): + model.switch_to_deploy() + model = model.to(self.device) model.eval() return model diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 9621b92f38..5e6b0c5c6f 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -import inspect import os from collections import defaultdict from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -180,41 +179,6 @@ def build_backend_model( **kwargs) return model.eval().to(self.device) - def build_pytorch_model(self, - model_checkpoint: Optional[str] = None, - cfg_options: Optional[Dict] = None, - **kwargs) -> torch.nn.Module: - """Initialize torch model and switch to deploy mode. - - Args: - model_checkpoint (str): The checkpoint file of torch model, - defaults to `None`. - cfg_options (dict): Optional config key-pair parameters. - - Returns: - nn.Module: An initialized torch model generated by other OpenMMLab - codebases. - """ - # Initialize the PyTorch model using parent class method - torch_model = super().build_pytorch_model(model_checkpoint, - cfg_options, **kwargs) - - # Check if called from torch2onnx within 'apis/pytorch2onnx.py' - callers = inspect.stack() - is_torch2onnx_call = ( - len(callers) > 1 and callers[1].function == 'torch2onnx' - and callers[1].filename.endswith( - os.path.join('apis', 'pytorch2onnx.py'))) - - # If model has a 'switch_to_deploy' method and is called from - # torch2onnx, activate this method - if is_torch2onnx_call and hasattr(torch_model, - 'switch_to_deploy') and callable( - torch_model.switch_to_deploy): - torch_model.switch_to_deploy() - - return torch_model - def create_input(self, imgs: Union[str, np.ndarray, Sequence], input_shape: Sequence[int] = None, From 97231f87e1ec1b9721ac109ba9a2b05a9d454519 Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 13 Dec 2023 14:05:53 +0800 Subject: [PATCH 07/10] fix lint --- mmdeploy/codebase/base/task.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index c52565dbbb..71efbf545e 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -127,9 +127,10 @@ def build_pytorch_model(self, 'switch_to_deploy'): model.backbone.switch_to_deploy() - if hasattr(model, 'switch_to_deploy') and callable(model.switch_to_deploy): - model.switch_to_deploy() - + if hasattr(model, 'switch_to_deploy') and callable( + model.switch_to_deploy): + model.switch_to_deploy() + model = model.to(self.device) model.eval() return model From e0d716649c903353cefa6ceed663585b11da1a04 Mon Sep 17 00:00:00 2001 From: lupeng Date: Thu, 14 Dec 2023 11:11:59 +0800 Subject: [PATCH 08/10] add rtmo into regression test --- tests/regression/mmpose.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/regression/mmpose.yml b/tests/regression/mmpose.yml index ad3b7b8744..7e1b461936 100644 --- a/tests/regression/mmpose.yml +++ b/tests/regression/mmpose.yml @@ -150,3 +150,13 @@ models: input_img: *img_human_pose test_img: *img_human_pose deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py + + - name: RTMO + metafile: configs/body_2d_keypoint/rtmo/body7/rtmo_body7.yml + model_configs: + - configs/body_2d_keypoint/rtmo/body7/rtmo-s_8xb32-600e_body7-640x640.py + pipelines: + - convert_image: + input_img: *img_human_pose + test_img: *img_human_pose + deploy_config: configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py From 8080b4d0d2c2f1044c07c8bb3729bb5123543c8d Mon Sep 17 00:00:00 2001 From: lupeng Date: Thu, 14 Dec 2023 11:33:42 +0800 Subject: [PATCH 09/10] add rtmo with trt backend into regression test --- tests/regression/mmpose.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/regression/mmpose.yml b/tests/regression/mmpose.yml index 7e1b461936..41554a6622 100644 --- a/tests/regression/mmpose.yml +++ b/tests/regression/mmpose.yml @@ -160,3 +160,7 @@ models: input_img: *img_human_pose test_img: *img_human_pose deploy_config: configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py + - convert_image: + input_img: *img_human_pose + test_img: *img_human_pose + deploy_config: configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py From ce2658eedeca169c340d0617df474644ca884cdf Mon Sep 17 00:00:00 2001 From: lupeng Date: Thu, 14 Dec 2023 14:48:52 +0800 Subject: [PATCH 10/10] add rtmo into supported model list --- docs/en/04-supported-codebases/mmpose.md | 1 + docs/zh_cn/04-supported-codebases/mmpose.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/04-supported-codebases/mmpose.md b/docs/en/04-supported-codebases/mmpose.md index 6f6ee4ab50..8c822cebc9 100644 --- a/docs/en/04-supported-codebases/mmpose.md +++ b/docs/en/04-supported-codebases/mmpose.md @@ -161,3 +161,4 @@ TODO | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | | [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N | diff --git a/docs/zh_cn/04-supported-codebases/mmpose.md b/docs/zh_cn/04-supported-codebases/mmpose.md index 961ba31f22..617dbd670c 100644 --- a/docs/zh_cn/04-supported-codebases/mmpose.md +++ b/docs/zh_cn/04-supported-codebases/mmpose.md @@ -165,3 +165,4 @@ task_processor.visualize( | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | | [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N |