Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ def create_input(self,
from mmcv.transforms import Compose
from mmpose.registry import TRANSFORMS
cfg = self.model_cfg
if isinstance(imgs, str):
imgs = [mmcv.imread(imgs)]
elif isinstance(imgs, (list, tuple)):
img_data = []
for img in imgs:
if isinstance(img, str):
img_data.append(mmcv.imread(img))
else:
img_data.append(img)
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
img_data = [mmcv.imread(img) for img in imgs]
imgs = img_data
person_results = []
bboxes = []
Expand Down
23 changes: 10 additions & 13 deletions tools/onnx2ncnn_quant_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy

from mmengine import Config
from torch.utils.data import DataLoader

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config
Expand Down Expand Up @@ -31,9 +32,11 @@ def get_table(onnx_path: str,
from quant_image_dataset import QuantizationImageDataset
dataset = QuantizationImageDataset(
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
# dataloader = DataLoader(dataset, batch_size=1)

def collate(data_batch):
return data_batch[0]

dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate)
else:
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
calib_dataloader['dataset'] = dataset
Expand All @@ -44,16 +47,10 @@ def get_table(onnx_path: str,
# get an available input shape randomly
for _, input_data in enumerate(dataloader):
input_data = data_preprocessor(input_data)
input_tensor = input_data[0]
if isinstance(input_tensor, list):
input_shape = input_tensor[0].shape
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
device)
else:
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
device)
break
input_tensor = input_data['inputs']
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731
device)

from ppq import QuantizationSettingFactory, TargetPlatform
from ppq.api import export_ppq_graph, quantize_onnx_model
Expand Down