Each step follows test-first: write pytest tests → implement until tests pass.
Prerequisites before starting:
- A saved MLX checkpoint must exist for diagnostic probes (S8-S10). If none exists, run
python train_gpt_mlx.pywithMAX_WALLCLOCK_SECONDS=300 ITERATIONS=5000to produce one (~5 min). - Training data shards at
data/datasets/fineweb10B_sp1024/(seedata/README.mdfor download).
Phase A: Foundation
S1 → S2
Phase B: DAG Discovery
S3 → S4 → S5
Phase C: Experiment Pipeline (parallel with B after S2)
S6 → S7
Phase D: Diagnostic Probes (parallel with B+C after S2)
S8, S9, S10, S11
Phase E: Integration
S12
Complexity: Simple | Deps: None
-
uv add causal-learn statsmodels networkx graphviz scipy pytest -
brew install graphvizifdot -Vfails -
Create
scripts/causal/,results/causal/,results/causal/diagnostics/,tests/causal/ -
Add
results/causal/to.gitignore -
Create
scripts/causal/__init__.py -
Run
uv pip checkto verify no dependency conflicts (causal-learn pulls torch/sklearn)
Verify: All imports succeed, dot -V exits 0, no dependency conflicts.
Complexity: Medium | Deps: S1
Prerequisite: python -c "import train_gpt_mlx; print(train_gpt_mlx.GPT)" must: (1) exit 0, (2) complete in <5s, (3) print the class reference (no FileNotFoundError/ImportError). If fails, extract GPT/Block/CausalSelfAttention/MLP/Rotary classes into scripts/causal/model.py (Medium complexity fallback). Note: train_gpt_mlx.py has module-level class definitions (lines 1-1100) and if __name__ == "__main__": guard at line 1103 — import should be safe.
Tests (tests/causal/test_common.py): One test per function — load_submission_json, load_model, compute_bpb, paired_ttest, holm_bonferroni, decision_gate, log_experiment, get_cycle_dir, dag_diff.
Implement (scripts/causal/common.py):
load_submission_json(path)— parse submission.jsonload_model(checkpoint_path, config_overrides)— import GPT via__name__guard (TD-7), load weights, load tokenizercompute_bpb(model, val_tokens, sp_model)— reuse eval_val logicpaired_ttest(treatment, control)— scipy.stats.ttest_rel + bootstrap CIholm_bonferroni(p_values, alpha)— wrap statsmodels multipletestsdecision_gate(effect_size, p_value, mde=0.002)— "confirmed"/"suggestive"/"null"log_experiment(path, entry)— append to experiment_log.jsonget_cycle_dir(base_path, cycle)— createresults/causal/cycle_N/dag_diff(old_dag, new_dag)— edges added/removed/strengthened
Complexity: Medium | Deps: S2
Tests (tests/causal/test_extract.py): Each parser format with test fixtures drawn from actual record READMEs (not synthetic). Field coverage computation. --append-experiment with mock data. Unknown format fallback test.
Implement:
- submission.json pass — 6 core fields from all discovered records
- Format A parser —
| Change | ... | Impact |tables (~2 records) - Format B parser —
| | Base | This |tables (~8 records), compute delta - Format C fallback — headings, bullets, blurb (~10 records)
- Cross-reference pass — compute total delta from base_bpb
--append-experimentmode — convert experiment results to submission format- Field coverage metric
Verify: field_coverage ≥ 0.90 on actual records. All discovered submissions parsed.
Complexity: Complex | Deps: S3
Tests (tests/causal/test_estimate_dag.py): Expert skeleton edges. Binary encoding dimensions. FCI on synthetic data. Near-degenerate case (n=20, correlated binary cols). LinAlgError caught → expert fallback. Edge tagging. --previous-dag diff. DOT renders.
Implement:
- Expert-guided skeleton — hardcode known causal relationships, tag as
expert_imposed - Binary encoding of interventions (presence/absence matrix)
- FCI validation (causal-learn, Fisher-Z, alpha=0.01) with try/except for numerical errors
- Degenerate detection (empty/fully-connected → keep expert only)
- Edge tagging (expert_imposed → data_confirmed/data_contradicted/uncertain)
- Marginal effect estimation per node
next_interventionrecommendation (highest expected BPB improvement among uncertain edges)--previous-dagmode (dag_diff, edge stability tracking)- DOT visualization via graphviz
- Write
scripts/causal/README.md— cycle protocol and CLI usage
Verify: ≥5 nodes. dag.png renders. next_intervention non-null. Edge tags valid.
Complexity: Simple | Deps: S3, S4
Tests (tests/causal/test_identifiability.py): Synthetic interventions with known counts. Proceed/skip threshold. Combination enumeration.
Implement:
- Count single-variable and multi-variable records
- Identifiability score (fraction of testable edges)
- Confounded pairs (always co-occurring interventions)
- Proceed/skip recommendation (>50% multi-variable → skip)
- Unexplored combinations with expected effects × interaction priors
Complexity: Complex | Deps: S2
Tests (tests/causal/test_experiment_runner.py): Config validation. Stdout BPB parsing with mock output. Partial failure (1/3 seeds). Timeout handling.
Implement:
- Config loading + validation (env_overrides schema)
- Subprocess invocation with SEED env var, timeout = wallclock + 120s
- Parse LAST occurrence of
val_bpb:<float>from complete stdout (not just final line), capture stderr - Fallback: parse from training log file at
logs/{run_id}/train.logif stdout parsing fails - Error handling: crash → partial result with error field; 1/3 fail → reduced_power flag; 2+/3 fail → condition failed
- Per-run JSON-lines metrics capture
- Checkpoint/log path capture
- 3 seeds × 2 conditions → raw_runs.json
- Append to experiment_log.json
Verify: Dry-run (ITERATIONS=10, MAX_WALLCLOCK_SECONDS=30). Valid raw_runs.json schema.
Complexity: Medium | Deps: S2, S6 output
Tests (tests/causal/test_statistical.py): Synthetic data with known effect. CI contains true effect. Holm-Bonferroni adjusts upward. Decision gate classification.
Implement:
- Load raw_runs.json, extract per-seed BPB pairs (handle partial failures)
- Paired differences, mean effect, bootstrapped 95% CI
- Paired t-test p-value
- Holm-Bonferroni correction
- Decision gate classification
- Platform transfer coefficient if MLX + H100 data
Complexity: Medium | Deps: S2 | Prereq: Saved checkpoint exists
Tests (tests/causal/test_token_loss.py): Decomposition check with mock output. Frequency bucketing. BPB contribution summation.
Implement:
- Load model + validation data via common.py
- Forward pass with reduction='none' for per-token losses
- Decomposition verification: mean(per_token) matches aggregate within 1e-6
- Frequency buckets (top-100, 100-500, 500-1024)
- Boundary vs. mid-sequence classification
- Per-category statistics
Verify: decomposition_check.passed. Buckets sum to total. BPB contributions sum to aggregate (±0.001).
Complexity: Medium | Deps: S2 | Prereq: Saved checkpoint exists
Tests (tests/causal/test_quant_gap.py): Gap computation with mock BPB. Threshold check logic.
Implement:
- Load model, eval pre-quant BPB
- Quantize → dequantize → eval post-quant BPB (reuse train_gpt_mlx.py functions)
- Gap and threshold check (gap > 3× largest training effect)
- Optional: per-token category comparison pre/post quant
Complexity: Medium | Deps: S2 | Prereq: Saved checkpoint + training shards
Tests (tests/causal/test_influence.py): Dot product with small mock gradients. CV calculation. Skip threshold.
Implement:
- Load model via common.py
- Validation gradient via plain nn.value_and_grad (trainable params only)
- Memory check after first shard — warn and reduce --max-shards if needed
- Iterate shards (4096 tokens/shard), per-shard gradient + dot product
mx.eval()after each (hard requirement)- Sort scores, compute CV, skip recommendation if CV < 0.1
Verify: --max-shards 5, scores sorted, CV non-negative.
Complexity: Complex | Deps: S1, S2
Tests (tests/causal/test_gradient_attr.py): LAST occurrence targeting. Sentinel validation on current source. Sentinel fails on modified source. Patched file syntax check: ast.parse() on the instrumented output to verify it's valid Python. JSON-lines parsing.
Implement:
- Verify sentinel strings exist in current train_gpt_mlx.py
- Find LAST
accumulate_flat_grads, validate dual sentinel (train_loss + lr_mul within ±5 lines) - Insert gradient norm logging, write
train_gpt_mlx_instrumented.py - Execute via subprocess (short training for test)
- Parse JSON-lines, compute phase boundaries from lr_mul transitions
- Per-phase correlations between layer norms and val_loss
Complexity: Medium | Deps: Confirmed effect from cycle OR engineering fallback
Tests (tests/causal/test_submission.py): submission.json schema. Artifact size ≤ 16MB. README.md required sections. All files present.
Implement:
- Map causal findings to train_gpt.py code changes
- Verify artifact ≤ 16MB, training ≤ 10min on 8×H100
- 3-seed validation on H100
- README.md with ablation table
- submission.json with metadata
- Engineering fallback (R5.3) if no causal findings
S1 → S2 → S3 → S4 → S5
│
├→ S6 → S7 S7 → feed back to S3 → S4 (cycle)
│
├→ S8, S9, S10 (parallel diagnostics)
└→ S11
S12 ← confirmed effect OR time gate
Manual, researcher-driven. Protocol documented in scripts/causal/README.md.
Cycle 0: S3 → S4 → S5 + S8-S11 in parallel → read next_intervention → create configs
Cycle 1+: S6 → S7 → decision gate → S3 --append → S4 --previous-dag → repeat
Stop: confirmed effect | 3 null streak | 4 cycles max | 2-day gate
Final: S12 with best findings
| Risk | Mitigation |
|---|---|
| FCI degenerate (n=20) | Expert DAG primary. FCI validation only. |
| Records lack structured tables | Three-tier parser. Coverage computed, not assumed. |
| MLX→H100 transfer fails | S6 --platform flag. S7 transfer coefficient. |
| Low influence variance | S10 CV check. CV < 0.1 → documented null. |
| Patch site drifts | S11 LAST occurrence + dual sentinel. |
| All causal angles null | S12 engineering fallback. Null results = scientific contribution. |