-
Notifications
You must be signed in to change notification settings - Fork 675
[Feature] Support rtmdet for dev-1.x #1104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
14 changes: 14 additions & 0 deletions
14
configs/mmdet/detection/detection_tensorrt_static-640x640.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
_base_ = ['../_base_/base_tensorrt_static-300x300.py'] | ||
|
||
onnx_config = dict(input_shape=(640, 640)) | ||
|
||
backend_config = dict( | ||
common_config=dict(max_workspace_size=1 << 30), | ||
model_inputs=[ | ||
dict( | ||
input_shapes=dict( | ||
input=dict( | ||
min_shape=[1, 3, 640, 640], | ||
opt_shape=[1, 3, 640, 640], | ||
max_shape=[1, 3, 640, 640]))) | ||
]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Optional | ||
|
||
import torch | ||
from mmengine.config import ConfigDict | ||
from mmengine.structures import InstanceData | ||
from torch import Tensor | ||
|
||
from mmdeploy.codebase.mmdet import get_post_processing_params | ||
from mmdeploy.codebase.mmdet.models.layers import multiclass_nms | ||
from mmdeploy.core import FUNCTION_REWRITER | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter( | ||
func_name='mmdet.models.dense_heads.rtmdet_head.' | ||
'RTMDetHead.predict_by_feat') | ||
def rtmdet_head__predict_by_feat(ctx, | ||
self, | ||
cls_scores: List[Tensor], | ||
bbox_preds: List[Tensor], | ||
batch_img_metas: Optional[List[dict]] = None, | ||
cfg: Optional[ConfigDict] = None, | ||
rescale: bool = False, | ||
with_nms: bool = True) -> List[InstanceData]: | ||
"""Rewrite `predict_by_feat` of `RTMDet` for default backend. | ||
|
||
Rewrite this function to deploy model, transform network output for a | ||
batch into bbox predictions. | ||
|
||
Args: | ||
ctx: Context that contains original meta information. | ||
cls_scores (list[Tensor]): Classification scores for all | ||
scale levels, each is a 4D-tensor, has shape | ||
(batch_size, num_priors * num_classes, H, W). | ||
bbox_preds (list[Tensor]): Box energies / deltas for all | ||
scale levels, each is a 4D-tensor, has shape | ||
(batch_size, num_priors * 4, H, W). | ||
batch_img_metas (list[dict], Optional): Batch image meta info. | ||
Defaults to None. | ||
cfg (ConfigDict, optional): Test / postprocessing | ||
configuration, if None, test_cfg would be used. | ||
Defaults to None. | ||
rescale (bool): If True, return boxes in original image space. | ||
Defaults to False. | ||
with_nms (bool): If True, do nms before return boxes. | ||
Defaults to True. | ||
|
||
Returns: | ||
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor, | ||
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch | ||
size and the score between 0 and 1. The shape of the second | ||
tensor in the tuple is (N, num_box), and each element | ||
represents the class label of the corresponding box. | ||
""" | ||
assert len(cls_scores) == len(bbox_preds) | ||
device = cls_scores[0].device | ||
cfg = self.test_cfg if cfg is None else cfg | ||
batch_size = bbox_preds[0].shape[0] | ||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] | ||
mlvl_priors = self.prior_generator.grid_priors( | ||
featmap_sizes, device=device) | ||
|
||
flatten_cls_scores = [ | ||
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, | ||
self.cls_out_channels) | ||
for cls_score in cls_scores | ||
] | ||
flatten_bbox_preds = [ | ||
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) | ||
for bbox_pred in bbox_preds | ||
] | ||
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() | ||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) | ||
priors = torch.cat(mlvl_priors) | ||
tl_x = (priors[..., 0] - flatten_bbox_preds[..., 0]) | ||
tl_y = (priors[..., 1] - flatten_bbox_preds[..., 1]) | ||
br_x = (priors[..., 0] + flatten_bbox_preds[..., 2]) | ||
br_y = (priors[..., 1] + flatten_bbox_preds[..., 3]) | ||
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) | ||
# directly multiply score factor and feed to nms | ||
max_scores, _ = torch.max(flatten_cls_scores, 1) | ||
mask = max_scores >= cfg.score_thr | ||
scores = flatten_cls_scores.where(mask, flatten_cls_scores.new_zeros(1)) | ||
if not with_nms: | ||
return bboxes, scores | ||
|
||
deploy_cfg = ctx.cfg | ||
post_params = get_post_processing_params(deploy_cfg) | ||
max_output_boxes_per_class = post_params.max_output_boxes_per_class | ||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) | ||
score_threshold = cfg.get('score_thr', post_params.score_threshold) | ||
pre_top_k = post_params.pre_top_k | ||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) | ||
|
||
return multiclass_nms(bboxes, scores, max_output_boxes_per_class, | ||
iou_threshold, score_threshold, pre_top_k, | ||
keep_top_k) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since RTMDetHead share the same
predict_by_feat
with base_dense_head, we can export the model without this rewriter. Is there any optimization in this rewriter?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the same logic in YOLOX's predict_by_feat for a fair comparison with YOLO series.