Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# QAT + Architecture Exploration (Non-Record)

This is a non-record submission demonstrating quantization-aware training (QAT) applied to the baseline architecture.

## Approach

**Quantization-Aware Training:** We add simulated int8 per-row quantization to all `CastedLinear` layers during training using a straight-through estimator (STE). The model learns weights that are robust to int8 rounding, reducing the gap between pre-quantization and post-quantization val_bpb.

The 4-hour baseline shows this quantization gap grows from 0.0072 BPB (10-min run) to 0.0325 BPB (4-hour run), making QAT increasingly valuable as training quality improves.

**Implementation:** A `fake_quantize_per_row()` function simulates the exact quantization pipeline used in `quantize_state_dict_int8` -- per-row int8 with fp16 scales. The training flag is threaded through the model so evaluation uses unquantized weights.

## Status

Work in progress. Local MLX smoke tests on Apple Silicon are too resource-intensive for meaningful validation -- training runs saturate compute on consumer hardware, which is precisely why H100 access is needed. Pending H100 validation for leaderboard metrics.

## Planned Extensions

- SEQUENCE parameter sharing (Takase & Kiyono 2023) for more effective depth
- BitNet b1.58 ternary exploration for 5x parameter budget
- NorMuon optimizer, value embeddings, vocabulary tuning
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"author": "Matt Leonard",
"github_id": "mattleonard16",
"name": "QAT + Architecture Exploration",
"blurb": "Quantization-aware training with STE to close the int8 quantization gap. Non-record exploration submission pending H100 validation.",
"date": "2026-03-18T00:00:00Z",
"track": "non-record-exploration",
"status": "in-progress"
}
51 changes: 51 additions & 0 deletions test_qat_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for numel-gated fake quantization in QAT training."""
import mlx.core as mx
from train_gpt_mlx import (
CastedLinear,
INT8_KEEP_FLOAT_MAX_NUMEL,
fake_quantize_per_row,
)


def test_fake_quantize_per_row_roundtrip():
"""fake_quantize_per_row returns a different tensor (not bitwise identical)."""
w = mx.random.normal((128, 64))
w_q = fake_quantize_per_row(w)
mx.eval(w_q)
assert w.shape == w_q.shape
assert not mx.array_equal(w, w_q), "Quantized weight should differ from original"


def test_casted_linear_skips_small_weights():
"""CastedLinear with numel <= threshold should NOT fake-quantize during training."""
# 64x64 = 4096 << 65536 threshold
layer = CastedLinear(64, 64)
assert layer.weight.size <= INT8_KEEP_FLOAT_MAX_NUMEL
x = mx.random.normal((1, 4, 64))
out_train = layer(x, training=True)
out_eval = layer(x, training=False)
mx.eval(out_train, out_eval)
assert mx.allclose(out_train, out_eval, atol=1e-6), (
"Small-weight CastedLinear should produce identical train/eval output"
)


def test_casted_linear_quantizes_large_weights():
"""CastedLinear with numel > threshold SHOULD fake-quantize during training."""
# 512x512 = 262144 > 65536 threshold
layer = CastedLinear(512, 512)
assert layer.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL
x = mx.random.normal((1, 4, 512))
out_train = layer(x, training=True)
out_eval = layer(x, training=False)
mx.eval(out_train, out_eval)
assert not mx.allclose(out_train, out_eval, atol=1e-6), (
"Large-weight CastedLinear should produce different train/eval output"
)


if __name__ == "__main__":
test_fake_quantize_per_row_roundtrip()
test_casted_linear_skips_small_weights()
test_casted_linear_quantizes_large_weights()
print("All tests passed")
61 changes: 38 additions & 23 deletions train_gpt_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,26 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra
# MODEL BLOCKS
# ==============================================================================

def fake_quantize_per_row(w: mx.array) -> mx.array:
"""Simulate int8 per-row quantization during training (STE).
Forward: returns quantized-then-dequantized weights.
Backward: straight-through (gradient flows as if no quantization)."""
w32 = w.astype(mx.float32)
row_max = mx.max(mx.abs(w32), axis=1, keepdims=True)
scale = mx.maximum(row_max / 127.0, mx.array(1.0 / 127.0))
w_q = mx.round(mx.clip(w32 / scale, -127, 127)) * scale
Comment on lines +281 to +283
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use exporter clipping in fake_quantize_per_row

quantize_float_array() clips each 2D row at INT8_CLIP_Q before computing its int8 scale, but this new training path derives scale from the absolute row max and never clips. When a row contains even one large outlier (a realistic case for embeddings and projection weights), QAT trains against a much coarser quantizer than quantize_state_dict_int8() will actually serialize, so the measured pre/post-quantization gap no longer reflects the deployed model. If this experiment is intended to validate deployment-aware training, the fake-quant path needs to mirror the exporter’s clipping rule.

Useful? React with 👍 / 👎.

# STE: forward uses w_q, backward gradient flows through w32
return (w32 + mx.stop_gradient(w_q - w32)).astype(w.dtype)


class CastedLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32)

def __call__(self, x: mx.array) -> mx.array:
return x @ self.weight.astype(x.dtype).T
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
w = fake_quantize_per_row(self.weight) if (training and self.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.weight
return x @ w.astype(x.dtype).T


class RMSNormNoWeight(nn.Module):
Expand Down Expand Up @@ -320,18 +333,18 @@ def __init__(
self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base)
self.scale = self.head_dim ** -0.5

def __call__(self, x: mx.array) -> mx.array:
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
bsz, seqlen, dim = x.shape
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)

q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE))
k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE))
q = q * self.q_gain.astype(q.dtype)[None, :, None, None]
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal")
y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim)
return self.proj(y)
return self.proj(y, training=training)


class MLP(nn.Module):
Expand All @@ -342,9 +355,9 @@ def __init__(self, dim: int, mlp_mult: int):
self.fc = CastedLinear(dim, hidden)
self.proj = CastedLinear(hidden, dim)

def __call__(self, x: mx.array) -> mx.array:
x = nn.relu(self.fc(x))
return self.proj(x * x)
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
x = nn.relu(self.fc(x, training=training))
return self.proj(x * x, training=training)


class Block(nn.Module):
Expand All @@ -366,12 +379,12 @@ def __init__(
self.mlp_scale = mx.ones((dim,), dtype=mx.float32)
self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32))))

def __call__(self, x: mx.array, x0: mx.array) -> mx.array:
def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array:
mix = self.resid_mix.astype(x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
attn_out = self.attn(self.attn_norm(x))
attn_out = self.attn(self.attn_norm(x), training=training)
x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out
x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training)
return x


Expand Down Expand Up @@ -411,38 +424,40 @@ def softcap(self, logits: mx.array) -> mx.array:
c = self.logit_softcap
return c * mx.tanh(logits / c)

def __call__(self, input_ids: mx.array) -> mx.array:
x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE))
def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array:
emb_w = fake_quantize_per_row(self.tok_emb.weight) if (training and self.tok_emb.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.tok_emb.weight
x = rms_norm(emb_w[input_ids].astype(COMPUTE_DTYPE))
x0 = x
skips: list[mx.array] = []

for i in range(self.num_encoder_layers):
x = self.blocks[i](x, x0)
x = self.blocks[i](x, x0, training=training)
skips.append(x)
for i in range(self.num_decoder_layers):
# Odd layer counts have one more decoder block than encoder block. The baseline only
# applies a skip connection when one exists, then runs the remaining decoder block(s)
# without an added skip.
if skips:
x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop()
x = self.blocks[self.num_encoder_layers + i](x, x0)
x = self.blocks[self.num_encoder_layers + i](x, x0, training=training)
return self.final_norm(x)

def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array:
def loss(self, input_ids: mx.array, target_ids: mx.array, training: bool = False) -> mx.array:
# Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful
# memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE).
x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1])
x = self(input_ids, training=training).reshape(-1, self.tok_emb.weight.shape[1])
y = target_ids.reshape(-1)
lm_weight = fake_quantize_per_row(self.tok_emb.weight) if (training and self.tok_emb.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL) else self.tok_emb.weight
if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens:
logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T
logits_proj = x @ lm_weight.astype(x.dtype).T
logits = self.softcap(logits_proj)
return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean")

loss_sum = mx.array(0.0, dtype=mx.float32)
n = int(x.shape[0])
for s in range(0, n, self.logit_chunk_tokens):
e = min(s + self.logit_chunk_tokens, n)
logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T
logits_proj = x[s:e] @ lm_weight.astype(x.dtype).T
logits = self.softcap(logits_proj)
loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum")
return loss_sum / float(n)
Expand Down Expand Up @@ -895,9 +910,9 @@ def log(msg: str, console: bool = True) -> None:
# inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs".
# Compiling the model-bound functions and capturing the full model state fixes that while still
# returning gradients only for trainable parameters via nn.value_and_grad(...).
compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state)
compiled_loss = mx.compile(lambda x, y: model.loss(x, y, training=False), inputs=model.state, outputs=model.state)
compiled_loss_and_grad = mx.compile(
nn.value_and_grad(model, lambda x, y: model.loss(x, y)),
nn.value_and_grad(model, lambda x, y: model.loss(x, y, training=True)),
inputs=model.state,
outputs=model.state,
)
Expand Down