Non-record: Focal Loss (gamma=2.0) — val_bpb=1.1460#1233
Non-record: Focal Loss (gamma=2.0) — val_bpb=1.1460#1233ibarrajo wants to merge 1 commit intoopenai:mainfrom
Conversation
…460) Replaces standard cross-entropy with focal loss (1-p)^2 * CE during training to down-weight easy tokens and focus gradient on hard tokens. Built on Approach B (Int5 GPTQ + 33.6M params). Focal loss at gamma=2.0 hurts BPB by +0.028 vs baseline, suggesting the technique over-suppresses gradients from well-predicted tokens that still carry useful signal. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Community Review — Non-record: Focal Loss (gamma=2.0) — val_bpb=1.1460Compliance: LOOKS CLEAN — legal score-first-per-chunk TTT (PR #1413 pattern) Analysis ### PR #1233 — Approach H: Focal Loss + Int5 GPTQ + 33.6M params Head SHA: 65c71c5 --- ### Check 1: ILLEGAL n-gram family bug (target XOR'd into hash key) CLEAR — Not present. Lines 495–501 define
|
Per-token NLL rescaled by detached, clipped, mean-1-normalized ratio of own NLL to batch-mean NLL, raised to alpha (warmup-ramped). Bit-identical to PR openai#1413 (1.0810 main frontier) when LOSS_REWEIGHT_ALPHA=0. Patch is 4 surgical edits to PR openai#1413 train_gpt.py: hyperparameters (+4 env vars), GPT.__init__ (+_train_step buffer), GPT.forward (constant-branch on alpha==0 else weighted CE), step_fn (fill _train_step each step). Wrapped LZMA script grew 308 bytes; tightest base seed keeps ~7.7KB headroom under 16MB cap. README acknowledges prior negative results (PR openai#1360 Gaussian reweight, PR openai#1233 focal gamma=2, PR openai#1380 focal investigation) and frames this as replication on a stronger TTT-heavy base where train-time hardness focus could interact with eval-time TTT in ways the older bases can't show.
Summary
(1-p)^gamma * CEwith gamma=2.0 replaces standard cross-entropy during training, down-weighting easy tokens to focus gradient signal on hard tokens(1-p)^2among other techniquesResults
Delta: +0.028 BPB vs baseline — focal loss hurts at gamma=2.0.
Analysis: Why Focal Loss Hurts
Focal loss at gamma=2.0 over-suppresses gradients from well-predicted tokens. In language modeling (unlike object detection where focal loss originated), even "easy" tokens carry useful distributional signal. The
(1-p)^2factor reduces their gradient contribution too aggressively, slowing overall learning. A lower gamma (0.5-1.0) or curriculum-style scheduling might work better, but was not explored.Key Changes
forward():loss = ((1 - (-ce).exp()).pow(gamma) * ce).mean()FOCAL_GAMMAenv var (default 2.0, set to 0.0 for standard CE)Rule Compliance
Test Plan
🤖 Generated with Claude Code