4
4
from copy import deepcopy
5
5
6
6
from mmengine import Config
7
+ from torch .utils .data import DataLoader
7
8
8
9
from mmdeploy .apis .utils import build_task_processor
9
10
from mmdeploy .utils import get_root_logger , load_config
@@ -31,9 +32,11 @@ def get_table(onnx_path: str,
31
32
from quant_image_dataset import QuantizationImageDataset
32
33
dataset = QuantizationImageDataset (
33
34
path = image_dir , deploy_cfg = deploy_cfg , model_cfg = model_cfg )
34
- calib_dataloader ['dataset' ] = dataset
35
- dataloader = task_processor .build_dataloader (calib_dataloader )
36
- # dataloader = DataLoader(dataset, batch_size=1)
35
+
36
+ def collate (data_batch ):
37
+ return data_batch [0 ]
38
+
39
+ dataloader = DataLoader (dataset , batch_size = 1 , collate_fn = collate )
37
40
else :
38
41
dataset = task_processor .build_dataset (calib_dataloader ['dataset' ])
39
42
calib_dataloader ['dataset' ] = dataset
@@ -44,16 +47,10 @@ def get_table(onnx_path: str,
44
47
# get an available input shape randomly
45
48
for _ , input_data in enumerate (dataloader ):
46
49
input_data = data_preprocessor (input_data )
47
- input_tensor = input_data [0 ]
48
- if isinstance (input_tensor , list ):
49
- input_shape = input_tensor [0 ].shape
50
- collate_fn = lambda x : data_preprocessor (x [0 ])[0 ].to ( # noqa: E731
51
- device )
52
- else :
53
- input_shape = input_tensor .shape
54
- collate_fn = lambda x : data_preprocessor (x )[0 ].to ( # noqa: E731
55
- device )
56
- break
50
+ input_tensor = input_data ['inputs' ]
51
+ input_shape = input_tensor .shape
52
+ collate_fn = lambda x : data_preprocessor (x )['inputs' ].to ( # noqa: E731
53
+ device )
57
54
58
55
from ppq import QuantizationSettingFactory , TargetPlatform
59
56
from ppq .api import export_ppq_graph , quantize_onnx_model
0 commit comments