Skip to content

Commit b291a1c

Browse files
committed
fix mmdet3d
1 parent 985a4f3 commit b291a1c

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

mmdeploy/codebase/mmdet3d/deploy/mono_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def create_input(
124124

125125
if data_preprocessor is not None:
126126
collate_data = data_preprocessor(collate_data, False)
127-
inputs = collate_data['inputs']
127+
inputs = collate_data['inputs']['imgs']
128128
else:
129129
inputs = collate_data['inputs']
130130
return collate_data, inputs

mmdeploy/codebase/mmdet3d/models/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010
'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501
1111
)
1212
def basedetector__forward(self,
13-
inputs: list,
13+
voxels: torch.Tensor,
14+
num_points: torch.Tensor,
15+
coors: torch.Tensor,
1416
data_samples=None,
1517
**kwargs) -> Tuple[List[torch.Tensor]]:
1618
"""Extract features of images."""
1719

1820
batch_inputs_dict = {
1921
'voxels': {
20-
'voxels': inputs[0],
21-
'num_points': inputs[1],
22-
'coors': inputs[2]
22+
'voxels': voxels,
23+
'num_points': num_points,
24+
'coors': coors
2325
}
2426
}
2527
return self._forward(batch_inputs_dict, data_samples, **kwargs)
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from torch import Tensor
3+
24
from mmdeploy.core import FUNCTION_REWRITER
35

46

57
@FUNCTION_REWRITER.register_rewriter(
68
'mmdet3d.models.detectors.single_stage_mono3d.'
79
'SingleStageMono3DDetector.forward')
8-
def singlestagemono3ddetector__forward(self, inputs: list, **kwargs):
9-
"""Rewrite this func to r.
10+
def singlestagemono3ddetector__forward(self, inputs: Tensor, **kwargs):
11+
"""Rewrite to support feed inputs of Tensor type.
1012
1113
Args:
12-
inputs (dict): Input dict comprises `imgs`
14+
inputs (Tensor): Input image
1315
1416
Returns:
1517
list: two torch.Tensor
1618
"""
17-
x = self.extract_feat(inputs)
19+
20+
x = self.extract_feat({'imgs': inputs})
1821
results = self.bbox_head.forward(x)
1922
return results[0], results[1]

tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_pointpillars(backend_type: Backend):
157157
cfg=deploy_cfg,
158158
backend=deploy_cfg.backend_config.type,
159159
opset=deploy_cfg.onnx_config.opset_version):
160-
outputs = model.forward(data)
160+
outputs = model.forward(*data)
161161
assert len(outputs) == 3
162162

163163

0 commit comments

Comments
 (0)