Skip to content

Commit ef5cfda

Browse files
resouerclaude
andcommitted
Apply openai#1445 improved training config
WD=0.095, MATRIX_LR=0.022, EMA=0.9965, RECUR_START=2000, WARMDOWN=0.72 These settings push SP4096 base from ~1.090 to ~1.089 per PR openai#1445. Combined with SLOT (-0.013): target 1.076. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c171fc0 commit ef5cfda

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

train_gpt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Hyperparameters():
3939

4040
# Training length
4141
iterations = int(os.environ.get('ITERATIONS', 20000))
42-
warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667))
42+
warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72))
4343
warmup_steps = int(os.environ.get('WARMUP_STEPS', 20))
4444
train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8))
4545
train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048))
@@ -79,7 +79,7 @@ class Hyperparameters():
7979
head_lr = float(os.environ.get('HEAD_LR', 0.008))
8080
tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03))
8181
tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005))
82-
matrix_lr = float(os.environ.get('MATRIX_LR', 0.02))
82+
matrix_lr = float(os.environ.get('MATRIX_LR', 0.022))
8383
scalar_lr = float(os.environ.get('SCALAR_LR', 0.02))
8484
muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99))
8585
muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5))
@@ -92,13 +92,13 @@ class Hyperparameters():
9292
eval_stride = int(os.environ.get('EVAL_STRIDE', 64))
9393
muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95))
9494
adam_wd = float(os.environ.get('ADAM_WD', 0.02))
95-
muon_wd = float(os.environ.get('MUON_WD', 0.090))
96-
embed_wd = float(os.environ.get('EMBED_WD', 0.090))
97-
ema_decay = float(os.environ.get('EMA_DECAY', 0.997))
95+
muon_wd = float(os.environ.get('MUON_WD', 0.095))
96+
embed_wd = float(os.environ.get('EMBED_WD', 0.095))
97+
ema_decay = float(os.environ.get('EMA_DECAY', 0.9965))
9898

9999
# Depth Recurrence (Modification 2)
100100
recur_layers = os.environ.get("RECUR_LAYERS", "4,5")
101-
recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000))
101+
recur_start_step = int(os.environ.get("RECUR_START_STEP", 2000))
102102

103103
# Parallel Residuals (Modification 5)
104104
parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", "7"))

0 commit comments

Comments
 (0)