-
Notifications
You must be signed in to change notification settings - Fork 675
[Feature] Support ONNX and TensorRT exportation of RTMO models #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
RunningLeon
merged 10 commits into
open-mmlab:main
from
Ben-Louis:lupeng/support-onestage-rtmpose
Dec 14, 2023
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
fe7aef1
support ONNX&TensorRT exportation of RTMO
Ben-Louis 6ac1102
add configs for rtmo
Ben-Louis 21dc0ba
replace bbox expansion factor with parameter bbox_padding
Ben-Louis e669d16
refine code
Ben-Louis c3ec981
refine comment
Ben-Louis 2207861
apply model.switch_to_deploy in BaseTask.build_pytorch_model
Ben-Louis 97231f8
fix lint
Ben-Louis e0d7166
add rtmo into regression test
Ben-Louis 8080b4d
add rtmo with trt backend into regression test
Ben-Louis ce2658e
add rtmo into supported model list
Ben-Louis File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] | ||
RunningLeon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
onnx_config = dict( | ||
output_names=['dets', 'keypoints'], | ||
dynamic_axes={ | ||
'input': { | ||
0: 'batch', | ||
}, | ||
'dets': { | ||
0: 'batch', | ||
}, | ||
'keypoints': { | ||
0: 'batch' | ||
} | ||
}) | ||
|
||
codebase_config = dict( | ||
post_processing=dict( | ||
score_threshold=0.05, | ||
iou_threshold=0.5, | ||
max_output_boxes_per_class=200, | ||
pre_top_k=2000, | ||
keep_top_k=50, | ||
background_label_id=-1, | ||
)) |
36 changes: 36 additions & 0 deletions
36
configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt-fp16.py'] | ||
|
||
onnx_config = dict( | ||
output_names=['dets', 'keypoints'], | ||
dynamic_axes={ | ||
'input': { | ||
0: 'batch', | ||
}, | ||
'dets': { | ||
0: 'batch', | ||
}, | ||
'keypoints': { | ||
0: 'batch' | ||
} | ||
}) | ||
|
||
backend_config = dict( | ||
common_config=dict(max_workspace_size=1 << 30), | ||
model_inputs=[ | ||
dict( | ||
input_shapes=dict( | ||
input=dict( | ||
min_shape=[1, 3, 640, 640], | ||
opt_shape=[1, 3, 640, 640], | ||
max_shape=[1, 3, 640, 640]))) | ||
]) | ||
|
||
codebase_config = dict( | ||
post_processing=dict( | ||
score_threshold=0.05, | ||
iou_threshold=0.5, | ||
max_output_boxes_per_class=200, | ||
pre_top_k=2000, | ||
keep_top_k=50, | ||
background_label_id=-1, | ||
)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403 | ||
from . import mspn_head, rtmo_head, simcc_head, yolox_pose_head | ||
|
||
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head'] | ||
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
RunningLeon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from mmpose.structures.bbox import bbox_xyxy2cs | ||
from torch import Tensor | ||
|
||
from mmdeploy.codebase.mmdet import get_post_processing_params | ||
from mmdeploy.core import FUNCTION_REWRITER | ||
from mmdeploy.mmcv.ops.nms import multiclass_nms | ||
from mmdeploy.utils import Backend, get_backend | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmpose.models.heads.hybrid_heads.' | ||
'rtmo_head.RTMOHead.forward') | ||
def predict(self, | ||
x: Tuple[Tensor], | ||
batch_data_samples: List = [], | ||
test_cfg: Optional[dict] = None): | ||
"""Get predictions and transform to bbox and keypoints results. | ||
Args: | ||
x (Tuple[Tensor]): The input tensor from upstream network. | ||
batch_data_samples: Batch image meta info. Defaults to None. | ||
test_cfg: The runtime config for testing process. | ||
|
||
Returns: | ||
Tuple[Tensor]: Predict bbox and keypoint results. | ||
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, | ||
has shape (batch_size, num_instances, 5), the last dimension 5 | ||
arrange as (x1, y1, x2, y2, score). | ||
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D | ||
Tensor, has shape (batch_size, num_instances, num_keypoints, 5), | ||
the last dimension 3 arrange as (x, y, score). | ||
""" | ||
|
||
# deploy context | ||
ctx = FUNCTION_REWRITER.get_context() | ||
backend = get_backend(ctx.cfg) | ||
deploy_cfg = ctx.cfg | ||
|
||
cfg = self.test_cfg if test_cfg is None else test_cfg | ||
|
||
# get predictions | ||
cls_scores, bbox_preds, _, kpt_vis, pose_vecs = self.head_module(x)[:5] | ||
assert len(cls_scores) == len(bbox_preds) | ||
num_imgs = cls_scores[0].shape[0] | ||
|
||
# flatten and concat predictions | ||
scores = self._flatten_predictions(cls_scores).sigmoid() | ||
flatten_bbox_preds = self._flatten_predictions(bbox_preds) | ||
flatten_pose_vecs = self._flatten_predictions(pose_vecs) | ||
flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid() | ||
bboxes = self.decode_bbox(flatten_bbox_preds, self.flatten_priors, | ||
self.flatten_stride) | ||
|
||
if backend == Backend.TENSORRT: | ||
# pad for batched_nms because its output index is filled with -1 | ||
bboxes = torch.cat( | ||
[bboxes, | ||
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))], | ||
dim=1) | ||
|
||
scores = torch.cat( | ||
[scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1) | ||
|
||
# nms parameters | ||
post_params = get_post_processing_params(deploy_cfg) | ||
max_output_boxes_per_class = post_params.max_output_boxes_per_class | ||
iou_threshold = cfg.get('nms_thr', post_params.iou_threshold) | ||
score_threshold = cfg.get('score_thr', post_params.score_threshold) | ||
pre_top_k = post_params.get('pre_top_k', -1) | ||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) | ||
|
||
# do nms | ||
_, _, nms_indices = multiclass_nms( | ||
bboxes, | ||
scores, | ||
max_output_boxes_per_class, | ||
iou_threshold, | ||
score_threshold, | ||
pre_top_k=pre_top_k, | ||
keep_top_k=keep_top_k, | ||
output_index=True) | ||
|
||
batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) | ||
|
||
# filter predictions | ||
dets = torch.cat([bboxes, scores], dim=2) | ||
dets = dets[batch_inds, nms_indices, ...] | ||
pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...] | ||
kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...] | ||
grids = self.flatten_priors[nms_indices, ...] | ||
|
||
# decode keypoints | ||
bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) | ||
keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids) | ||
pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1) | ||
|
||
return dets, pred_kpts |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.