|
97 | 97 | # intel
|
98 | 98 | import intel_extension_for_pytorch as ipex
|
99 | 99 | from torch.utils import ThroughputBenchmark
|
| 100 | + |
| 101 | +# int8 |
| 102 | +from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig |
| 103 | +from intel_extension_for_pytorch.quantization import prepare, convert |
100 | 104 | # For distributed run
|
101 | 105 | import extend_distributed as ext_dist
|
102 | 106 |
|
@@ -403,12 +407,18 @@ def trace_model(args, dlrm, test_ld):
|
403 | 407 | dlrm.emb_l.bfloat16()
|
404 | 408 | dlrm = ipex.optimize(dlrm, dtype=torch.bfloat16, inplace=True)
|
405 | 409 | elif args.int8:
|
406 |
| - conf = ipex.quantization.QuantConf(args.int8_configure) |
407 |
| - dlrm = ipex.quantization.convert(dlrm, conf, (X, lS_o, lS_i)) |
| 410 | + qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), |
| 411 | + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) |
| 412 | + prepared_dlrm = prepare(dlrm, qconfig, example_inputs=(X, lS_o, lS_i), inplace=True) |
| 413 | + prepared_dlrm.load_qconf_summary(qconf_summary = args.int8_configure) |
| 414 | + dlrm = convert(prepared_dlrm) |
408 | 415 | else:
|
409 | 416 | dlrm = ipex.optimize(dlrm, dtype=torch.float, inplace=True)
|
410 | 417 | if args.int8:
|
411 |
| - dlrm = freeze(dlrm) |
| 418 | + dlrm = torch.jit.trace(dlrm, [X, lS_o, lS_i]) |
| 419 | + dlrm = torch.jit.freeze(dlrm) |
| 420 | + dlrm(X, lS_o, lS_i) |
| 421 | + dlrm(X, lS_o, lS_i) |
412 | 422 | else:
|
413 | 423 | with torch.cpu.amp.autocast(enabled=args.bf16):
|
414 | 424 | dlrm = torch.jit.trace(dlrm, (X, lS_o, lS_i), check_trace=True)
|
@@ -666,6 +676,7 @@ def run():
|
666 | 676 | parser.add_argument("--ipex-merged-emb", action="store_true", default=False)
|
667 | 677 | parser.add_argument("--num-warmup-iters", type=int, default=1000)
|
668 | 678 | parser.add_argument("--int8", action="store_true", default=False)
|
| 679 | + parser.add_argument("--calibration", action="store_true", default=False) |
669 | 680 | parser.add_argument("--int8-configure", type=str, default="./int8_configure.json")
|
670 | 681 | parser.add_argument("--dist-backend", type=str, default="ccl")
|
671 | 682 |
|
@@ -743,6 +754,7 @@ def run():
|
743 | 754 | sigmoid_top=ln_top.size - 2,
|
744 | 755 | loss_threshold=args.loss_threshold,
|
745 | 756 | )
|
| 757 | + |
746 | 758 | if args.ipex_merged_emb:
|
747 | 759 | dlrm.emb_l = ipex.nn.modules.MergedEmbeddingBagWithSGD.from_embeddingbag_list(dlrm.emb_l, lr=args.learning_rate)
|
748 | 760 | dlrm.need_linearize_indices_and_offsets = torch.BoolTensor([False])
|
@@ -805,6 +817,26 @@ def run():
|
805 | 817 | print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))
|
806 | 818 |
|
807 | 819 | ext_dist.barrier()
|
| 820 | + |
| 821 | + if args.calibration: |
| 822 | + assert args.load_model != "", "need load weight to do calibration" |
| 823 | + dlrm.eval() |
| 824 | + qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8), |
| 825 | + weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) |
| 826 | + for j, inputBatch in enumerate(train_ld): |
| 827 | + X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch) |
| 828 | + example_inputs = (X, lS_o, lS_i) |
| 829 | + prepared_dlrm = prepare(dlrm, qconfig, example_inputs=example_inputs, inplace=True) |
| 830 | + break |
| 831 | + |
| 832 | + for j, inputBatch in enumerate(train_ld): |
| 833 | + prepared_dlrm(X, lS_o, lS_i) |
| 834 | + if j == 2: |
| 835 | + break |
| 836 | + prepared_dlrm.save_qconf_summary(qconf_summary = args.int8_configure) |
| 837 | + print("calibration done, save config file to ", args.int8_configure) |
| 838 | + exit() |
| 839 | + |
808 | 840 | print("time/loss/accuracy (if enabled):")
|
809 | 841 |
|
810 | 842 | if args.bf16 and not args.inference_only:
|
|
0 commit comments