Skip to content

Commit c4cf4b2

Browse files
committed
Non-record: BitNet b1.58 Ternary v2 — 68M params, val_bpb=1.1770 (sliding window)
BitNet b1.58 ternary quantization with full-training STE. 68M params in 15.88MB via base-3 packing (1.6 bits/param). Near-lossless roundtrip (0.0016 BPB gap). Systematic analysis of why the standard competition stack breaks for ternary: - XSA, weight decay, grad clipping: cause training plateau at 2.4 - SmearGate, BigramHash, OrthoInit: hurt or no effect - EMA/SWA: fundamentally incompatible - TTT: no improvement on ternary models What works: higher LR (0.04), wider MLP, fp16 scale simulation, longer warmdown. Improves on PR openai#139 (1.2029 → 1.1770).
1 parent 0f51451 commit c4cf4b2

5 files changed

Lines changed: 3205 additions & 0 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# BitNet b1.58 Ternary v2 — 68M Params, val_bpb=1.1770
2+
3+
**Non-record submission.** Explores ternary quantization (1.58-bit weights) as an alternative to the standard int6 stack. Improves on our earlier PR #139 (1.2029 → 1.1770) and documents systematic findings on what works and what breaks for ternary models.
4+
5+
## Results
6+
7+
| Metric | Value |
8+
|---|---|
9+
| Standard val_bpb | 1.1983 |
10+
| Ternary roundtrip val_bpb | 1.1999 |
11+
| **Sliding window val_bpb (stride=64)** | **1.1770** |
12+
| Roundtrip gap | 0.0016 |
13+
| Artifact size | 15,882,274 bytes (15.88 MB) |
14+
| Training steps | 4,591 in 600s (130.7 ms/step) |
15+
| Model params | 68M ternary |
16+
| Peak memory | 23.2 GB per GPU |
17+
18+
Improvement over PR #139: **1.2029 → 1.1770** (−0.026 BPB).
19+
20+
## Architecture
21+
22+
- 12 transformer layers, 768 dim, 12 heads, 6 KV heads (GQA)
23+
- MLP 3.25× (hidden 2496), relu² activation
24+
- BitLinear for all attention + MLP projections (ternary {-1, 0, 1} with per-group absmax STE)
25+
- U-Net skip connections, tied fp16 embeddings
26+
- RoPE base 200,000, logit softcap 30.0
27+
- Muon optimizer (LR=0.04, momentum 0.99) + Adam for embeddings/scalars
28+
29+
## What Worked for Ternary
30+
31+
### Wider MLP (2304 → 2496)
32+
Ternary packing at 1.6 bits/param makes extra parameters almost free in the artifact. Widening the MLP from 3× to 3.25× added 4.4M params but only 0.8MB to the artifact. This gave a direct improvement: 1.2029 → 1.1983 (standard eval).
33+
34+
### Higher Learning Rate (0.04 vs competition's 0.02–0.025)
35+
Ternary STE gradients are inherently noisy — the quantization function snaps continuous weights to just 3 levels, and the STE passes gradients through as if this didn't happen. Higher LR helps the optimizer punch through this noise floor. The competition consensus LR (0.02–0.025) was tuned for int6 models where STE noise is 20× smaller.
36+
37+
### fp16 Scale Simulation During Training
38+
The critical insight from v1: training computes scales in fp32, but serialization stores them in fp16. Different precision → different rounding → different ternary values → 0.05 BPB gap. Training with `.half().float()` on scales simulates fp16 precision during training, closing the gap to **0.0016 BPB**. This is the single most important technique for ternary roundtrip fidelity.
39+
40+
### Longer Warmdown
41+
The final 25% of training (LR decaying linearly to 0) was the most productive phase. In v1, the last 500 steps alone improved BPB by 0.03. At low LR, the continuous shadow weights converge to values that round cleanly to {-1, 0, 1}, reducing quantization error. Ternary benefits from warmdown more than fp16 models because the quantization grid is so coarse.
42+
43+
### Base-3 Packing (1.6 bits/param)
44+
Ternary values {-1, 0, 1} are encoded as base-3: 5 trits per byte (3⁵ = 243 ≤ 255). This achieves 1.6 bits/param — near the information-theoretic minimum of log₂(3) = 1.585 bits. Combined with LZMA compression, 68M ternary params + fp16 scales fit in 15.82MB.
45+
46+
### 68M Parameters in 15.88MB
47+
At 1.6 bits/param, ternary packs 2.4× more parameters than int6 (68M vs 27M) in the same artifact budget. While each ternary param is less expressive, the sheer parameter count provides more model capacity than might be expected.
48+
49+
## What Broke the Network
50+
51+
These techniques caused the model to plateau at val_loss ≈ 2.4 (vs normal convergence to ≈ 2.0):
52+
53+
### XSA (Exclusive Self-Attention)
54+
XSA subtracts the self-value projection from attention output, removing the "self-attention bias" where tokens attend too much to themselves. For int6 models this gives +0.002 BPB. For ternary, it causes a **complete training plateau at 2.4**. Ternary attention weights are too coarse to exhibit self-attention bias in the first place — the attention patterns are already sparse and noisy. Removing the self-value component removes signal, not noise.
55+
56+
### Weight Decay (0.04)
57+
Decoupled weight decay shrinks continuous weights toward zero. For int6, this improves quantization robustness and generalization. For ternary, it causes **training collapse**. Shrinking weights toward zero means more weights quantize to 0 instead of ±1, making the model progressively sparser until it loses capacity. The ternary STE already provides implicit regularization — adding explicit WD on top is destructive.
58+
59+
### Gradient Clipping (0.3–1.0)
60+
The Muon optimizer normalizes gradients via Newton-Schulz orthogonalization. Adding gradient clipping on top double-normalizes the signal. With 68M params, the global gradient norm is naturally large. Clipping at 1.0 caused training to stall after ~500 steps — the model learned basic patterns, then the clipped gradients couldn't push through to learn more.
61+
62+
## What Didn't Help (But Didn't Break)
63+
64+
### SmearGate + BigramHash
65+
SmearGate blends each token's embedding with the previous token's (learned per-dim gate, 768 params). BigramHash adds a hash-table embedding for token pairs. Together they give +0.01 BPB for int6 models. For ternary, they **hurt by 0.02 BPB** regardless of initialization. The ternary weights can't adapt to the modified input distribution — the quantization is too coarse to learn the subtle adjustments needed.
66+
67+
### OrthoInit
68+
Orthogonal weight initialization is required for SmearGate to work in int6 models. For ternary, the carefully constructed orthogonal structure is **destroyed on the first forward pass** when weights snap to {-1, 0, 1}. The STE gradient tries to maintain orthogonality in the continuous shadow weights, but the actual computation uses ternary values that bear no resemblance to an orthogonal matrix. Default Kaiming initialization works equally well.
69+
70+
### Test-Time Training (TTT)
71+
Causal SGD adaptation during eval (evaluate chunk, record scores, train on chunk, repeat). Gives +0.01–0.03 BPB for int6 models. For ternary: **no improvement** in any variant tested:
72+
- Frozen ternary (only update norms/scales/gates): 1.176 vs 1.177 sliding — negligible
73+
- All params (update dequantized weights): no improvement
74+
75+
The dequantized weights (q × scale) sit in a region of weight space the model was never trained to operate in with continuous values. The loss landscape around ternary-structured weights isn't smooth enough for few-step SGD adaptation.
76+
77+
## QAT Insight
78+
79+
Our full-training STE (quantize every forward from step 0) achieves near-zero roundtrip gap but hurts convergence — the model trains slower because every gradient is corrupted by quantization noise. The competition's "Late QAT" approach (enable STE only in the last 15% of training) is likely optimal: full-precision convergence for 85% of training, then STE to close the quantization gap.
80+
81+
This suggests an unexplored direction: **int4 with late QAT**. Int4 gives 50% more params than int6 (32M vs 21M) in the same budget. With late QAT, the quantization gap should be near-zero (our ternary full-STE gap is only 0.0016 — int4 with 16 levels would be even smaller). Nobody in the competition has tried this.
82+
83+
## Negative Results Summary
84+
85+
| Technique | int6 effect | Ternary effect | Root cause |
86+
|---|---|---|---|
87+
| XSA | +0.002 | 💀 plateau | Ternary attention too coarse for self-bias |
88+
| Weight decay | +0.003 | 💀 plateau | Fights STE, causes sparsity collapse |
89+
| Grad clipping | standard | 💀 stalls | Double-normalizes with Muon |
90+
| SmearGate | +0.01 | −0.02 | Can't adapt to modified inputs |
91+
| OrthoInit | enables SmearGate | no effect | Destroyed by quantization |
92+
| EMA/SWA | +0.005 | 💀 broken | Averaging ternary = non-ternary |
93+
| TTT | +0.01–0.03 | no effect | Loss landscape not smooth around ternary |
94+
95+
## Run Command
96+
97+
```bash
98+
bash setup.sh
99+
bash run_8xh100.sh
100+
```
101+
102+
## References
103+
104+
- BitNet b1.58: [arXiv:2310.11453](https://arxiv.org/abs/2310.11453)
105+
- Muon optimizer: [kellerjordan.github.io/posts/muon](https://kellerjordan.github.io/posts/muon/)
106+
- Prior submission: PR #139 (val_bpb=1.2029)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/bin/bash
2+
set -e
3+
# BitNet v2 — exact v1 config + MLP 2496 + sliding/TTT eval
4+
RUN_ID=bitnet_v2_v1mlp \
5+
ITERATIONS=20000 \
6+
NUM_LAYERS=12 \
7+
MODEL_DIM=768 \
8+
NUM_HEADS=12 \
9+
NUM_KV_HEADS=6 \
10+
MLP_MULT=3 \
11+
MLP_HIDDEN=2496 \
12+
TRAIN_BATCH_TOKENS=524288 \
13+
TRAIN_SEQ_LEN=2048 \
14+
VAL_LOSS_EVERY=500 \
15+
VAL_BATCH_SIZE=524288 \
16+
MAX_WALLCLOCK_SECONDS=600 \
17+
TRAIN_LOG_EVERY=50 \
18+
LR_WARMUP_STEPS=50 \
19+
MATRIX_LR=0.04 \
20+
SCALAR_LR=0.04 \
21+
TIED_EMBED_LR=0.04 \
22+
MUON_MOMENTUM=0.99 \
23+
MUON_MOMENTUM_WARMUP_START=0.92 \
24+
MUON_MOMENTUM_WARMUP_STEPS=1500 \
25+
MUON_WD=0.0 \
26+
ROPE_BASE=200000 \
27+
EVAL_STRIDE=64 \
28+
EVAL_BATCH_SEQS=32 \
29+
WARMDOWN_ITERS=1200 \
30+
GRAD_CLIP_NORM=0.0 \
31+
XSA_LAST_N=0 \
32+
TTT_ENABLED=1 \
33+
TTT_LR=3e-4 \
34+
TTT_EPOCHS=3 \
35+
SEQ_RAMP_START=2048 \
36+
BATCH_RAMP_START=524288 \
37+
torchrun --standalone --nproc_per_node=8 train_bitnet.py
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"track": "non_record_16mb",
3+
"date": "2026-03-21",
4+
"name": "BitNet b1.58 Ternary v2 — 68M params, Wider MLP, Sliding Window",
5+
"author": "Etai Zilberman",
6+
"github_id": "ksang123",
7+
"blurb": "BitNet b1.58 ternary quantization with full-training STE. 68M params in 15.88MB via base-3 packing (1.6 bits/param). Near-lossless roundtrip (0.0016 BPB gap) from fp16 scale simulation. Systematic analysis of why the standard competition stack (SmearGate, XSA, WD, EMA, OrthoInit) breaks for ternary.",
8+
"val_loss": 1.98723407,
9+
"val_bpb": 1.17695507,
10+
"bytes_total": 15882274,
11+
"bytes_code": 63242
12+
}

0 commit comments

Comments
 (0)