-
Notifications
You must be signed in to change notification settings - Fork 3.4k
QAT + Architecture Exploration (Non-Record) #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
14e54ee
c879d0d
795412d
4bcdea4
bda4981
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # QAT + Architecture Exploration (Non-Record) | ||
|
|
||
| This is a non-record submission demonstrating quantization-aware training (QAT) applied to the baseline architecture. | ||
|
|
||
| ## Approach | ||
|
|
||
| **Quantization-Aware Training:** We add simulated int8 per-row quantization to all `CastedLinear` layers during training using a straight-through estimator (STE). The model learns weights that are robust to int8 rounding, reducing the gap between pre-quantization and post-quantization val_bpb. | ||
|
|
||
| The 4-hour baseline shows this quantization gap grows from 0.0072 BPB (10-min run) to 0.0325 BPB (4-hour run), making QAT increasingly valuable as training quality improves. | ||
|
|
||
| **Implementation:** A `fake_quantize_per_row()` function simulates the exact quantization pipeline used in `quantize_state_dict_int8` -- per-row int8 with fp16 scales. The training flag is threaded through the model so evaluation uses unquantized weights. | ||
|
|
||
| ## Status | ||
|
|
||
| Work in progress. Local MLX smoke tests on Apple Silicon are too resource-intensive for meaningful validation -- training runs saturate compute on consumer hardware, which is precisely why H100 access is needed. Pending H100 validation for leaderboard metrics. | ||
|
|
||
| ## Planned Extensions | ||
|
|
||
| - SEQUENCE parameter sharing (Takase & Kiyono 2023) for more effective depth | ||
| - BitNet b1.58 ternary exploration for 5x parameter budget | ||
| - NorMuon optimizer, value embeddings, vocabulary tuning |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| { | ||
| "author": "Matt Leonard", | ||
| "github_id": "mattleonard16", | ||
| "name": "QAT + Architecture Exploration", | ||
| "blurb": "Quantization-aware training with STE to close the int8 quantization gap. Non-record exploration submission pending H100 validation.", | ||
| "date": "2026-03-18T00:00:00Z", | ||
| "track": "non-record-exploration", | ||
| "status": "in-progress" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,13 +273,26 @@ def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.arra | |
| # MODEL BLOCKS | ||
| # ============================================================================== | ||
|
|
||
| def fake_quantize_per_row(w: mx.array) -> mx.array: | ||
| """Simulate int8 per-row quantization during training (STE). | ||
| Forward: returns quantized-then-dequantized weights. | ||
| Backward: straight-through (gradient flows as if no quantization).""" | ||
| w32 = w.astype(mx.float32) | ||
| row_max = mx.max(mx.abs(w32), axis=1, keepdims=True) | ||
| scale = mx.maximum(row_max / 127.0, mx.array(1.0 / 127.0)) | ||
| w_q = mx.round(mx.clip(w32 / scale, -127, 127)) * scale | ||
| # STE: forward uses w_q, backward gradient flows through w32 | ||
| return (w32 + mx.stop_gradient(w_q - w32)).astype(w.dtype) | ||
|
|
||
|
|
||
| class CastedLinear(nn.Module): | ||
| def __init__(self, in_dim: int, out_dim: int): | ||
| super().__init__() | ||
| self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) | ||
|
|
||
| def __call__(self, x: mx.array) -> mx.array: | ||
| return x @ self.weight.astype(x.dtype).T | ||
| def __call__(self, x: mx.array, training: bool = False) -> mx.array: | ||
| w = fake_quantize_per_row(self.weight) if training else self.weight | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new QAT path only wraps Useful? React with 👍 / 👎. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For smaller architecture sweeps, this branch fake-quantizes every Useful? React with 👍 / 👎. |
||
| return x @ w.astype(x.dtype).T | ||
|
|
||
|
|
||
| class RMSNormNoWeight(nn.Module): | ||
|
|
@@ -320,18 +333,18 @@ def __init__( | |
| self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) | ||
| self.scale = self.head_dim ** -0.5 | ||
|
|
||
| def __call__(self, x: mx.array) -> mx.array: | ||
| def __call__(self, x: mx.array, training: bool = False) -> mx.array: | ||
| bsz, seqlen, dim = x.shape | ||
| q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
| k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
| v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
| q = self.c_q(x, training=training).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
| k = self.c_k(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
| v = self.c_v(x, training=training).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) | ||
|
|
||
| q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) | ||
| k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) | ||
| q = q * self.q_gain.astype(q.dtype)[None, :, None, None] | ||
| y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") | ||
| y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) | ||
| return self.proj(y) | ||
| return self.proj(y, training=training) | ||
|
|
||
|
|
||
| class MLP(nn.Module): | ||
|
|
@@ -342,9 +355,9 @@ def __init__(self, dim: int, mlp_mult: int): | |
| self.fc = CastedLinear(dim, hidden) | ||
| self.proj = CastedLinear(hidden, dim) | ||
|
|
||
| def __call__(self, x: mx.array) -> mx.array: | ||
| x = nn.relu(self.fc(x)) | ||
| return self.proj(x * x) | ||
| def __call__(self, x: mx.array, training: bool = False) -> mx.array: | ||
| x = nn.relu(self.fc(x, training=training)) | ||
| return self.proj(x * x, training=training) | ||
|
|
||
|
|
||
| class Block(nn.Module): | ||
|
|
@@ -366,12 +379,12 @@ def __init__( | |
| self.mlp_scale = mx.ones((dim,), dtype=mx.float32) | ||
| self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) | ||
|
|
||
| def __call__(self, x: mx.array, x0: mx.array) -> mx.array: | ||
| def __call__(self, x: mx.array, x0: mx.array, training: bool = False) -> mx.array: | ||
| mix = self.resid_mix.astype(x.dtype) | ||
| x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 | ||
| attn_out = self.attn(self.attn_norm(x)) | ||
| attn_out = self.attn(self.attn_norm(x), training=training) | ||
| x = x + self.attn_scale.astype(x.dtype)[None, None, :] * attn_out | ||
| x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) | ||
| x = x + self.mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x), training=training) | ||
| return x | ||
|
|
||
|
|
||
|
|
@@ -411,27 +424,27 @@ def softcap(self, logits: mx.array) -> mx.array: | |
| c = self.logit_softcap | ||
| return c * mx.tanh(logits / c) | ||
|
|
||
| def __call__(self, input_ids: mx.array) -> mx.array: | ||
| def __call__(self, input_ids: mx.array, training: bool = False) -> mx.array: | ||
| x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) | ||
| x0 = x | ||
| skips: list[mx.array] = [] | ||
|
|
||
| for i in range(self.num_encoder_layers): | ||
| x = self.blocks[i](x, x0) | ||
| x = self.blocks[i](x, x0, training=training) | ||
| skips.append(x) | ||
| for i in range(self.num_decoder_layers): | ||
| # Odd layer counts have one more decoder block than encoder block. The baseline only | ||
| # applies a skip connection when one exists, then runs the remaining decoder block(s) | ||
| # without an added skip. | ||
| if skips: | ||
| x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() | ||
| x = self.blocks[self.num_encoder_layers + i](x, x0) | ||
| x = self.blocks[self.num_encoder_layers + i](x, x0, training=training) | ||
| return self.final_norm(x) | ||
|
|
||
| def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: | ||
| # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful | ||
| # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). | ||
| x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) | ||
| x = self(input_ids, training=True).reshape(-1, self.tok_emb.weight.shape[1]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Because Useful? React with 👍 / 👎. |
||
| y = target_ids.reshape(-1) | ||
| if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: | ||
| logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantize_float_array()clips each 2D row atINT8_CLIP_Qbefore computing its int8 scale, but this new training path derivesscalefrom the absolute row max and never clips. When a row contains even one large outlier (a realistic case for embeddings and projection weights), QAT trains against a much coarser quantizer thanquantize_state_dict_int8()will actually serialize, so the measured pre/post-quantization gap no longer reflects the deployed model. If this experiment is intended to validate deployment-aware training, the fake-quant path needs to mirror the exporter’s clipping rule.Useful? React with 👍 / 👎.