Skip to content

Commit 691ad9e

Browse files
authored
update DLRM int8 config with correct calibration set (#1330)
1 parent 2e1d7b0 commit 691ad9e

File tree

4 files changed

+90
-88
lines changed

4 files changed

+90
-88
lines changed

models/recommendation/pytorch/torchrec_dlrm/dlrm_main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ def train_val_test(
800800
results = TrainValTestResults()
801801

802802
if args.inference_only:
803+
# Mlperf is using val set to test auroc
804+
# https://github.com/mlcommons/inference/blob/master/recommendation/dlrm_v2/pytorch/python/multihot_criteo.py#L99-L107
803805
test_auroc = _evaluate(
804806
args.limit_val_batches,
805807
model.model,
@@ -1016,7 +1018,7 @@ def main(argv: List[str]) -> None:
10161018
)
10171019

10181020
train_model = DLRMTrain(dlrm_model)
1019-
if args.test_auroc:
1021+
if args.test_auroc or args.calibration:
10201022
assert args.snapshot_dir
10211023
load_snapshot(train_model, args.snapshot_dir)
10221024
# embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD
@@ -1149,7 +1151,7 @@ def main(argv: List[str]) -> None:
11491151
test_dataloader = RestartableMap(multihot.convert_to_multi_hot, test_dataloader)
11501152

11511153
if args.ipex_optimize:
1152-
model.model, optimizer = ipex_optimize(args, model.model, optimizer, val_dataloader)
1154+
model.model, optimizer = ipex_optimize(args, model.model, optimizer, test_dataloader)
11531155

11541156
train_val_test(
11551157
args,

0 commit comments

Comments
 (0)