Skip to content

Commit 9eb3211

Browse files
ssd enable new int8 (#580)
* v1 * enable new int8 method
1 parent 0a07b93 commit 9eb3211

File tree

4 files changed

+26471
-13684
lines changed

4 files changed

+26471
-13684
lines changed

models/object_detection/pytorch/ssd-resnet34/inference/cpu/infer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
use_ipex = False
3636
if os.environ.get('USE_IPEX') == "1":
3737
import intel_extension_for_pytorch as ipex
38+
from intel_extension_for_pytorch.quantization import prepare, convert
39+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
3840
use_ipex = True
3941

4042

@@ -192,35 +194,36 @@ def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args):
192194
print('int8 conv_bn_fusion enabled')
193195
with torch.no_grad():
194196
model.model = optimization.fuse(model.model, inplace=False)
195-
197+
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
198+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
199+
example_inputs = torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last)
200+
prepared_model = prepare(model, qconfig, example_inputs=example_inputs, inplace=False)
196201
if args.calibration:
197202
print("runing int8 LLGA calibration step\n")
198-
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine) # qscheme can is torch.per_tensor_affine, torch.per_tensor_symmetric
199203
with torch.no_grad():
200204
for nbatch, (img, img_id, img_size, bbox, label) in enumerate(val_dataloader):
201205
print("nbatch:{}".format(nbatch))
202-
with ipex.quantization.calibrate(conf):
203-
ploc, plabel = model(img)
206+
ploc, plabel = prepared_model(img)
204207
if nbatch == args.iteration:
205208
break
206-
conf.save(args.configure)
209+
prepared_model.save_qconf_summary(qconf_summary = args.configure)
207210
return
208211

209212
else:
210213
print("INT8 LLGA start trace")
211214
# insert quant/dequant based on configure.json
212-
conf = ipex.quantization.QuantConf(configure_file=args.configure)
213-
model = ipex.quantization.convert(model, conf, torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
215+
prepared_model.load_qconf_summary(qconf_summary = args.configure)
216+
convert_model = convert(prepared_model)
217+
with torch.no_grad():
218+
model = torch.jit.trace(convert_model, example_inputs, check_trace=False).eval()
219+
model = torch.jit.freeze(model)
214220
print("done ipex default recipe.......................")
215-
# freeze the module
216-
# model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))
217221

218222
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
219223
# At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
220224
with torch.no_grad():
221225
for i in range(2):
222-
#_, _ = model(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
223-
_, _ = model(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
226+
_, _ = model(example_inputs)
224227

225228
print('runing int8 real inputs inference path')
226229
with torch.no_grad():

models/object_detection/pytorch/ssd-resnet34/inference/cpu/infer_weight_sharing.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
use_ipex = False
4141
if os.environ.get('USE_IPEX') == "1":
4242
import intel_extension_for_pytorch as ipex
43+
from intel_extension_for_pytorch.quantization import prepare, convert
44+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
4345
use_ipex = True
4446

4547
def get_bs_per_stream(batch_size, stream_number):
@@ -238,38 +240,33 @@ def coco_eval(model, val_dataloader, cocoGt, encoder, inv_map, args):
238240
start = time.time()
239241
if args.int8:
240242
model = model.eval()
241-
model_decode = SSD_R34_NMS(model, encoder)
243+
model_decode = SSD_R34_NMS(model, encoder).eval()
242244
print('int8 conv_bn_fusion enabled')
243245
with torch.no_grad():
244246
model_decode.model.model = optimization.fuse(model_decode.model.model, inplace=False)
245-
247+
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
248+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
249+
250+
example_inputs = torch.randn( ((args.batch_size // args.number_instance) if args.use_multi_stream_module else args.batch_size), 3, 1200, 1200).to(memory_format=torch.channels_last)
251+
model_decode.model = prepare(model_decode.model, qconfig, example_inputs=example_inputs, inplace=False)
246252
if args.calibration:
247253
print("runing int8 LLGA calibration step not support in throughput benchmark")
248254
else:
249255
print("INT8 LLGA start trace")
250256
# insert quant/dequant based on configure.json
251-
conf = ipex.quantization.QuantConf(configure_file=args.configure)
252-
model_decode.eval()
253-
if args.use_multi_stream_module:
254-
batch_per_stream = args.batch_size // args.number_instance
255-
print("batch_per_stream for multi_stream_module is:", batch_per_stream)
256-
model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last))
257-
else:
258-
model_decode = ipex.quantization.convert(model_decode, conf, torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
257+
model_decode.model.load_qconf_summary(qconf_summary = args.configure)
258+
model_decode.model = convert(model_decode.model)
259+
with torch.no_grad():
260+
model_decode = torch.jit.trace(model_decode, example_inputs, check_trace=False).eval()
261+
model_decode = torch.jit.freeze(model_decode)
262+
259263
print("done ipex default recipe.......................")
260-
# freeze the module
261-
# model = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model._c, preserveParameters=True))
262-
# model_decode = torch.jit._recursive.wrap_cpp_module(torch._C._freeze_module(model_decode._c, preserveParameters=True))
263264

264265
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
265266
# At the 2nd run, the llga pass will be triggered and the model is turned into an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
266267
with torch.no_grad():
267268
for i in range(2):
268-
# _ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
269-
if args.use_multi_stream_module:
270-
_ = model_decode(torch.randn(batch_per_stream, 3, 1200, 1200).to(memory_format=torch.channels_last))
271-
else:
272-
_ = model_decode(torch.randn(args.batch_size, 3, 1200, 1200).to(memory_format=torch.channels_last))
269+
_ = model_decode(example_inputs)
273270

274271
if args.use_throughput_benchmark:
275272

0 commit comments

Comments
 (0)