Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
7326a9a
pd: specify review iteration 1
Mar 24, 2026
e317f9e
pd: specify review iteration 2
Mar 24, 2026
86c5b7a
pd: specify phase-review iteration 1
Mar 24, 2026
e7ca2dc
pd: design review iteration 1
Mar 24, 2026
d662200
pd: design review iteration 2
Mar 24, 2026
fc3d771
pd: design handoff review iteration 1
Mar 24, 2026
5892207
pd: design revision - discovery-adjust cycle, observability, schema f…
Mar 24, 2026
1b54988
pd: design review iteration 1 (post-revision) - fix cycle versioning,…
Mar 24, 2026
d61d570
pd: design review iteration 2 (post-revision) - fix I4 path, I2 appen…
Mar 24, 2026
b6d9222
pd: design - fix I5 versioned path, C5 parsing/overrides, cycle_statu…
Mar 24, 2026
7847d9c
pd: design - fix sentinel docs, I6 versioned path, cycle_status forma…
Mar 24, 2026
14c9f27
pd: design - address final phase-review warnings: spec overrides, sto…
Mar 24, 2026
8b34eb3
pd: design - use pyenv+uv for dependency management, remove scripts/c…
Mar 24, 2026
9164a54
pd: plan review iteration 1 - add TDD ordering, import safety check, …
Mar 24, 2026
82217b5
pd: plan review iteration 2 - S12 tests, import safety criteria, chec…
Mar 24, 2026
80be95c
pd: plan phase-review iteration 1 - fix S5 deps, S12 ordering, S4a RE…
Mar 24, 2026
f220d76
pd: plan simplification - merge S4a/S4b, remove verbose rationale, ti…
Mar 24, 2026
4d003c5
pd: plan review fixes - checkpoint prereq, import criteria, syntax te…
Mar 24, 2026
3d1bed6
pd: tasks review iteration 1 - split T4/T11, add platform flag, funct…
Mar 24, 2026
ec6e45d
pd: tasks - fix T4→T4b deps, add R4.2 consolidation note to T17
Mar 24, 2026
d4f0c3a
pd: implement T1-T4b — project setup, common.py with tests
Mar 24, 2026
7b38523
Implement T14, T15, T16: statistical analysis, token loss decompositi…
Mar 24, 2026
c8a973c
pd: implement T5-T18 — all causal analysis scripts with tests
Mar 24, 2026
01c6aba
pd: implement T10+T11 — README and identifiability_check with tests
Mar 24, 2026
089e49e
pd: fix review issues — import path, spec consolidation, phase correl…
Mar 24, 2026
fb617eb
pd: fix checkpoint test — use .npz format matching train_gpt_mlx.py o…
Mar 25, 2026
9d394c3
pd: finish feature — mark T1-T18 complete, T19 deferred
Mar 25, 2026
c51e123
pd: implement T19 — submission assembly with dry-run support
Mar 25, 2026
d9f7fa5
pd: mark T19 complete (except H100 validation)
Mar 25, 2026
7cdcc93
Merge feature/1-causal-inference-training into clthuang-dev
Mar 25, 2026
b8293bb
pd: add auto research pipeline + fix subprocess import paths
Mar 25, 2026
3d20bc7
Add in-process MLX training runner to avoid subprocess warmup overhead
Mar 25, 2026
1c815ed
pd: add in-process trainer with warmup caching
Mar 25, 2026
9324789
pd: integrate in-process trainer into pipeline + upgrade to Python 3.12
Mar 25, 2026
eca41ee
fix: add sys.path for in-process trainer import in run_pipeline.py
Mar 25, 2026
ed383c7
fix: load_model test prefers default_ckpt to avoid shape mismatches
Mar 25, 2026
32c64df
pd: refactor + quality improvements across pipeline
Mar 25, 2026
1b80bed
perf: reduce screening warmup from 20 steps to 1, add GRAD_ACCUM_STEPS=1
Mar 25, 2026
a7eb5cb
perf: add progress logging + validate only at end for screening
Mar 25, 2026
3461485
pd: track per-step losses + screening_mode + loss curve plotting
Mar 25, 2026
96a3c27
feat: add activation function variants to screening pipeline
Mar 25, 2026
4dbb2ab
fix: remove broken FAN activation + add Priority 3 search space sweep
Mar 25, 2026
5ddd803
feat: add Rho-1 selective loss and Adaptive-K multi-token prediction …
Mar 26, 2026
6033599
feat: add --fast flag and balanced reduction for relative screening
Mar 26, 2026
f85970e
perf: bump --fast iteration preset from 50 to 300 for longer-term signal
Mar 26, 2026
21faf34
fix: move _logger creation before activation patch to avoid NameError
Mar 26, 2026
98d65cf
feat: add --cooldown flag to prevent GPU thermal throttling
chefterryhuang Mar 26, 2026
94ca331
Non-record submission: sin^2 activation + causal screening pipeline
chefterryhuang Mar 26, 2026
1c64387
Merge remote-tracking branch 'origin/main' into clthuang-dev
chefterryhuang Apr 4, 2026
862e656
feat: Phase 2 experiment matrix — 20 experiments with full screening …
chefterryhuang Apr 5, 2026
faaa0f0
docs: comprehensive experiment results documentation
chefterryhuang Apr 5, 2026
d4b2124
feat: add submission-ready R2-11 corrupted context training
chefterryhuang Apr 5, 2026
1d85c16
feat: switch submission to int6 GPTQ-lite + lzma compression
chefterryhuang Apr 5, 2026
552a56b
feat: V2 submission with SOTA training fundamentals
chefterryhuang Apr 6, 2026
78c32e3
fix: revert batch to 524K (130ms/step too slow), bigram_dim 112 for <…
chefterryhuang Apr 6, 2026
4d77a8c
feat: per-group-64 int6 bit-packed quantization (8.3MB artifact)
chefterryhuang Apr 6, 2026
24e5820
feat: increase MLP 3x→4x to use compression headroom (32.7M params, 1…
chefterryhuang Apr 6, 2026
4d428fe
docs: update README and submission.json for V3 submission
chefterryhuang Apr 6, 2026
756497a
feat: add AWQ activation-aware quantization with alpha sweep
chefterryhuang Apr 6, 2026
4617467
feat: Full Hessian GPTQ with Cholesky error compensation
chefterryhuang Apr 6, 2026
efff560
feat: Phase C fork - paper GPTQ-lite script + corrupted context
chefterryhuang Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,11 @@ data/manifest.json
data/docs_selected.jsonl
.mypy_cache/
.venv
logs/
logs/

# clthuang
local_docs
local-docs/

# Causal inference results (generated)
results/causal/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
272 changes: 272 additions & 0 deletions docs/brainstorms/20260324-000000-causal-inference-training.prd.md

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions docs/features/1-causal-inference-training/.meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"id": "1",
"slug": "causal-inference-training",
"mode": "standard",
"status": "active",
"created": "2026-03-24T12:57:20.718072+00:00",
"branch": "feature/1-causal-inference-training",
"brainstorm_source": "docs/brainstorms/20260324-000000-causal-inference-training.prd.md",
"lastCompletedPhase": null,
"phases": {
"brainstorm": {
"started": "2026-03-24T12:57:20.717802+00:00"
}
}
}
785 changes: 785 additions & 0 deletions docs/features/1-causal-inference-training/design.md

Large diffs are not rendered by default.

261 changes: 261 additions & 0 deletions docs/features/1-causal-inference-training/plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Plan: Causal Inference for Parameter-Efficient LM Training

## Implementation Order

Each step follows **test-first**: write pytest tests → implement until tests pass.

**Prerequisites before starting**:
- A saved MLX checkpoint must exist for diagnostic probes (S8-S10). If none exists, run `python train_gpt_mlx.py` with `MAX_WALLCLOCK_SECONDS=300 ITERATIONS=5000` to produce one (~5 min).
- Training data shards at `data/datasets/fineweb10B_sp1024/` (see `data/README.md` for download).

```
Phase A: Foundation
S1 → S2

Phase B: DAG Discovery
S3 → S4 → S5

Phase C: Experiment Pipeline (parallel with B after S2)
S6 → S7

Phase D: Diagnostic Probes (parallel with B+C after S2)
S8, S9, S10, S11

Phase E: Integration
S12
```

## Steps

### S1: Project Setup
**Complexity**: Simple | **Deps**: None

1. `uv add causal-learn statsmodels networkx graphviz scipy pytest`
2. `brew install graphviz` if `dot -V` fails
3. Create `scripts/causal/`, `results/causal/`, `results/causal/diagnostics/`, `tests/causal/`
4. Add `results/causal/` to `.gitignore`
5. Create `scripts/causal/__init__.py`

6. Run `uv pip check` to verify no dependency conflicts (causal-learn pulls torch/sklearn)

**Verify**: All imports succeed, `dot -V` exits 0, no dependency conflicts.

---

### S2: common.py (C1)
**Complexity**: Medium | **Deps**: S1

**Prerequisite**: `python -c "import train_gpt_mlx; print(train_gpt_mlx.GPT)"` must: (1) exit 0, (2) complete in <5s, (3) print the class reference (no FileNotFoundError/ImportError). If fails, extract GPT/Block/CausalSelfAttention/MLP/Rotary classes into `scripts/causal/model.py` (Medium complexity fallback). Note: train_gpt_mlx.py has module-level class definitions (lines 1-1100) and `if __name__ == "__main__":` guard at line 1103 — import should be safe.

**Tests** (`tests/causal/test_common.py`): One test per function — load_submission_json, load_model, compute_bpb, paired_ttest, holm_bonferroni, decision_gate, log_experiment, get_cycle_dir, dag_diff.

**Implement** (`scripts/causal/common.py`):
1. `load_submission_json(path)` — parse submission.json
2. `load_model(checkpoint_path, config_overrides)` — import GPT via `__name__` guard (TD-7), load weights, load tokenizer
3. `compute_bpb(model, val_tokens, sp_model)` — reuse eval_val logic
4. `paired_ttest(treatment, control)` — scipy.stats.ttest_rel + bootstrap CI
5. `holm_bonferroni(p_values, alpha)` — wrap statsmodels multipletests
6. `decision_gate(effect_size, p_value, mde=0.002)` — "confirmed"/"suggestive"/"null"
7. `log_experiment(path, entry)` — append to experiment_log.json
8. `get_cycle_dir(base_path, cycle)` — create `results/causal/cycle_N/`
9. `dag_diff(old_dag, new_dag)` — edges added/removed/strengthened

---

### S3: extract_interventions.py (C2)
**Complexity**: Medium | **Deps**: S2

**Tests** (`tests/causal/test_extract.py`): Each parser format with test fixtures drawn from actual record READMEs (not synthetic). Field coverage computation. --append-experiment with mock data. Unknown format fallback test.

**Implement**:
1. submission.json pass — 6 core fields from all discovered records
2. Format A parser — `| Change | ... | Impact |` tables (~2 records)
3. Format B parser — `| | Base | This |` tables (~8 records), compute delta
4. Format C fallback — headings, bullets, blurb (~10 records)
5. Cross-reference pass — compute total delta from base_bpb
6. `--append-experiment` mode — convert experiment results to submission format
7. Field coverage metric

**Verify**: field_coverage ≥ 0.90 on actual records. All discovered submissions parsed.

---

### S4: estimate_dag.py (C3)
**Complexity**: Complex | **Deps**: S3

**Tests** (`tests/causal/test_estimate_dag.py`): Expert skeleton edges. Binary encoding dimensions. FCI on synthetic data. Near-degenerate case (n=20, correlated binary cols). LinAlgError caught → expert fallback. Edge tagging. --previous-dag diff. DOT renders.

**Implement**:
1. Expert-guided skeleton — hardcode known causal relationships, tag as `expert_imposed`
2. Binary encoding of interventions (presence/absence matrix)
3. FCI validation (causal-learn, Fisher-Z, alpha=0.01) with try/except for numerical errors
4. Degenerate detection (empty/fully-connected → keep expert only)
5. Edge tagging (expert_imposed → data_confirmed/data_contradicted/uncertain)
6. Marginal effect estimation per node
7. `next_intervention` recommendation (highest expected BPB improvement among uncertain edges)
8. `--previous-dag` mode (dag_diff, edge stability tracking)
9. DOT visualization via graphviz
10. Write `scripts/causal/README.md` — cycle protocol and CLI usage

**Verify**: ≥5 nodes. dag.png renders. next_intervention non-null. Edge tags valid.

---

### S5: identifiability_check.py (C4)
**Complexity**: Simple | **Deps**: S3, S4

**Tests** (`tests/causal/test_identifiability.py`): Synthetic interventions with known counts. Proceed/skip threshold. Combination enumeration.

**Implement**:
1. Count single-variable and multi-variable records
2. Identifiability score (fraction of testable edges)
3. Confounded pairs (always co-occurring interventions)
4. Proceed/skip recommendation (>50% multi-variable → skip)
5. Unexplored combinations with expected effects × interaction priors

---

### S6: experiment_runner.py (C5)
**Complexity**: Complex | **Deps**: S2

**Tests** (`tests/causal/test_experiment_runner.py`): Config validation. Stdout BPB parsing with mock output. Partial failure (1/3 seeds). Timeout handling.

**Implement**:
1. Config loading + validation (env_overrides schema)
2. Subprocess invocation with SEED env var, timeout = wallclock + 120s
3. Parse LAST occurrence of `val_bpb:<float>` from complete stdout (not just final line), capture stderr
4. Fallback: parse from training log file at `logs/{run_id}/train.log` if stdout parsing fails
5. Error handling: crash → partial result with error field; 1/3 fail → reduced_power flag; 2+/3 fail → condition failed
6. Per-run JSON-lines metrics capture
7. Checkpoint/log path capture
8. 3 seeds × 2 conditions → raw_runs.json
9. Append to experiment_log.json

**Verify**: Dry-run (ITERATIONS=10, MAX_WALLCLOCK_SECONDS=30). Valid raw_runs.json schema.

---

### S7: statistical_analysis.py (C6)
**Complexity**: Medium | **Deps**: S2, S6 output

**Tests** (`tests/causal/test_statistical.py`): Synthetic data with known effect. CI contains true effect. Holm-Bonferroni adjusts upward. Decision gate classification.

**Implement**:
1. Load raw_runs.json, extract per-seed BPB pairs (handle partial failures)
2. Paired differences, mean effect, bootstrapped 95% CI
3. Paired t-test p-value
4. Holm-Bonferroni correction
5. Decision gate classification
6. Platform transfer coefficient if MLX + H100 data

---

### S8: token_loss_decompose.py (C7)
**Complexity**: Medium | **Deps**: S2 | **Prereq**: Saved checkpoint exists

**Tests** (`tests/causal/test_token_loss.py`): Decomposition check with mock output. Frequency bucketing. BPB contribution summation.

**Implement**:
1. Load model + validation data via common.py
2. Forward pass with reduction='none' for per-token losses
3. Decomposition verification: mean(per_token) matches aggregate within 1e-6
4. Frequency buckets (top-100, 100-500, 500-1024)
5. Boundary vs. mid-sequence classification
6. Per-category statistics

**Verify**: decomposition_check.passed. Buckets sum to total. BPB contributions sum to aggregate (±0.001).

---

### S9: quant_gap_analysis.py (C8)
**Complexity**: Medium | **Deps**: S2 | **Prereq**: Saved checkpoint exists

**Tests** (`tests/causal/test_quant_gap.py`): Gap computation with mock BPB. Threshold check logic.

**Implement**:
1. Load model, eval pre-quant BPB
2. Quantize → dequantize → eval post-quant BPB (reuse train_gpt_mlx.py functions)
3. Gap and threshold check (gap > 3× largest training effect)
4. Optional: per-token category comparison pre/post quant

---

### S10: influence_proxy.py (C9)
**Complexity**: Medium | **Deps**: S2 | **Prereq**: Saved checkpoint + training shards

**Tests** (`tests/causal/test_influence.py`): Dot product with small mock gradients. CV calculation. Skip threshold.

**Implement**:
1. Load model via common.py
2. Validation gradient via plain nn.value_and_grad (trainable params only)
3. Memory check after first shard — warn and reduce --max-shards if needed
4. Iterate shards (4096 tokens/shard), per-shard gradient + dot product
5. `mx.eval()` after each (hard requirement)
6. Sort scores, compute CV, skip recommendation if CV < 0.1

**Verify**: --max-shards 5, scores sorted, CV non-negative.

---

### S11: gradient_attribution.py (C10)
**Complexity**: Complex | **Deps**: S1, S2

**Tests** (`tests/causal/test_gradient_attr.py`): LAST occurrence targeting. Sentinel validation on current source. Sentinel fails on modified source. **Patched file syntax check**: `ast.parse()` on the instrumented output to verify it's valid Python. JSON-lines parsing.

**Implement**:
1. Verify sentinel strings exist in current train_gpt_mlx.py
2. Find LAST `accumulate_flat_grads`, validate dual sentinel (train_loss + lr_mul within ±5 lines)
3. Insert gradient norm logging, write `train_gpt_mlx_instrumented.py`
4. Execute via subprocess (short training for test)
5. Parse JSON-lines, compute phase boundaries from lr_mul transitions
6. Per-phase correlations between layer norms and val_loss

---

### S12: Submission Assembly
**Complexity**: Medium | **Deps**: Confirmed effect from cycle OR engineering fallback

**Tests** (`tests/causal/test_submission.py`): submission.json schema. Artifact size ≤ 16MB. README.md required sections. All files present.

**Implement**:
1. Map causal findings to train_gpt.py code changes
2. Verify artifact ≤ 16MB, training ≤ 10min on 8×H100
3. 3-seed validation on H100
4. README.md with ablation table
5. submission.json with metadata
6. Engineering fallback (R5.3) if no causal findings

## Dependency Graph

```
S1 → S2 → S3 → S4 → S5
├→ S6 → S7 S7 → feed back to S3 → S4 (cycle)
├→ S8, S9, S10 (parallel diagnostics)
└→ S11

S12 ← confirmed effect OR time gate
```

## Discovery-Adjust Cycle

**Manual, researcher-driven.** Protocol documented in `scripts/causal/README.md`.

```
Cycle 0: S3 → S4 → S5 + S8-S11 in parallel → read next_intervention → create configs
Cycle 1+: S6 → S7 → decision gate → S3 --append → S4 --previous-dag → repeat
Stop: confirmed effect | 3 null streak | 4 cycles max | 2-day gate
Final: S12 with best findings
```

## Risks

| Risk | Mitigation |
|------|-----------|
| FCI degenerate (n=20) | Expert DAG primary. FCI validation only. |
| Records lack structured tables | Three-tier parser. Coverage computed, not assumed. |
| MLX→H100 transfer fails | S6 --platform flag. S7 transfer coefficient. |
| Low influence variance | S10 CV check. CV < 0.1 → documented null. |
| Patch site drifts | S11 LAST occurrence + dual sentinel. |
| All causal angles null | S12 engineering fallback. Null results = scientific contribution. |
Loading