Skip to content

Commit 49103cb

Browse files
authored
[Fix] Fix dataloader for ncnn quantization (#2018)
* fix mmpose create_input when feed ndarray * fix dataloader for ncnn-ppq
1 parent 2d85be9 commit 49103cb

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

mmdeploy/codebase/mmpose/deploy/pose_detection.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,15 @@ def create_input(self,
197197
from mmcv.transforms import Compose
198198
from mmpose.registry import TRANSFORMS
199199
cfg = self.model_cfg
200-
if isinstance(imgs, str):
201-
imgs = [mmcv.imread(imgs)]
202-
elif isinstance(imgs, (list, tuple)):
203-
img_data = []
204-
for img in imgs:
205-
if isinstance(img, str):
206-
img_data.append(mmcv.imread(img))
207-
else:
208-
img_data.append(img)
200+
if isinstance(imgs, (list, tuple)):
201+
if not isinstance(imgs[0], (np.ndarray, str)):
202+
raise AssertionError('imgs must be strings or numpy arrays')
203+
elif isinstance(imgs, (np.ndarray, str)):
204+
imgs = [imgs]
205+
else:
206+
raise AssertionError('imgs must be strings or numpy arrays')
207+
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
208+
img_data = [mmcv.imread(img) for img in imgs]
209209
imgs = img_data
210210
person_results = []
211211
bboxes = []

tools/onnx2ncnn_quant_table.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from copy import deepcopy
55

66
from mmengine import Config
7+
from torch.utils.data import DataLoader
78

89
from mmdeploy.apis.utils import build_task_processor
910
from mmdeploy.utils import get_root_logger, load_config
@@ -31,9 +32,11 @@ def get_table(onnx_path: str,
3132
from quant_image_dataset import QuantizationImageDataset
3233
dataset = QuantizationImageDataset(
3334
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
34-
calib_dataloader['dataset'] = dataset
35-
dataloader = task_processor.build_dataloader(calib_dataloader)
36-
# dataloader = DataLoader(dataset, batch_size=1)
35+
36+
def collate(data_batch):
37+
return data_batch[0]
38+
39+
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate)
3740
else:
3841
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
3942
calib_dataloader['dataset'] = dataset
@@ -44,16 +47,10 @@ def get_table(onnx_path: str,
4447
# get an available input shape randomly
4548
for _, input_data in enumerate(dataloader):
4649
input_data = data_preprocessor(input_data)
47-
input_tensor = input_data[0]
48-
if isinstance(input_tensor, list):
49-
input_shape = input_tensor[0].shape
50-
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
51-
device)
52-
else:
53-
input_shape = input_tensor.shape
54-
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
55-
device)
56-
break
50+
input_tensor = input_data['inputs']
51+
input_shape = input_tensor.shape
52+
collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731
53+
device)
5754

5855
from ppq import QuantizationSettingFactory, TargetPlatform
5956
from ppq.api import export_ppq_graph, quantize_onnx_model

0 commit comments

Comments
 (0)