|
40 | 40 | use_ipex = False
|
41 | 41 | if os.environ.get('USE_IPEX') == "1":
|
42 | 42 | import intel_extension_for_pytorch as ipex
|
| 43 | + from intel_extension_for_pytorch.quantization import prepare, convert |
| 44 | + from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig |
43 | 45 | use_ipex = True
|
44 | 46 |
|
45 | 47 | def get_bs_per_stream(batch_size, stream_number):
|
@@ -238,38 +240,33 @@ def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args):
|
238 | 240 | start = time.time()
|
239 | 241 | if args.int8:
|
240 | 242 | model = model.eval()
|
241 |
| - model_decode = SSD_R34_NMS(model, encoder) |
| 243 | + model_decode = SSD_R34_NMS(model, encoder).eval() |
242 | 244 | print('int8 conv_bn_fusion enabled')
|
243 | 245 | with torch.no_grad():
|
244 | 246 | model_decode.model.model = optimization.fuse(model_decode.model.model, inplace=False)
|
245 |
| - |
| 247 | + qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), |
| 248 | + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) |
| 249 | + |
| 250 | + example_inputs = torch.randn( ((args.batch_size // args.number_instance) if args.use_multi_stream_module else args.batch_size), 3, 1200, 1200).to(memory_format=torch.channels_last) |
| 251 | + model_decode.model = prepare(model_decode.model, qconfig, example_inputs=example_inputs, inplace=False) |
246 | 252 | if args.calibration:
|
247 | 253 | print("runing int8 LLGA calibration step not support in throughput benchmark")
|
248 | 254 | else:
|
249 | 255 | print("INT8 LLGA start trace")
|
250 | 256 | # insert quant/dequant based on configure.json
|
251 |
| - conf = ipex.quantization.QuantConf(configure_file=args.configure) |
252 |
| - model_decode.eval() |
253 |
| - if args.use_multi_stream_module: |
254 |
| - batch_per_stream = args.batch_size // args.number_instance |
255 |
| - print("batch_per_stream for multi_stream_module is:", batch_per_stream) |
256 |
| - model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
257 |
| - else: |
258 |
| - model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
| 257 | + model_decode.model.load_qconf_summary(qconf_summary = args.configure) |
| 258 | + model_decode.model = convert(model_decode.model) |
| 259 | + with torch.no_grad(): |
| 260 | + model_decode = torch.jit.trace(model_decode, example_inputs, check_trace=False).eval() |
| 261 | + model_decode = torch.jit.freeze(model_decode) |
| 262 | + |
259 | 263 | print("done ipex default recipe.......................")
|
260 |
| - # freeze the module |
261 |
| - # model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True)) |
262 |
| - # model_decode = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model_decode._c, preserveParameters=True)) |
263 | 264 |
|
264 | 265 | # After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
|
265 | 266 | # At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
|
266 | 267 | with torch.no_grad():
|
267 | 268 | for i in range(2):
|
268 |
| - # _ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
269 |
| - if args.use_multi_stream_module: |
270 |
| - _ = model_decode(torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
271 |
| - else: |
272 |
| - _ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)) |
| 269 | + _ = model_decode(example_inputs) |
273 | 270 |
|
274 | 271 | if args.use_throughput_benchmark:
|
275 | 272 |
|
|
0 commit comments