Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Asymmetric Encoder-Decoder Split + Systematic Exploration

## Key Finding: Asymmetric 1/10 Encoder-Decoder Split

The hourglass architecture uses `num_encoder_layers = num_layers // 2` by default (5 encoder, 6 decoder for 11 layers). We discovered that **shifting almost all layers to the decoder** monotonically improves performance. One-line change: `self.num_encoder_layers = 1`.

### Asymmetric Split Sweep (RTX 5090, 11 layers, 300 steps, baseline code)

| Encoder/Decoder Split | int8_bpb | vs Default (5/6) |
|----------------------|----------|-------------------|
| 5/6 (default) | 1.5455 | — |
| 3/8 | 1.5421 | -0.0034 |
| 2/9 | 1.5369 | -0.0086 |
| **1/10** | **1.5298** | **-0.0157** |

Monotonic improvement across all 4 configurations. The pattern suggests the decoder benefits more from additional layers than the encoder.

### Validated on SOTA Code (RTX 5090, 300 steps)

Applied the 1/10 split to the current SOTA submission (PR #549 stack). With SWA disabled for short-run comparability:

| Split | val_bpb (pre-quant) | vs Default |
|-------|---------------------|------------|
| 5/6 (default) | 1.8070 | — |
| 1/10 | 1.8034 | -0.0036 |

The improvement transfers to SOTA code, confirming it's not an artifact of the baseline architecture.

### 8xH100 Partial Run

Applied the 1/10 split to the SOTA code (PR #549 stack) and ran on 8xH100 SXM with full competition settings. Flash Attention 3 was unavailable as a pip package, so we used FA2 as a fallback.

| Metric | Value |
|--------|-------|
| Pre-quant val_bpb | **1.1492** |
| Steps completed | 5666/9000 |
| Step avg | 105.9ms (vs 83.3ms with FA3) |
| Wall clock | 600s (10 min cap) |

**The run was handicapped by FA2's slower speed** (105ms vs 83ms/step), reaching only 5666 of 9000 steps. The pod crashed during the post-training eval phase (TTT + quantization), so we could not obtain the final int8 quantized score. With FA3 and all 9000 steps, we estimate this would place in the top 10 on the leaderboard.

**We ran out of GPU credits and H100 availability before we could complete a clean run with FA3.**

## Background: 27 Experiments on M4 MacBook

Before the GPU runs, we conducted 27 systematic experiments over ~30 hours on an Apple M4 MacBook (16GB, MLX backend) to identify promising techniques.

### Hardware & Setup

- **Hardware:** Apple M4 MacBook, 16GB unified memory
- **Backend:** MLX, bfloat16 compute, ~9K tok/s peak throughput
- **Data:** 10 training shards from fineweb10B_sp1024 (~1B tokens)
- **Training:** 300 steps per experiment (wall-clock limited)
- **Validation:** Full fineweb_val split, both fp32 and int8 quantized roundtrip

### Deep Supervision (Novel Technique)

Added auxiliary loss at the encoder-decoder boundary, forcing encoder layers to learn predictive representations earlier. Zero extra parameters.

| Batch Size | Baseline | +DeepSup(0.03) | Effect |
|-----------|----------|----------------|--------|
| 8K | 2.168 | 2.118 | -0.050 (helps) |
| 16K | 2.037 | 2.037 | 0.000 (neutral) |
| 64K | 1.767 | 1.774 | +0.006 (neutral) |

Acts as a regularizer whose benefit scales inversely with batch size. Not useful at competition scale.

### Optimization Tuning (128K batch, M4)

| Config | int8_bpb | vs 128K Baseline (1.667) |
|--------|----------|--------------------------|
| **Matrix LR 0.08** | **1.6414** | **-0.025** |
| Matrix LR 0.06 | 1.6431 | -0.024 |
| Grad clip norm=1.0 | 1.6473 | -0.019 |
| MLP mult 3 | 1.6596 | -0.007 |
| 10 layers | 1.6613 | -0.006 |

Default LR 0.04 is too conservative for short training runs. Gradient clipping stabilizes large-batch training.

### Convergence Techniques (64K batch, M4)

EMA, SWA, Partial RoPE, and longer sequences all hurt at 300 steps — these need thousands of steps to help, which is consistent with them appearing in the SOTA submissions that train for 9000 steps.

## Summary

| Phase | Hardware | Key Finding |
|-------|----------|-------------|
| M4 exploration (27 exps) | M4 MacBook | LR 0.08 optimal for short runs, deep supervision helps at small batch only |
| Asymmetric split sweep | RTX 5090 | 1/10 encoder-decoder split gives -0.016 BPB monotonically |
| SOTA validation | RTX 5090 | 1/10 split gives -0.004 on SOTA code |
| 8xH100 record attempt | 8xH100 SXM | 1.1492 pre-quant BPB at step 5666/9000 (FA2 speed bottleneck) |

## What Additional GPU Credits Would Enable

With FA3 properly built from source on H100, we could:
1. **Complete the record run** — reach all 9000 steps (vs 5666 with FA2), potentially placing top 10
2. **Test asymmetric split + LR tuning combined** — our two best findings modify different aspects of training
3. **Run multiple seeds** for statistical confidence

The asymmetric split is a one-line change that improves performance across baseline and SOTA code, across different GPUs, and at different training scales. A clean FA3 run would determine its true contribution at competition scale.

## Reproduce

```bash
# Asymmetric split (one-line change in any train_gpt.py)
# Change: self.num_encoder_layers = num_layers // 2
# To: self.num_encoder_layers = 1

# RTX 5090 / H100 baseline test
pip install sentencepiece huggingface-hub datasets tiktoken flash-attn
python data/cached_challenge_fineweb.py --variant sp1024
ITERATIONS=300 torchrun --standalone --nproc_per_node=1 train_gpt.py

# M4 MacBook
pip install sentencepiece mlx
MATRIX_LR=0.08 ITERATIONS=300 TRAIN_BATCH_TOKENS=131072 GRAD_ACCUM_STEPS=16 \
VAL_BATCH_SIZE=524288 python train_gpt_mlx.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"author": "Rana Usman",
"github_id": "ranausmanai",
"name": "Asymmetric 1/10 Encoder-Decoder Split + Systematic Exploration",
"blurb": "One-line change: setting num_encoder_layers=1 (instead of num_layers//2) monotonically improves BPB. Validated on baseline code (-0.016 BPB across 4 split ratios on RTX 5090) and SOTA code (-0.004 BPB). 8xH100 partial run reached 1.1492 pre-quant BPB at step 5666/9000 (FA2 speed bottleneck prevented full training). Background: 27 systematic experiments on M4 MacBook exploring deep supervision, LR tuning, batch scaling, and architecture.",
"date": "2026-04-03T00:00:00Z",
"track": "non-record-16mb",
"val_loss": 2.7714,
"val_bpb": 1.6414,
"pre_quant_val_loss": 2.7683,
"pre_quant_val_bpb": 1.6395,
"h100_pre_quant_val_bpb": 1.1492,
"h100_steps_completed": 5666,
"h100_steps_target": 9000,
"step_stop": 300,
"wallclock_seconds": 4125,
"bytes_total": 9815427,
"bytes_model_int8_zlib": 9815427,
"bytes_code": null,
"gpu": "Apple M4 (MLX), RTX 5090 (PyTorch), 8xH100 SXM (PyTorch)"
}
Loading