|
40 | 40 | use_ipex = False |
41 | 41 | if os.environ.get('USE_IPEX') == "1": |
42 | 42 | import intel_extension_for_pytorch as ipex |
| 43 | + from intel_extension_for_pytorch.quantization import prepare, convert |
| 44 | + from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig |
43 | 45 | use_ipex = True |
44 | 46 |
|
45 | 47 | def get_bs_per_stream(batch_size, stream_number): |
@@ -238,38 +240,33 @@ def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args): |
238 | 240 | start = time.time() |
239 | 241 | if args.int8: |
240 | 242 | model = model.eval() |
241 | | - model_decode = SSD_R34_NMS(model, encoder) |
| 243 | + model_decode = SSD_R34_NMS(model, encoder).eval() |
242 | 244 | print('int8 conv_bn_fusion enabled') |
243 | 245 | with torch.no_grad(): |
244 | 246 | model_decode.model.model = optimization.fuse(model_decode.model.model, inplace=False) |
245 | | - |
| 247 | + qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), |
| 248 | + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) |
| 249 | + |
| 250 | + example_inputs = torch.randn( ((args.batch_size // args.number_instance) if args.use_multi_stream_module else args.batch_size), 3, 1200, 1200).to(memory_format=torch.channels_last) |
| 251 | + model_decode.model = prepare(model_decode.model, qconfig, example_inputs=example_inputs, inplace=False) |
246 | 252 | if args.calibration: |
247 | 253 | print("runing int8 LLGA calibration step not support in throughput benchmark") |
248 | 254 | else: |
249 | 255 | print("INT8 LLGA start trace") |
250 | 256 | # insert quant/dequant based on configure.json |
251 | | - conf = ipex.quantization.QuantConf(configure_file=args.configure) |
252 | | - model_decode.eval() |
253 | | - if args.use_multi_stream_module: |
254 | | - batch_per_stream = args.batch_size // args.number_instance |
255 | | - print("batch_per_stream for multi_stream_module is:", batch_per_stream) |
256 | | - model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
257 | | - else: |
258 | | - model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
| 257 | + model_decode.model.load_qconf_summary(qconf_summary = args.configure) |
| 258 | + model_decode.model = convert(model_decode.model) |
| 259 | + with torch.no_grad(): |
| 260 | + model_decode = torch.jit.trace(model_decode, example_inputs, check_trace=False).eval() |
| 261 | + model_decode = torch.jit.freeze(model_decode) |
| 262 | + |
259 | 263 | print("done ipex default recipe.......................") |
260 | | - # freeze the module |
261 | | - # model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True)) |
262 | | - # model_decode = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model_decode._c, preserveParameters=True)) |
263 | 264 |
|
264 | 265 | # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile |
265 | 266 | # At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph |
266 | 267 | with torch.no_grad(): |
267 | 268 | for i in range(2): |
268 | | - # _ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
269 | | - if args.use_multi_stream_module: |
270 | | - _ = model_decode(torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
271 | | - else: |
272 | | - _ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
| 269 | + _ = model_decode(example_inputs) |
273 | 270 |
|
274 | 271 | if args.use_throughput_benchmark: |
275 | 272 |
|
|
0 commit comments