Skip to content

Commit 5f869c3

Browse files
committed
phase-6: add gptq-lite export path
Replace single fixed clip percentile (99.99984) with per-row optimal clip search across 5 percentiles [0.999, 0.9995, 0.9999, 0.99999, 1.0]. Each row uses the percentile giving minimum reconstruction MSE. Deterministic, zero training cost. Used by openai#2 submission (est. -0.003 q_gap).
1 parent 3b02f8f commit 5f869c3

1 file changed

Lines changed: 23 additions & 12 deletions

File tree

train_gpt.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,7 @@ def eval_val_sliding(
369369
INT8_KEEP_FLOAT_MAX_NUMEL = 65_536
370370
INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16
371371
INT8_PER_ROW_SCALE_DTYPE = torch.float16
372-
INT8_CLIP_PERCENTILE = 99.99984
373-
INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0
372+
GPTQ_CLIP_PERCENTILES = [0.999, 0.9995, 0.9999, 0.99999, 1.0]
374373

375374
def tensor_nbytes(t: Tensor) -> int:
376375
return int(t.numel()) * int(t.element_size())
@@ -386,17 +385,29 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s
386385
def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]:
387386
max_val = 2 ** (bits - 1) - 1
388387
t32 = t.float()
388+
if t32.ndim == 2 and t32.numel() > 0:
389+
best_q = None
390+
best_scale = None
391+
best_mse = torch.full((t32.shape[0],), float("inf"))
392+
for pct in GPTQ_CLIP_PERCENTILES:
393+
clip_abs = torch.quantile(t32.abs(), pct, dim=1)
394+
scale = (clip_abs / max_val).clamp_min(1.0 / max_val)
395+
clipped = torch.clamp(t32, -clip_abs[:, None], clip_abs[:, None])
396+
q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8)
397+
recon = q.float() * scale[:, None]
398+
mse = (t32 - recon).pow(2).mean(dim=1)
399+
improved = mse < best_mse
400+
if improved.any():
401+
if best_q is None:
402+
best_q, best_scale, best_mse = q.clone(), scale.clone(), mse.clone()
403+
else:
404+
best_q[improved] = q[improved]
405+
best_scale[improved] = scale[improved]
406+
best_mse[improved] = mse[improved]
407+
return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous()
389408
if t32.ndim == 2:
390-
clip_abs = (
391-
torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1)
392-
if t32.numel()
393-
else torch.empty((t32.shape[0],), dtype=torch.float32)
394-
)
395-
clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None])
396-
scale = (clip_abs / max_val).clamp_min(1.0 / max_val)
397-
q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous()
398-
return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous()
399-
clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0
409+
return torch.zeros_like(t32, dtype=torch.int8), torch.empty((t32.shape[0],), dtype=INT8_PER_ROW_SCALE_DTYPE)
410+
clip_abs = float(t32.abs().max().item()) if t32.numel() else 0.0
400411
scale = torch.tensor(clip_abs / max_val if clip_abs > 0 else 1.0, dtype=torch.float32)
401412
q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous()
402413
return q, scale

0 commit comments

Comments
 (0)