diff --git a/records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md b/records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md new file mode 100644 index 0000000000..5e193e0f0d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md @@ -0,0 +1,21 @@ +# QAT + Architecture Exploration (Non-Record) + +This is a non-record submission demonstrating quantization-aware training (QAT) applied to the baseline architecture. + +## Approach + +**Quantization-Aware Training:** We add simulated int8 per-row quantization to all `CastedLinear` layers during training using a straight-through estimator (STE). The model learns weights that are robust to int8 rounding, reducing the gap between pre-quantization and post-quantization val_bpb. + +The 4-hour baseline shows this quantization gap grows from 0.0072 BPB (10-min run) to 0.0325 BPB (4-hour run), making QAT increasingly valuable as training quality improves. + +**Implementation:** A `fake_quantize_per_row()` function simulates the exact quantization pipeline used in `quantize_state_dict_int8` -- per-row int8 with fp16 scales. The training flag is threaded through the model so evaluation uses unquantized weights. + +## Status + +Work in progress. Local MLX smoke tests on Apple Silicon are too resource-intensive for meaningful validation -- training runs saturate compute on consumer hardware, which is precisely why H100 access is needed. Pending H100 validation for leaderboard metrics. + +## Planned Extensions + +- SEQUENCE parameter sharing (Takase & Kiyono 2023) for more effective depth +- BitNet b1.58 ternary exploration for 5x parameter budget +- NorMuon optimizer, value embeddings, vocabulary tuning diff --git a/records/track_non_record_16mb/2026-03-18_QAT_Exploration/submission.json b/records/track_non_record_16mb/2026-03-18_QAT_Exploration/submission.json new file mode 100644 index 0000000000..c0e27502f3 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-18_QAT_Exploration/submission.json @@ -0,0 +1,9 @@ +{ + "author": "Matt Leonard", + "github_id": "mattleonard16", + "name": "QAT + Architecture Exploration", + "blurb": "Quantization-aware training with STE to close the int8 quantization gap. Non-record exploration submission pending H100 validation.", + "date": "2026-03-18T00:00:00Z", + "track": "non-record-exploration", + "status": "in-progress" +} diff --git a/test_qat_guards.py b/test_qat_guards.py new file mode 100644 index 0000000000..8045e1175c --- /dev/null +++ b/test_qat_guards.py @@ -0,0 +1,51 @@ +"""Tests for numel-gated fake quantization in QAT training.""" +import mlx.core as mx +from train_gpt_mlx import ( + CastedLinear, + INT8_KEEP_FLOAT_MAX_NUMEL, + fake_quantize_per_row, +) + + +def test_fake_quantize_per_row_roundtrip(): + """fake_quantize_per_row returns a different tensor (not bitwise identical).""" + w = mx.random.normal((128, 64)) + w_q = fake_quantize_per_row(w) + mx.eval(w_q) + assert w.shape == w_q.shape + assert not mx.array_equal(w, w_q), "Quantized weight should differ from original" + + +def test_casted_linear_skips_small_weights(): + """CastedLinear with numel <= threshold should NOT fake-quantize during training.""" + # 64x64 = 4096 << 65536 threshold + layer = CastedLinear(64, 64) + assert layer.weight.size <= INT8_KEEP_FLOAT_MAX_NUMEL + x = mx.random.normal((1, 4, 64)) + out_train = layer(x, training=True) + out_eval = layer(x, training=False) + mx.eval(out_train, out_eval) + assert mx.allclose(out_train, out_eval, atol=1e-6), ( + "Small-weight CastedLinear should produce identical train/eval output" + ) + + +def test_casted_linear_quantizes_large_weights(): + """CastedLinear with numel > threshold SHOULD fake-quantize during training.""" + # 512x512 = 262144 > 65536 threshold + layer = CastedLinear(512, 512) + assert layer.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL + x = mx.random.normal((1, 4, 512)) + out_train = layer(x, training=True) + out_eval = layer(x, training=False) + mx.eval(out_train, out_eval) + assert not mx.allclose(out_train, out_eval, atol=1e-6), ( + "Large-weight CastedLinear should produce different train/eval output" + ) + + +if __name__ == "__main__": + test_fake_quantize_per_row_roundtrip() + test_casted_linear_skips_small_weights() + test_casted_linear_quantizes_large_weights() + print("All tests passed") diff --git a/train_gpt_mlx.py b/train_gpt_mlx.py index bf7c7d1b8c..10a1381446 100644 --- a/train_gpt_mlx.py +++ b/train_gpt_mlx.py @@ -273,13 +273,26 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra # MODEL BLOCKS # ============================================================================== +def fake_quantize_per_row(w: mx.array) -> mx.array: + """Simulate int8 per-row quantization during training (STE). + Forward: returns quantized-then-dequantized weights. + Backward: straight-through (gradient flows as if no quantization).""" + w32 = w.astype(mx.float32) + row_max = mx.max(mx.abs(w32), axis=1, keepdims=True) + scale = mx.maximum(row_max / 127.0, mx.array(1.0 / 127.0)) + w_q = mx.round(mx.clip(w32 / scale, -127, 127)) * scale + # STE: forward uses w_q, backward gradient flows through w32 + return (w32 + mx.stop_gradient(w_q - w32)).astype(w.dtype) + + class CastedLinear(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) - def __call__(self, x: mx.array) -> mx.array: - return x @ self.weight.astype(x.dtype).T + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + w = fake_quantize_per_row(self.weight) if (training and self.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.weight + return x @ w.astype(x.dtype).T class RMSNormNoWeight(nn.Module): @@ -320,18 +333,18 @@ def __init__( self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) self.scale = self.head_dim ** -0.5 - def __call__(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array, training: bool = False) -> mx.array: bsz, seqlen, dim = x.shape - q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) - k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) - v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) q = q * self.q_gain.astype(q.dtype)[None, :, None, None] y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) - return self.proj(y) + return self.proj(y, training=training) class MLP(nn.Module): @@ -342,9 +355,9 @@ def __init__(self, dim: int, mlp_mult: int): self.fc = CastedLinear(dim, hidden) self.proj = CastedLinear(hidden, dim) - def __call__(self, x: mx.array) -> mx.array: - x = nn.relu(self.fc(x)) - return self.proj(x * x) + def __call__(self, x: mx.array, training: bool = False) -> mx.array: + x = nn.relu(self.fc(x, training=training)) + return self.proj(x * x, training=training) class Block(nn.Module): @@ -366,12 +379,12 @@ def __init__( self.mlp_scale = mx.ones((dim,), dtype=mx.float32) self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) - def __call__(self, x: mx.array, x0: mx.array) -> mx.array: + def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array: mix = self.resid_mix.astype(x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + attn_out = self.attn(self.attn_norm(x), training=training) x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out - x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training) return x @@ -411,13 +424,14 @@ def softcap(self, logits: mx.array) -> mx.array: c = self.logit_softcap return c * mx.tanh(logits / c) - def __call__(self, input_ids: mx.array) -> mx.array: - x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array: + emb_w = fake_quantize_per_row(self.tok_emb.weight) if (training and self.tok_emb.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.tok_emb.weight + x = rms_norm(emb_w[input_ids].astype(COMPUTE_DTYPE)) x0 = x skips: list[mx.array] = [] for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) + x = self.blocks[i](x, x0, training=training) skips.append(x) for i in range(self.num_decoder_layers): # Odd layer counts have one more decoder block than encoder block. The baseline only @@ -425,16 +439,17 @@ def __call__(self, input_ids: mx.array) -> mx.array: # without an added skip. if skips: x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.blocks[self.num_encoder_layers + i](x, x0, training=training) return self.final_norm(x) - def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + def loss(self, input_ids: mx.array, target_ids: mx.array, training: bool = False) -> mx.array: # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). - x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + x = self(input_ids, training=training).reshape(-1, self.tok_emb.weight.shape[1]) y = target_ids.reshape(-1) + lm_weight = fake_quantize_per_row(self.tok_emb.weight) if (training and self.tok_emb.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.tok_emb.weight if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: - logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits_proj = x @ lm_weight.astype(x.dtype).T logits = self.softcap(logits_proj) return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") @@ -442,7 +457,7 @@ def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: n = int(x.shape[0]) for s in range(0, n, self.logit_chunk_tokens): e = min(s + self.logit_chunk_tokens, n) - logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits_proj = x[s:e] @ lm_weight.astype(x.dtype).T logits = self.softcap(logits_proj) loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") return loss_sum / float(n) @@ -895,9 +910,9 @@ def log(msg: str, console: bool = True) -> None: # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". # Compiling the model-bound functions and capturing the full model state fixes that while still # returning gradients only for trainable parameters via nn.value_and_grad(...). - compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss = mx.compile(lambda x, y: model.loss(x, y, training=False), inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( - nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + nn.value_and_grad(model, lambda x, y: model.loss(x, y, training=True)), inputs=model.state, outputs=model.state, )