File tree Expand file tree Collapse file tree 4 files changed +15
-10
lines changed
mmdeploy/codebase/mmdet3d
tests/test_codebase/test_mmdet3d Expand file tree Collapse file tree 4 files changed +15
-10
lines changed Original file line number Diff line number Diff line change @@ -124,7 +124,7 @@ def create_input(
124
124
125
125
if data_preprocessor is not None :
126
126
collate_data = data_preprocessor (collate_data , False )
127
- inputs = collate_data ['inputs' ]
127
+ inputs = collate_data ['inputs' ][ 'imgs' ]
128
128
else :
129
129
inputs = collate_data ['inputs' ]
130
130
return collate_data , inputs
Original file line number Diff line number Diff line change 10
10
'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501
11
11
)
12
12
def basedetector__forward (self ,
13
- inputs : list ,
13
+ voxels : torch .Tensor ,
14
+ num_points : torch .Tensor ,
15
+ coors : torch .Tensor ,
14
16
data_samples = None ,
15
17
** kwargs ) -> Tuple [List [torch .Tensor ]]:
16
18
"""Extract features of images."""
17
19
18
20
batch_inputs_dict = {
19
21
'voxels' : {
20
- 'voxels' : inputs [ 0 ] ,
21
- 'num_points' : inputs [ 1 ] ,
22
- 'coors' : inputs [ 2 ]
22
+ 'voxels' : voxels ,
23
+ 'num_points' : num_points ,
24
+ 'coors' : coors
23
25
}
24
26
}
25
27
return self ._forward (batch_inputs_dict , data_samples , ** kwargs )
Original file line number Diff line number Diff line change 1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
+ from torch import Tensor
3
+
2
4
from mmdeploy .core import FUNCTION_REWRITER
3
5
4
6
5
7
@FUNCTION_REWRITER .register_rewriter (
6
8
'mmdet3d.models.detectors.single_stage_mono3d.'
7
9
'SingleStageMono3DDetector.forward' )
8
- def singlestagemono3ddetector__forward (self , inputs : list , ** kwargs ):
9
- """Rewrite this func to r .
10
+ def singlestagemono3ddetector__forward (self , inputs : Tensor , ** kwargs ):
11
+ """Rewrite to support feed inputs of Tensor type .
10
12
11
13
Args:
12
- inputs (dict ): Input dict comprises `imgs`
14
+ inputs (Tensor ): Input image
13
15
14
16
Returns:
15
17
list: two torch.Tensor
16
18
"""
17
- x = self .extract_feat (inputs )
19
+
20
+ x = self .extract_feat ({'imgs' : inputs })
18
21
results = self .bbox_head .forward (x )
19
22
return results [0 ], results [1 ]
Original file line number Diff line number Diff line change @@ -157,7 +157,7 @@ def test_pointpillars(backend_type: Backend):
157
157
cfg = deploy_cfg ,
158
158
backend = deploy_cfg .backend_config .type ,
159
159
opset = deploy_cfg .onnx_config .opset_version ):
160
- outputs = model .forward (data )
160
+ outputs = model .forward (* data )
161
161
assert len (outputs ) == 3
162
162
163
163
You can’t perform that action at this time.
0 commit comments