Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# QAT + Architecture Exploration (Non-Record)

This is a non-record submission demonstrating quantization-aware training (QAT) applied to the baseline architecture.

## Approach

**Quantization-Aware Training:** We add simulated int8 per-row quantization to all `CastedLinear` layers during training using a straight-through estimator (STE). The model learns weights that are robust to int8 rounding, reducing the gap between pre-quantization and post-quantization val_bpb.

The 4-hour baseline shows this quantization gap grows from 0.0072 BPB (10-min run) to 0.0325 BPB (4-hour run), making QAT increasingly valuable as training quality improves.

**Implementation:** A `fake_quantize_per_row()` function simulates the exact quantization pipeline used in `quantize_state_dict_int8` -- per-row int8 with fp16 scales. The training flag is threaded through the model so evaluation uses unquantized weights.

## Status

Work in progress. Local MLX smoke tests on Apple Silicon are too resource-intensive for meaningful validation -- training runs saturate compute on consumer hardware, which is precisely why H100 access is needed. Pending H100 validation for leaderboard metrics.

## Planned Extensions

- SEQUENCE parameter sharing (Takase & Kiyono 2023) for more effective depth
- BitNet b1.58 ternary exploration for 5x parameter budget
- NorMuon optimizer, value embeddings, vocabulary tuning
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"author": "Matt Leonard",
"github_id": "mattleonard16",
"name": "QAT + Architecture Exploration",
"blurb": "Quantization-aware training with STE to close the int8 quantization gap. Non-record exploration submission pending H100 validation.",
"date": "2026-03-18T00:00:00Z",
"track": "non-record-exploration",
"status": "in-progress"
}
47 changes: 30 additions & 17 deletions train_gpt_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,26 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra
# MODEL BLOCKS
# ==============================================================================

def fake_quantize_per_row(w: mx.array) -> mx.array:
"""Simulate int8 per-row quantization during training (STE).
Forward: returns quantized-then-dequantized weights.
Backward: straight-through (gradient flows as if no quantization)."""
w32 = w.astype(mx.float32)
row_max = mx.max(mx.abs(w32), axis=1, keepdims=True)
scale = mx.maximum(row_max / 127.0, mx.array(1.0 / 127.0))
w_q = mx.round(mx.clip(w32 / scale, -127, 127)) * scale
Comment on lines +281 to +283
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use exporter clipping in fake_quantize_per_row

quantize_float_array() clips each 2D row at INT8_CLIP_Q before computing its int8 scale, but this new training path derives scale from the absolute row max and never clips. When a row contains even one large outlier (a realistic case for embeddings and projection weights), QAT trains against a much coarser quantizer than quantize_state_dict_int8() will actually serialize, so the measured pre/post-quantization gap no longer reflects the deployed model. If this experiment is intended to validate deployment-aware training, the fake-quant path needs to mirror the exporter’s clipping rule.

Useful? React with 👍 / 👎.

# STE: forward uses w_q, backward gradient flows through w32
return (w32 + mx.stop_gradient(w_q - w32)).astype(w.dtype)


class CastedLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32)

def __call__(self, x: mx.array) -> mx.array:
return x @ self.weight.astype(x.dtype).T
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
w = fake_quantize_per_row(self.weight) if training else self.weight
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fake-quantize the tied embedding matrix too

The new QAT path only wraps CastedLinear.weight, but the tied tok_emb.weight is still used in full precision for both the input embedding lookup and the LM-head projection. quantize_state_dict_int8() will nevertheless export that tensor as int8 whenever it exceeds INT8_KEEP_FLOAT_MAX_NUMEL (the default 1024×512 table already does), so a large chunk of the deployed quantization error is never present in the training objective. In practice the post-export gap can stay dominated by the embedding table even though the linear layers were trained with QAT.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip fake-quantizing matrices that stay float at export

For smaller architecture sweeps, this branch fake-quantizes every CastedLinear weight unconditionally, but quantize_state_dict_int8() only converts tensors larger than INT8_KEEP_FLOAT_MAX_NUMEL and writes smaller matrices to passthrough. In configurations like MODEL_DIM<=256, several attention/MLP weights fall under that cutoff, so QAT trains against noise that will never exist in the serialized model. That can hurt convergence and make QAT-vs-baseline comparisons misleading for the exact exploration runs this commit introduces.

Useful? React with 👍 / 👎.

return x @ w.astype(x.dtype).T


class RMSNormNoWeight(nn.Module):
Expand Down Expand Up @@ -320,18 +333,18 @@ def __init__(
self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base)
self.scale = self.head_dim ** -0.5

def __call__(self, x: mx.array) -> mx.array:
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
bsz, seqlen, dim = x.shape
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)

q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE))
k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE))
q = q * self.q_gain.astype(q.dtype)[None, :, None, None]
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal")
y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim)
return self.proj(y)
return self.proj(y, training=training)


class MLP(nn.Module):
Expand All @@ -342,9 +355,9 @@ def __init__(self, dim: int, mlp_mult: int):
self.fc = CastedLinear(dim, hidden)
self.proj = CastedLinear(hidden, dim)

def __call__(self, x: mx.array) -> mx.array:
x = nn.relu(self.fc(x))
return self.proj(x * x)
def __call__(self, x: mx.array, training: bool = False) -> mx.array:
x = nn.relu(self.fc(x, training=training))
return self.proj(x * x, training=training)


class Block(nn.Module):
Expand All @@ -366,12 +379,12 @@ def __init__(
self.mlp_scale = mx.ones((dim,), dtype=mx.float32)
self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32))))

def __call__(self, x: mx.array, x0: mx.array) -> mx.array:
def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array:
mix = self.resid_mix.astype(x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
attn_out = self.attn(self.attn_norm(x))
attn_out = self.attn(self.attn_norm(x), training=training)
x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out
x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training)
return x


Expand Down Expand Up @@ -411,27 +424,27 @@ def softcap(self, logits: mx.array) -> mx.array:
c = self.logit_softcap
return c * mx.tanh(logits / c)

def __call__(self, input_ids: mx.array) -> mx.array:
def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array:
x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE))
x0 = x
skips: list[mx.array] = []

for i in range(self.num_encoder_layers):
x = self.blocks[i](x, x0)
x = self.blocks[i](x, x0, training=training)
skips.append(x)
for i in range(self.num_decoder_layers):
# Odd layer counts have one more decoder block than encoder block. The baseline only
# applies a skip connection when one exists, then runs the remaining decoder block(s)
# without an added skip.
if skips:
x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop()
x = self.blocks[self.num_encoder_layers + i](x, x0)
x = self.blocks[self.num_encoder_layers + i](x, x0, training=training)
return self.final_norm(x)

def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array:
# Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful
# memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE).
x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1])
x = self(input_ids, training=True).reshape(-1, self.tok_emb.weight.shape[1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Split the fake-quant training loss from evaluation

Because compiled_loss is the function that eval_val() and the final final_int8_zlib_roundtrip check call, hard-wiring training=True here makes every validation pass run through fake_quantize_per_row() too. That means the logged val_loss/val_bpb no longer represent the full-precision model during training, and after reloading the dequantized int8 checkpoint the round-trip check fake-quantizes those weights a second time. If the intent is “train with QAT, evaluate without it”, this needs a separate eval path.

Useful? React with 👍 / 👎.

y = target_ids.reshape(-1)
if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens:
logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T
Expand Down