Skip to content

Commit fc9994e

Browse files
amrayachclaude
andcommitted
Replace GPTQ path with PR openai#1019 faithful transplant
Previous replay_ref_hfix still showed 66/66 worse layers with the Hessian normalization fix. Rather than continuing to debug from symptoms, this transplants PR openai#1019's complete GPTQ slice verbatim: - collect_hessians: PR openai#1019 hook pattern with pre-init and param_name keys - quantize_int6_gptq: verbatim from PR openai#1019 lines 1171-1224 - gptq_mixed_quantize_int6: direct param_name key lookup, PR openai#1019 quantizer Source: pr-1019-gptq:records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/train_gpt.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9cea7e9 commit fc9994e

1 file changed

Lines changed: 98 additions & 167 deletions

File tree

  • records/track_non_record_16mb/2026-03-29_full_hessian_gptq

records/track_non_record_16mb/2026-03-29_full_hessian_gptq/train_gpt.py

Lines changed: 98 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -479,30 +479,6 @@ def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object],
479479
# GPTQ: HESSIAN COLLECTION + QUANTIZATION
480480
# -----------------------------
481481

482-
def _make_hessian_hook(hessians: dict[str, Tensor], name: str, n_samples: dict[str, int]):
483-
"""Create a forward hook that accumulates H = X^T X for a linear layer."""
484-
def hook(_module, input, _output):
485-
x = input[0].detach().float()
486-
if x.ndim > 2:
487-
x = x.reshape(-1, x.shape[-1]) # (B*T, D)
488-
if name not in hessians:
489-
hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device="cpu")
490-
hessians[name].add_((x.T @ x).cpu())
491-
n_samples[name] = n_samples.get(name, 0) + x.shape[0]
492-
return hook
493-
494-
495-
def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> dict[str, Tensor]:
496-
"""Match the PR Hessian preparation path before quantization."""
497-
finalized: dict[str, Tensor] = {}
498-
for name, H in hessians.items():
499-
H = H.float() / max(num_batches, 1)
500-
damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6)
501-
H = H + damp * torch.eye(H.shape[0], dtype=H.dtype)
502-
finalized[name] = H
503-
return finalized
504-
505-
506482
def collect_hessians(
507483
model: nn.Module,
508484
data_files: str,
@@ -511,145 +487,110 @@ def collect_hessians(
511487
seq_len: int = 2048,
512488
batch_size: int = 4,
513489
) -> dict[str, Tensor]:
514-
"""Collect H = X^T X for each int6-targeted CastedLinear layer using training data.
490+
"""Collect H = X^T X for CastedLinear layers (PR #1019 faithful transplant).
515491
516-
Returns dict mapping module_name (e.g. "blocks.0.attn.c_q") -> H (float32, CPU).
517-
Uses forward_logits (no targets needed) and TokenStream (rank-0 only).
492+
Returns dict mapping param_name (e.g. "blocks.0.attn.c_q.weight") -> H (float32, CPU).
518493
"""
519494
hessians: dict[str, Tensor] = {}
520-
n_samples: dict[str, int] = {}
521-
handles = []
522-
skipped_names = []
523-
sd_keys = set(model.state_dict().keys())
524-
495+
hooks: list = []
525496
for name, module in model.named_modules():
526497
if isinstance(module, CastedLinear):
527-
weight_key = name + ".weight"
528-
assert weight_key in sd_keys, f"No state_dict key for {weight_key}"
529-
cat = _classify_param(weight_key)
530-
if cat not in ("mlp", "attn"):
531-
skipped_names.append(name)
532-
continue
533-
handles.append(module.register_forward_hook(_make_hessian_hook(hessians, name, n_samples)))
534-
535-
print(f"gptq: hooking {len(handles)} layers, skipped: {skipped_names}", flush=True)
498+
param_name = name + ".weight"
499+
cols = module.weight.shape[1]
500+
hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu')
501+
def make_hook(pname):
502+
def hook_fn(module, input, output):
503+
x = input[0].detach().float()
504+
if x.ndim == 3:
505+
x = x.reshape(-1, x.shape[-1])
506+
hessians[pname] += (x.T @ x).cpu()
507+
return hook_fn
508+
hooks.append(module.register_forward_hook(make_hook(param_name)))
509+
510+
print(f"gptq: hooking {len(hooks)} layers", flush=True)
536511

537-
stream = TokenStream(data_files)
538512
num_batches = num_samples // batch_size
513+
stream = TokenStream(data_files)
539514
model.eval()
540-
with torch.inference_mode():
541-
for i in range(num_batches):
515+
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
516+
for _ in range(num_batches):
542517
tokens = stream.take(batch_size * seq_len).to(device=device, dtype=torch.int64)
543518
x = tokens.reshape(batch_size, seq_len)
544-
with torch.autocast(device_type="cuda", dtype=LOWP_DTYPE, enabled=True):
545-
model.forward_logits(x)
519+
model.forward_logits(x)
546520

547-
for h in handles:
521+
for h in hooks:
548522
h.remove()
549523

550-
result = _finalize_hessians(hessians, num_batches)
551-
print(f"gptq: collected {len(result)} Hessians, samples per layer: "
552-
f"{dict(list(n_samples.items())[:3])}...", flush=True)
553-
return result
554-
555-
556-
def gptq_quantize_layer(
557-
W: Tensor,
558-
H: Tensor,
559-
block_size: int = 128,
560-
percdamp: float = 0.01,
561-
clip_range: int = 31,
562-
actorder: bool = True,
563-
) -> tuple[Tensor, Tensor, bool, dict[str, object]]:
564-
"""GPTQ-quantize weight matrix W using the PR-grounded loop.
565-
566-
Returns (q: int8 in [-clip_range, clip_range], scale: fp16 per-row, degraded: bool, stats).
567-
"""
568-
W_orig = W.float().clone()
569-
H = H.float().clone()
570-
d_row, d_col = W_orig.shape
571-
572-
dead = (H.diag() == 0)
573-
H[dead, dead] = 1.0
574-
575-
damp = percdamp * H.diag().mean().clamp_min(1e-6)
576-
H.diagonal().add_(damp)
577-
578-
perm = torch.arange(d_col, device=H.device)
579-
if actorder:
580-
perm = torch.argsort(H.diag(), descending=True)
581-
invperm = torch.argsort(perm)
582-
W_perm = W_orig[:, perm].clone()
583-
W_perm[:, dead[perm]] = 0.0
524+
# Finalize: normalize and damp (PR #1019 lines 1131-1136)
525+
n_samples_example = {}
526+
for name in hessians:
527+
H = hessians[name]
528+
n_samples_example[name] = num_batches * batch_size * seq_len
529+
H /= num_batches
530+
damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6)
531+
H += damp * torch.eye(H.shape[0])
532+
hessians[name] = H
533+
534+
print(f"gptq: collected {len(hessians)} Hessians, samples per layer: "
535+
f"{dict(list(n_samples_example.items())[:3])}...", flush=True)
536+
return hessians
537+
538+
539+
def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128):
540+
"""Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation.
541+
If hessian is None, falls back to percentile search.
542+
Verbatim from PR #1019 (lines 1171-1224)."""
543+
t32 = weight.float()
544+
if t32.ndim != 2 or hessian is None:
545+
return quantize_int6_per_row(t32, clip_range)
546+
rows, cols = t32.shape
547+
H = hessian.float().clone()
548+
dead = torch.diag(H) == 0
549+
H[dead, dead] = 1
550+
damp = 0.01 * torch.mean(torch.diag(H))
551+
H[torch.arange(cols), torch.arange(cols)] += damp
552+
perm = torch.argsort(torch.diag(H), descending=True)
553+
inv_perm = torch.argsort(perm)
554+
W = t32[:, perm].clone()
555+
W[:, dead[perm]] = 0
584556
H = H[perm][:, perm]
585-
586-
try:
587-
Hinv_chol = torch.cholesky_inverse(torch.linalg.cholesky(H))
588-
Hinv_chol = torch.linalg.cholesky(Hinv_chol, upper=True)
589-
except torch.linalg.LinAlgError:
590-
cond_est = H.diag().max().item() / max(H.diag().min().item(), 1e-12)
591-
print(f"gptq:WARNING Cholesky failed, cond~{cond_est:.1e}, falling back to naive", flush=True)
592-
q, s = quantize_int6_per_row(W_orig, clip_range=clip_range)
593-
return q, s, True, {
594-
"mse": _quantization_mse(W_orig, q, s),
595-
"best_pct": None,
596-
"dead_cols": int(dead.sum().item()),
597-
"max_block_mse": None,
598-
"worst_block_start": None,
599-
"cholesky_fallback": True,
600-
}
601-
602-
best_q, best_scale, best_stats = None, None, None
603-
best_err = float("inf")
604-
for pct, row_clip in _iter_int6_row_clips(W_orig):
605-
scale = (row_clip / float(clip_range)).clamp_min(1.0 / float(clip_range)).to(torch.float16)
606-
scale_f = scale.float()
607-
Q = torch.zeros((d_row, d_col), dtype=torch.int8, device=W_perm.device)
608-
W_work = W_perm.clone()
609-
max_block_mse = -1.0
610-
worst_block_start = 0
611-
612-
for block_start in range(0, d_col, block_size):
613-
block_end = min(block_start + block_size, d_col)
614-
W_block = W_work[:, block_start:block_end].clone()
615-
Err = torch.zeros((d_row, block_end - block_start), dtype=W_work.dtype, device=W_work.device)
616-
Hinv_block = Hinv_chol[block_start:block_end, block_start:block_end]
617-
618-
for j in range(block_end - block_start):
619-
w_col = W_block[:, j]
620-
d = Hinv_block[j, j]
621-
q_col = torch.clamp(torch.round(w_col / scale_f), -clip_range, clip_range)
622-
Q[:, block_start + j] = q_col.to(torch.int8)
623-
err = (w_col - q_col.float() * scale_f) / d
624-
Err[:, j] = err
625-
W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0)
626-
627-
if block_end < d_col:
628-
W_work[:, block_end:] -= Err @ Hinv_chol[block_start:block_end, block_end:]
629-
630-
block_recon = Q[:, block_start:block_end].float() * scale_f[:, None]
631-
block_mse = float((W_perm[:, block_start:block_end] - block_recon).pow(2).mean().item())
632-
if block_mse > max_block_mse:
633-
max_block_mse = block_mse
634-
worst_block_start = block_start
635-
636-
recon = Q.float() * scale_f[:, None]
637-
mse = float((W_perm - recon).pow(2).mean().item())
557+
Hinv = torch.linalg.cholesky(H)
558+
Hinv = torch.cholesky_inverse(Hinv)
559+
Hinv = torch.linalg.cholesky(Hinv, upper=True)
560+
best_q = None; best_scale = None; best_err = float('inf')
561+
for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]:
562+
if pct < 1.0:
563+
row_clip = torch.quantile(t32.abs(), pct, dim=1)
564+
else:
565+
row_clip = t32.abs().amax(dim=1)
566+
s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16)
567+
sf = s.float()
568+
Q = torch.zeros_like(W, dtype=torch.int8)
569+
W_work = W.clone()
570+
for i1 in range(0, cols, block_size):
571+
i2 = min(i1 + block_size, cols)
572+
count = i2 - i1
573+
W1 = W_work[:, i1:i2].clone()
574+
Q1 = torch.zeros(rows, count, dtype=torch.int8)
575+
Err1 = torch.zeros(rows, count)
576+
Hinv1 = Hinv[i1:i2, i1:i2]
577+
for i in range(count):
578+
w = W1[:, i]
579+
d = Hinv1[i, i]
580+
q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8)
581+
Q1[:, i] = q
582+
err = (w - q.float() * sf) / d
583+
W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0)
584+
Err1[:, i] = err
585+
Q[:, i1:i2] = Q1
586+
if i2 < cols:
587+
W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:]
588+
recon = Q.float() * sf[:, None]
589+
mse = (W - recon).pow(2).mean().item()
638590
if mse < best_err:
639-
best_q = Q
640-
best_scale = scale
641-
best_err = mse
642-
best_stats = {
643-
"mse": mse,
644-
"best_pct": pct,
645-
"dead_cols": int(dead.sum().item()),
646-
"max_block_mse": max_block_mse,
647-
"worst_block_start": worst_block_start,
648-
"cholesky_fallback": False,
649-
}
650-
651-
best_q = best_q[:, invperm]
652-
return best_q, best_scale, False, best_stats
591+
best_q, best_scale, best_err = Q, s, mse
592+
best_q = best_q[:, inv_perm]
593+
return best_q, best_scale
653594

654595

655596
def gptq_mixed_quantize_int6(
@@ -662,8 +603,8 @@ def gptq_mixed_quantize_int6(
662603
):
663604
"""Like mixed_quantize_int6, but uses GPTQ for layers with Hessians.
664605
665-
hessians: dict mapping module_name (e.g. "blocks.0.attn.c_q") -> H tensor.
666-
state_dict keys use param names (e.g. "blocks.0.attn.c_q.weight").
606+
hessians: dict mapping param_name (e.g. "blocks.0.attn.c_q.weight") -> H tensor.
607+
state_dict keys use the same param names.
667608
"""
668609
result: dict[str, Tensor] = {}
669610
meta: dict[str, object] = {}
@@ -681,33 +622,25 @@ def gptq_mixed_quantize_int6(
681622
meta[name] = "passthrough_ctrl"
682623
continue
683624
if cat in int6_cats and t.ndim >= 1:
684-
module_name = name.rsplit(".weight", 1)[0] if name.endswith(".weight") else name
685-
H = hessians.get(module_name)
625+
H = hessians.get(name)
686626
if H is not None and t.ndim == 2:
687627
legacy_q, legacy_s = quantize_int6_per_row_legacy(t, clip_range=clip_range)
688628
naive_q, naive_s = quantize_int6_per_row(t, clip_range=clip_range)
689-
q, s, degraded, gptq_stats = gptq_quantize_layer(
690-
t, H, block_size=block_size, clip_range=clip_range, actorder=actorder,
691-
)
629+
q, s = quantize_int6_gptq(t, hessian=H, clip_range=clip_range, block_size=block_size)
692630
legacy_mse = _quantization_mse(t, legacy_q, legacy_s)
693631
naive_mse = _quantization_mse(t, naive_q, naive_s)
694-
gptq_mse = float(gptq_stats["mse"])
632+
gptq_mse = _quantization_mse(t, q, s)
695633
diagnostics.append({
696-
"name": module_name,
634+
"name": name,
697635
"legacy_rowmax_mse": legacy_mse,
698636
"percentile_naive_mse": naive_mse,
699637
"gptq_mse": gptq_mse,
700638
"gptq_minus_legacy_rowmax_mse": gptq_mse - legacy_mse,
701639
"gptq_minus_percentile_naive_mse": gptq_mse - naive_mse,
702640
"gptq_worse_than_legacy_rowmax": gptq_mse > legacy_mse,
703641
"gptq_worse_than_percentile_naive": gptq_mse > naive_mse,
704-
**gptq_stats,
705642
})
706-
if degraded:
707-
fallback_count += 1
708-
print(f"gptq: DEGRADED layer {module_name}", flush=True)
709-
else:
710-
gptq_count += 1
643+
gptq_count += 1
711644
else:
712645
q, s = quantize_int6_per_row(t, clip_range=clip_range)
713646
naive_count += 1
@@ -1732,8 +1665,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17321665
log0(
17331666
"gptq: legacy_worse "
17341667
f"{diag['name']} delta_mse:{diag['gptq_minus_legacy_rowmax_mse']:.6e} "
1735-
f"gptq_mse:{diag['gptq_mse']:.6e} legacy_mse:{diag['legacy_rowmax_mse']:.6e} "
1736-
f"pct:{diag['best_pct']} block:{diag['worst_block_start']}"
1668+
f"gptq_mse:{diag['gptq_mse']:.6e} legacy_mse:{diag['legacy_rowmax_mse']:.6e}"
17371669
)
17381670
if not worse_naive:
17391671
log0("gptq: no layers worse than percentile naive int6")
@@ -1742,8 +1674,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
17421674
log0(
17431675
"gptq: percentile_worse "
17441676
f"{diag['name']} delta_mse:{diag['gptq_minus_percentile_naive_mse']:.6e} "
1745-
f"gptq_mse:{diag['gptq_mse']:.6e} naive_mse:{diag['percentile_naive_mse']:.6e} "
1746-
f"pct:{diag['best_pct']} block:{diag['worst_block_start']}"
1677+
f"gptq_mse:{diag['gptq_mse']:.6e} naive_mse:{diag['percentile_naive_mse']:.6e}"
17471678
)
17481679
del hessians, sd_cpu_r0
17491680

0 commit comments

Comments
 (0)