Skip to content

Commit a8d3edb

Browse files
committed
Bake the claimed SpinQuant and phased-TTT defaults into the PR openai#1695 surface
The public openai#1695 run command relies on several env-var overrides that materially change the surface: SpinQuant on, phased TTT on, matrix LR 0.026, warmdown 0.75, embed_bits 7, embed_clip 20, chunk size 48, and the higher LoRA layer alpha. This branch bakes those settings into defaults so the reproduction lane can test the claimed surface rather than the inert default one. Constraint: Must preserve the public PR surface and only move claimed run-command settings into code defaults. Rejected: Reproduce with env vars only | the current evaluator path does not forward arbitrary env vars to remote jobs Confidence: high Scope-risk: narrow Reversibility: clean Directive: Any future public frontier PR must pass claimed-surface/default-surface comparison before it is treated as a serious candidate family Tested: python3 -m py_compile train_gpt.py evaluate.py Not-tested: GPU execution on Heimdall
1 parent 516864d commit a8d3edb

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

train_gpt.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ class Hyperparameters:
1717
seed = int(os.environ.get("SEED", 1337))
1818
run_id = os.environ.get("RUN_ID", str(uuid.uuid4()))
1919
iterations = int(os.environ.get("ITERATIONS", 20000))
20-
warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.72))
20+
warmdown_frac = float(os.environ.get("WARMDOWN_FRAC", 0.75))
2121
warmup_steps = int(os.environ.get("WARMUP_STEPS", 20))
2222
train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786432))
2323
train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048))
2424
train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500))
2525
max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 6e2))
2626
val_batch_tokens = int(os.environ.get("VAL_BATCH_TOKENS", 524288))
2727
eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048))
28-
val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000))
28+
val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 20000))
2929
sliding_window_enabled = bool(int(os.environ.get("SLIDING_WINDOW_ENABLED", "0")))
3030
vocab_size = int(os.environ.get("VOCAB_SIZE", 8192))
3131
num_layers = int(os.environ.get("NUM_LAYERS", 11))
@@ -47,7 +47,7 @@ class Hyperparameters:
4747
# canonical weights (attn c_q/c_k/c_v/proj, mlp fc/proj) using 4 globally
4848
# shared orthogonal matrices. State dict W <- W @ R, Hessians H <- R^T H R.
4949
# See install_spinquant_rotations / _spinquant_rotate_sd_and_H.
50-
spinquant_enabled = bool(int(os.environ.get("SPINQUANT_ENABLED", "0")))
50+
spinquant_enabled = bool(int(os.environ.get("SPINQUANT_ENABLED", "1")))
5151
spinquant_seed = int(os.environ.get("SPINQUANT_SEED", "20260416"))
5252
ln_scale = bool(int(os.environ.get("LN_SCALE", "1")))
5353
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0))
@@ -62,7 +62,7 @@ class Hyperparameters:
6262
head_lr = float(os.environ.get("HEAD_LR", 0.008))
6363
tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03))
6464
tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005))
65-
matrix_lr = float(os.environ.get("MATRIX_LR", 0.022))
65+
matrix_lr = float(os.environ.get("MATRIX_LR", 0.026))
6666
scalar_lr = float(os.environ.get("SCALAR_LR", 0.02))
6767
muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.97))
6868
muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
@@ -85,8 +85,8 @@ class Hyperparameters:
8585
ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 96))
8686
ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.0001))
8787
lora_plus_ratio = float(os.environ.get("LORA_PLUS_RATIO", 1.0))
88-
ttt_lora_layer_lr_alpha = float(os.environ.get("TTT_LORA_LAYER_LR_ALPHA", 0.0))
89-
ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 32))
88+
ttt_lora_layer_lr_alpha = float(os.environ.get("TTT_LORA_LAYER_LR_ALPHA", 0.5))
89+
ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 48))
9090
ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 2048))
9191
ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64))
9292
ttt_grad_steps = int(os.environ.get("TTT_GRAD_STEPS", 1))
@@ -104,7 +104,7 @@ class Hyperparameters:
104104
# Phased TTT: split prefix docs into N phases. Between phases, run SGD on
105105
# the base model using all scored-prefix tokens. Score-first-then-update
106106
# legality is preserved — only already-scored tokens feed the SGD.
107-
phased_ttt_enabled = bool(int(os.environ.get("PHASED_TTT_ENABLED", "0")))
107+
phased_ttt_enabled = bool(int(os.environ.get("PHASED_TTT_ENABLED", "1")))
108108
phased_ttt_prefix_docs = int(os.environ.get("PHASED_TTT_PREFIX_DOCS", 2000))
109109
phased_ttt_num_phases = int(os.environ.get("PHASED_TTT_NUM_PHASES", 3))
110110
global_ttt_lr = float(os.environ.get("GLOBAL_TTT_LR", 0.001))
@@ -121,9 +121,9 @@ class Hyperparameters:
121121
gptq_calibration_batches = int(os.environ.get("GPTQ_CALIBRATION_BATCHES", 64))
122122
gptq_reserve_seconds = float(os.environ.get("GPTQ_RESERVE_SECONDS", 13.0))
123123
matrix_bits = int(os.environ.get("MATRIX_BITS", 6))
124-
embed_bits = int(os.environ.get("EMBED_BITS", 8))
124+
embed_bits = int(os.environ.get("EMBED_BITS", 7))
125125
matrix_clip_sigmas = float(os.environ.get("MATRIX_CLIP_SIGMAS", 12.85))
126-
embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 15.0))
126+
embed_clip_sigmas = float(os.environ.get("EMBED_CLIP_SIGMAS", 20.0))
127127
mlp_clip_sigmas = float(os.environ.get("MLP_CLIP_SIGMAS", 12.0))
128128
attn_clip_sigmas = float(os.environ.get("ATTN_CLIP_SIGMAS", 13.0))
129129
distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ

0 commit comments

Comments
 (0)