QAT + Architecture Exploration (Non-Record)#20
QAT + Architecture Exploration (Non-Record)#20mattleonard16 wants to merge 5 commits intoopenai:mainfrom
Conversation
Add fake_quantize_per_row() using straight-through estimator to simulate int8 per-row quantization during training. Thread training flag through CastedLinear -> CausalSelfAttention -> MLP -> Block -> GPT -> loss. Closes the quantization gap between pre-quant and post-quant val_bpb.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c879d0d668
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful | ||
| # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). | ||
| x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) | ||
| x = self(input_ids, training=True).reshape(-1, self.tok_emb.weight.shape[1]) |
There was a problem hiding this comment.
Split the fake-quant training loss from evaluation
Because compiled_loss is the function that eval_val() and the final final_int8_zlib_roundtrip check call, hard-wiring training=True here makes every validation pass run through fake_quantize_per_row() too. That means the logged val_loss/val_bpb no longer represent the full-precision model during training, and after reloading the dequantized int8 checkpoint the round-trip check fake-quantizes those weights a second time. If the intent is “train with QAT, evaluate without it”, this needs a separate eval path.
Useful? React with 👍 / 👎.
| def __call__(self, x: mx.array) -> mx.array: | ||
| return x @ self.weight.astype(x.dtype).T | ||
| def __call__(self, x: mx.array, training: bool = False) -> mx.array: | ||
| w = fake_quantize_per_row(self.weight) if training else self.weight |
There was a problem hiding this comment.
Fake-quantize the tied embedding matrix too
The new QAT path only wraps CastedLinear.weight, but the tied tok_emb.weight is still used in full precision for both the input embedding lookup and the LM-head projection. quantize_state_dict_int8() will nevertheless export that tensor as int8 whenever it exceeds INT8_KEEP_FLOAT_MAX_NUMEL (the default 1024×512 table already does), so a large chunk of the deployed quantization error is never present in the training objective. In practice the post-export gap can stay dominated by the embedding table even though the linear layers were trained with QAT.
Useful? React with 👍 / 👎.
GPT.loss() was hard-wiring training=True, causing eval_val() and the int8 roundtrip check to run through fake_quantize_per_row(). Add training param to loss() and split compiled_loss (eval, no QAT) from compiled_loss_and_grad (train, QAT). Also apply fake quantization to tok_emb.weight in the LM-head matmul since it exceeds INT8_KEEP_FLOAT_MAX_NUMEL and gets int8-quantized on export.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 795412d609
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) | ||
| x = self(input_ids, training=training).reshape(-1, self.tok_emb.weight.shape[1]) | ||
| y = target_ids.reshape(-1) | ||
| lm_weight = fake_quantize_per_row(self.tok_emb.weight) if training else self.tok_emb.weight |
There was a problem hiding this comment.
Fake-quantize the input embedding path too
This change now fake-quantizes tok_emb.weight only for the LM-head reuse here, but GPT.__call__() still starts from self.tok_emb(input_ids) in full precision. Because quantize_state_dict_int8() exports that same tied tensor as int8 for the default 1024×512 table, the deployed model sees quantization error on every input lookup as well as on the logits projection. Training therefore still omits a large, always-on source of post-export distortion, so the measured QAT gain can remain overly optimistic even after this partial fix.
Useful? React with 👍 / 👎.
| def __call__(self, x: mx.array) -> mx.array: | ||
| return x @ self.weight.astype(x.dtype).T | ||
| def __call__(self, x: mx.array, training: bool = False) -> mx.array: | ||
| w = fake_quantize_per_row(self.weight) if training else self.weight |
There was a problem hiding this comment.
Skip fake-quantizing matrices that stay float at export
For smaller architecture sweeps, this branch fake-quantizes every CastedLinear weight unconditionally, but quantize_state_dict_int8() only converts tensors larger than INT8_KEEP_FLOAT_MAX_NUMEL and writes smaller matrices to passthrough. In configurations like MODEL_DIM<=256, several attention/MLP weights fall under that cutoff, so QAT trains against noise that will never exist in the serialized model. That can hurt convergence and make QAT-vs-baseline comparisons misleading for the exact exploration runs this commit introduces.
Useful? React with 👍 / 👎.
CastedLinear unconditionally fake-quantized during training, but quantize_state_dict_int8() keeps tensors with numel <= 65,536 in float. For MODEL_DIM<=256 several attention weights hit this cutoff, so QAT was training against noise that disappears at export. Also fake-quantize the embedding lookup path (tok_emb.weight gets int8-quantized at export) via manual index lookup on the STE-quantized weight, and add the same numel guard to the LM-head projection in loss().
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bda498113d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| row_max = mx.max(mx.abs(w32), axis=1, keepdims=True) | ||
| scale = mx.maximum(row_max / 127.0, mx.array(1.0 / 127.0)) | ||
| w_q = mx.round(mx.clip(w32 / scale, -127, 127)) * scale |
There was a problem hiding this comment.
Use exporter clipping in fake_quantize_per_row
quantize_float_array() clips each 2D row at INT8_CLIP_Q before computing its int8 scale, but this new training path derives scale from the absolute row max and never clips. When a row contains even one large outlier (a realistic case for embeddings and projection weights), QAT trains against a much coarser quantizer than quantize_state_dict_int8() will actually serialize, so the measured pre/post-quantization gap no longer reflects the deployed model. If this experiment is intended to validate deployment-aware training, the fake-quant path needs to mirror the exporter’s clipping rule.
Useful? React with 👍 / 👎.
…TORCH_COMPILE=0 + bigger batches) The SPEED1-5 family explicitly set USE_TORCH_COMPILE=1 which is now broken with our patcher (dynamic-shape tracing fails on n-gram bias). They were about to crash 3x each = 75 min wasted on the loop. Replaced with SP1-5 family that explicitly sets USE_TORCH_COMPILE=0 and tests progressively bigger batches/seq lengths to find the GPU saturation sweet spot: SP1: seq=1024, batch=65536 (current BASE_ENV) SP2: seq=1024, batch=131072 (2x batch) SP3: seq=2048, batch=65536 (2x seq) SP4: seq=1024, batch=131072 + Coprime + EngramLite (full stack big batch) SP5: seq=2048, batch=131072 + Coprime + EngramLite (full stack max compute) MAX_WALLCLOCK_SECONDS bumped to 600 to fit the bigger batches without timing out. Goal: identify which (seq, batch) combo gives 80%+ GPU util on the 3080 Ti without OOM. SP4 or SP5 should be the new H100 escalation candidate if they validate cleanly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…E_TORCH_COMPILE=0 Removed SPEED1-5 (would crash from torch.compile=1 explicit env override). Added SP1-5 with USE_TORCH_COMPILE=0 + progressively bigger batches: SP1: seq=1024, batch=65536 (baseline) SP2: seq=1024, batch=131072 (2x batch) SP3: seq=2048, batch=65536 (2x seq) SP4/5: full stack with CS+EL at max compute Goal: identify the (seq, batch) sweet spot that gives 80%+ GPU util on the 3080 Ti without OOM. SP4/5 are H100 escalation candidates. Tasks openai#63 (SPEED family) + openai#65 (speed push 1) marked completed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…penai#21) * feat(mcp): rewrite MCP server in Rust on rmcp 0.8 + axum (openai#20) R1 enforcement: drop the TypeScript MCP shim entirely; the public MCP control plane is now a single Rust binary built from crates/trios-railway-mcp. Stack: - rmcp 0.8 with features [server, macros, transport-streamable-http-server] - axum 0.8 for Streamable HTTP routing under /mcp - schemars 1.0 for tool input schemas - direct dependency on trios-railway-core / -audit / -experience Tools exposed: - railway_service_list - railway_service_deploy (R7 triplet -> .trinity/experience/<YYYYMMDD>.trinity) - railway_service_redeploy - railway_service_delete (R9: requires confirm: true) - railway_experience_append - railway_audit_migrate_sql Listens on 0.0.0.0:$PORT (Railway convention), with /healthz health probe and a plain-text root banner. Local smoke (PORT=8765): initialize -> 200, mcp-session-id assigned tools/list -> 6 tools with full JSON schemas Dockerfile.mcp ships a multi-stage Rust build on debian-slim runtime for Railway dockerfile-based deploy. Anchor: phi^2 + phi^-2 = 3. Closes openai#20 * ci(mcp): add GHCR build workflow for trios-railway-mcp (openai#20) Builds Dockerfile.mcp on push to feat branch + main and pushes multi-tag image to ghcr.io/ghashtag/trios-railway-mcp: - :latest - :<short-sha> - :<branch-name> Cache: GitHub Actions cache (type=gha mode=max). Refs openai#20 * ci(mcp): sanitize branch name for docker tag (openai#20) * build(mcp): bump rust base image to 1.85 for edition2024 (openai#20) * build(mcp): bump rust base image to 1.90 for icu deps (openai#20) --------- Co-authored-by: ghashtag-agent <koldezplattpa5my@hotmail.com>
…) (openai#24) Anchor: phi^2 + phi^-2 = 3. Adds the disaster-recovery and multi-account control plane: - restore-fleet.json — single-source-of-truth manifest for the IGLA fleet (16 services, shared/per-seed env, image pin policy, secret interpolation ${secret:NAME}). - template.json — Railway template for one-click MCP control-plane redeploy after a ban (Deploy on Railway button). - railway.json — config-as-code for the MCP service itself. - .github/workflows/dr-restore.yml — manual-dispatch CI workflow: confirm = PHI -> rebuilds the fleet from the manifest in <5 min, accepts --new-token for post-ban recovery, applies idempotent Neon audit DDL, seals an L7 experience triplet. - .github/workflows/docker-trainer.yml — every 6h + on-push: builds the trainer image and pushes ghcr.io/ghashtag/ trios-trainer-igla:{latest,sha-<commit>} so images survive a Railway account loss. - docs/DR.md — runbook (what survives a ban; pre-bake checklist; 3 trigger paths: CLI / Actions / Railway button). - docs/RESTORE_CLI_SPEC.md — implementation spec for tri-railway restore (flags, pseudocode, exit codes, tests). - docs/MULTI_ACCOUNT.md — multi-account fan-out: 4 lanes x 3 seeds = 12 parallel containers across Acc1+Acc2; live deploy IDs already on Railway as of 2026-04-27. - README.md — DR block at top with Deploy-on-Railway button. Live multi-account state at commit time: Acc1 (rumbodzalaclhdv0) -> IGLA: igla-final-seed-{42,43,44} deployed Acc2 (brabbtjubindt5cug) -> thriving-eagerness: iglaB-seed-{42,43,44} deployed Acc0 (kaglerslomaansc) -> token revoked, awaiting new token. R7 triplets sealed for all 3 IGLA deploys; see docs/MULTI_ACCOUNT.md. Closes openai#20 (DR plan), refs openai#143 (IGLA RACE). Co-authored-by: Perplexity Computer <computer@perplexity.ai>
…rainer spawn() Bisect analysis showed these flags cause BPB≈0 collapse for 8K steps. Old seed_agent.rs only passed --seed/--steps/--hidden/--lr and produced realistic BPB (2.5-3.2) from step 1. New ExternalTrainer with these flags produces BPB≈0 until step 8K then phase-transition. Evidence: - ID 231 (old, h=1024, lr=0.0019): BPB=3.18 at step 500 (realistic) - ID 1387 (new, h=1024, lr=0.003): BPB=0.0003 at step 1000 (collapsed) Closes openai#20 Agent: GENERAL
…mat/--attn-layers removed Bisect round 2: removing ALL flags caused trainer to hang (no output). Experiment openai#800 (BPB=1.82 at step=100) had ctx=12 and worked fine. Without --ctx, trios-train produces zero output (subprocess hangs). --format and --attn-layers remain removed (caused BPB≈0 collapse). Closes openai#20 Agent: GENERAL
Railway logs showed 'error: unexpected argument --ctx found' on every experiment. The trios-train binary does not accept --ctx. Removing it fixes the zero-steps bug that caused ALL 207 experiments to fail. Also fixes broken integration test (use crate:: on binary crate). Closes openai#20 Agent: GENERAL
Current Finding
I implemented QAT with STE in the MLX training stack to reduce degradation from post-training quantization. Early baseline runs suggest the pre-quant vs post-quant val_bpb gap widens over longer training, reaching 0.0325 BPB in a 4-hour baseline, which makes deployment-aware training increasingly important for competitive runs.
Implementation
fake_quantize_per_row()to simulate int8 per-row quantization during trainingCastedLinear->CausalSelfAttention->MLP->Block->GPT->lossHypothesis
Training with fake quantization in the loop should reduce the final quantization gap and improve the deployed model's post-quantized validation BPB under the 16MB artifact constraint.
Why Compute Is Needed
Meaningful validation requires full 10-minute 8xH100 runs and multi-seed significance testing. Local Apple Silicon is sufficient for correctness checks and short smoke tests, but not for the training scale needed to validate small BPB gains.
Planned Extensions
Test Plan
from train_gpt_mlx import GPT)