|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +import copy |
| 3 | + |
| 4 | +import torch |
| 5 | +from mmdet.models.detectors.base import ForwardResults |
| 6 | +from mmdet.structures import DetDataSample |
| 7 | +from mmdet.structures.det_data_sample import OptSampleList |
| 8 | + |
| 9 | +from mmdeploy.core import FUNCTION_REWRITER, mark |
| 10 | +from mmdeploy.utils import is_dynamic_shape |
| 11 | + |
| 12 | + |
| 13 | +@mark('detr_predict', inputs=['input'], outputs=['dets', 'labels', 'masks']) |
| 14 | +def __predict_impl(self, batch_inputs, data_samples, rescale): |
| 15 | + """Rewrite and adding mark for `predict`. |
| 16 | +
|
| 17 | + Encapsulate this function for rewriting `predict` of DetectionTransformer. |
| 18 | + 1. Add mark for DetectionTransformer. |
| 19 | + 2. Support both dynamic and static export to onnx. |
| 20 | + """ |
| 21 | + img_feats = self.extract_feat(batch_inputs) |
| 22 | + head_inputs_dict = self.forward_transformer(img_feats, data_samples) |
| 23 | + results_list = self.bbox_head.predict( |
| 24 | + **head_inputs_dict, rescale=rescale, batch_data_samples=data_samples) |
| 25 | + return results_list |
| 26 | + |
| 27 | + |
| 28 | +@torch.fx.wrap |
| 29 | +def _set_metainfo(data_samples, img_shape): |
| 30 | + """Set the metainfo. |
| 31 | +
|
| 32 | + Code in this function cannot be traced by fx. |
| 33 | + """ |
| 34 | + |
| 35 | + # fx can not trace deepcopy correctly |
| 36 | + data_samples = copy.deepcopy(data_samples) |
| 37 | + if data_samples is None: |
| 38 | + data_samples = [DetDataSample()] |
| 39 | + |
| 40 | + # note that we can not use `set_metainfo`, deepcopy would crash the |
| 41 | + # onnx trace. |
| 42 | + for data_sample in data_samples: |
| 43 | + data_sample.set_field( |
| 44 | + name='img_shape', value=img_shape, field_type='metainfo') |
| 45 | + |
| 46 | + return data_samples |
| 47 | + |
| 48 | + |
| 49 | +@FUNCTION_REWRITER.register_rewriter( |
| 50 | + 'mmdet.models.detectors.base_detr.DetectionTransformer.predict') |
| 51 | +def detection_transformer__predict(self, |
| 52 | + batch_inputs: torch.Tensor, |
| 53 | + data_samples: OptSampleList = None, |
| 54 | + rescale: bool = True, |
| 55 | + **kwargs) -> ForwardResults: |
| 56 | + """Rewrite `predict` for default backend. |
| 57 | +
|
| 58 | + Support configured dynamic/static shape for model input and return |
| 59 | + detection result as Tensor instead of numpy array. |
| 60 | +
|
| 61 | + Args: |
| 62 | + batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
| 63 | + data_samples (List[:obj:`DetDataSample`]): The Data |
| 64 | + Samples. It usually includes information such as |
| 65 | + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
| 66 | + rescale (Boolean): rescale result or not. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + tuple[Tensor]: Detection results of the |
| 70 | + input images. |
| 71 | + - dets (Tensor): Classification bboxes and scores. |
| 72 | + Has a shape (num_instances, 5) |
| 73 | + - labels (Tensor): Labels of bboxes, has a shape |
| 74 | + (num_instances, ). |
| 75 | + """ |
| 76 | + ctx = FUNCTION_REWRITER.get_context() |
| 77 | + |
| 78 | + deploy_cfg = ctx.cfg |
| 79 | + |
| 80 | + # get origin input shape as tensor to support onnx dynamic shape |
| 81 | + is_dynamic_flag = is_dynamic_shape(deploy_cfg) |
| 82 | + img_shape = torch._shape_as_tensor(batch_inputs)[2:] |
| 83 | + if not is_dynamic_flag: |
| 84 | + img_shape = [int(val) for val in img_shape] |
| 85 | + |
| 86 | + # set the metainfo |
| 87 | + data_samples = _set_metainfo(data_samples, img_shape) |
| 88 | + |
| 89 | + return __predict_impl(self, batch_inputs, data_samples, rescale) |
0 commit comments