Skip to content

Commit 6c53583

Browse files
Itssshikharclaude
andcommitted
GPTQ Hessian all-reduce on PR1851 base
PR openai#1851's collect_hessians (line 2037-2150 of _top_ref/train_gpt.py) computes each rank's Hessian on its own data shard subset (ShuffledSequenceLoader splits files by rank) and divides only by n_calibration_batches — without all-reduce, only rank 0's Hessian is effectively used since only rank 0 writes the quantized blob. 7/8 of calibration compute is wasted. Fix: dist.all_reduce(SUM) each Hessian (sorted iteration to avoid deadlock if key order ever drifts), divide by n_calibration_batches * world_size. Smoking- gun log line "gptq:all-rank Hessian averaging across N ranks (denom=...)" when on, "gptq:per-rank Hessian (no all-reduce, denom=...)" when off. Gated by GPTQ_ALL_REDUCE env var (default 1, the bugfix behavior). Off path preserves the original upstream semantics for clean A/B if needed. PR1493 evidence at gptq_calibration_batches=16 (PR openai#1851's default): 16-shard no-AR: q_ttt = 1.08060 16-shard AR : q_ttt = 1.07977 (delta -0.00083) At 128 calibration batches the AR delta saturates to noise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3f661b3 commit 6c53583

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

train_top.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ class Hyperparameters:
309309
compressor = os.environ.get("COMPRESSOR", "brotli")
310310
gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 16))
311311
gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 4.0))
312+
gptq_all_reduce = bool(int(os.environ.get("GPTQ_ALL_REDUCE", "1")))
312313
phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000))
313314
phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 1))
314315
global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001))
@@ -2145,8 +2146,20 @@ def hook_fn(module, inp, out):
21452146
block.attn._calib = False
21462147
block.mlp._calib = False
21472148
block.mlp.use_fused = True
2149+
distributed = dist.is_available() and dist.is_initialized()
2150+
world_size = dist.get_world_size() if distributed else 1
2151+
do_ar = bool(getattr(h, "gptq_all_reduce", True)) and distributed and world_size > 1
2152+
if do_ar:
2153+
log(f"gptq:all-rank Hessian averaging across {world_size} ranks "
2154+
f"(denom={n_calibration_batches * world_size})")
2155+
for name in sorted(hessians.keys()):
2156+
dist.all_reduce(hessians[name], op=dist.ReduceOp.SUM)
2157+
denom = n_calibration_batches * world_size
2158+
else:
2159+
log(f"gptq:per-rank Hessian (no all-reduce, denom={n_calibration_batches})")
2160+
denom = n_calibration_batches
21482161
for name in hessians:
2149-
hessians[name] = hessians[name].cpu() / n_calibration_batches
2162+
hessians[name] = hessians[name].cpu() / denom
21502163
return hessians
21512164

21522165

0 commit comments

Comments
 (0)