Skip to content

Commit 0bded92

Browse files
authored
re-enable int8 for api change (#579)
1 parent 0b286b0 commit 0bded92

File tree

4 files changed

+11187
-5308
lines changed

4 files changed

+11187
-5308
lines changed

models/recommendation/pytorch/dlrm/product/dlrm_s_pytorch.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@
9797
# intel
9898
import intel_extension_for_pytorch as ipex
9999
from torch.utils import ThroughputBenchmark
100+
101+
# int8
102+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
103+
from intel_extension_for_pytorch.quantization import prepare, convert
100104
# For distributed run
101105
import extend_distributed as ext_dist
102106

@@ -403,12 +407,18 @@ def trace_model(args, dlrm, test_ld):
403407
dlrm.emb_l.bfloat16()
404408
dlrm = ipex.optimize(dlrm, dtype=torch.bfloat16, inplace=True)
405409
elif args.int8:
406-
conf = ipex.quantization.QuantConf(args.int8_configure)
407-
dlrm = ipex.quantization.convert(dlrm, conf, (X, lS_o, lS_i))
410+
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
411+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
412+
prepared_dlrm = prepare(dlrm, qconfig, example_inputs=(X, lS_o, lS_i), inplace=True)
413+
prepared_dlrm.load_qconf_summary(qconf_summary = args.int8_configure)
414+
dlrm = convert(prepared_dlrm)
408415
else:
409416
dlrm = ipex.optimize(dlrm, dtype=torch.float, inplace=True)
410417
if args.int8:
411-
dlrm = freeze(dlrm)
418+
dlrm = torch.jit.trace(dlrm, [X, lS_o, lS_i])
419+
dlrm = torch.jit.freeze(dlrm)
420+
dlrm(X, lS_o, lS_i)
421+
dlrm(X, lS_o, lS_i)
412422
else:
413423
with torch.cpu.amp.autocast(enabled=args.bf16):
414424
dlrm = torch.jit.trace(dlrm, (X, lS_o, lS_i), check_trace=True)
@@ -666,6 +676,7 @@ def run():
666676
parser.add_argument("--ipex-merged-emb", action="store_true", default=False)
667677
parser.add_argument("--num-warmup-iters", type=int, default=1000)
668678
parser.add_argument("--int8", action="store_true", default=False)
679+
parser.add_argument("--calibration", action="store_true", default=False)
669680
parser.add_argument("--int8-configure", type=str, default="./int8_configure.json")
670681
parser.add_argument("--dist-backend", type=str, default="ccl")
671682

@@ -743,6 +754,7 @@ def run():
743754
sigmoid_top=ln_top.size - 2,
744755
loss_threshold=args.loss_threshold,
745756
)
757+
746758
if args.ipex_merged_emb:
747759
dlrm.emb_l = ipex.nn.modules.MergedEmbeddingBagWithSGD.from_embeddingbag_list(dlrm.emb_l, lr=args.learning_rate)
748760
dlrm.need_linearize_indices_and_offsets = torch.BoolTensor([False])
@@ -805,6 +817,26 @@ def run():
805817
print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))
806818

807819
ext_dist.barrier()
820+
821+
if args.calibration:
822+
assert args.load_model != "", "need load weight to do calibration"
823+
dlrm.eval()
824+
qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
825+
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
826+
for j, inputBatch in enumerate(train_ld):
827+
X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch)
828+
example_inputs = (X, lS_o, lS_i)
829+
prepared_dlrm = prepare(dlrm, qconfig, example_inputs=example_inputs, inplace=True)
830+
break
831+
832+
for j, inputBatch in enumerate(train_ld):
833+
prepared_dlrm(X, lS_o, lS_i)
834+
if j == 2:
835+
break
836+
prepared_dlrm.save_qconf_summary(qconf_summary = args.int8_configure)
837+
print("calibration done, save config file to ", args.int8_configure)
838+
exit()
839+
808840
print("time/loss/accuracy (if enabled):")
809841

810842
if args.bf16 and not args.inference_only:

0 commit comments

Comments
 (0)