-
Notifications
You must be signed in to change notification settings - Fork 3.4k
QAT + Architecture Exploration (Non-Record) #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
14e54ee
feat: add quantization-aware training (QAT) with STE to MLX script
mattleonard16 c879d0d
docs: add QAT exploration non-record submission
mattleonard16 795412d
fix: split QAT train/eval paths and cover LM-head projection
mattleonard16 4bcdea4
fix: add numel guards to all fake-quantization sites during QAT
mattleonard16 bda4981
test: add coverage for numel-gated fake quantization
mattleonard16 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
records/track_non_record_16mb/2026-03-18_QAT_Exploration/README.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # QAT + Architecture Exploration (Non-Record) | ||
|
|
||
| This is a non-record submission demonstrating quantization-aware training (QAT) applied to the baseline architecture. | ||
|
|
||
| ## Approach | ||
|
|
||
| **Quantization-Aware Training:** We add simulated int8 per-row quantization to all `CastedLinear` layers during training using a straight-through estimator (STE). The model learns weights that are robust to int8 rounding, reducing the gap between pre-quantization and post-quantization val_bpb. | ||
|
|
||
| The 4-hour baseline shows this quantization gap grows from 0.0072 BPB (10-min run) to 0.0325 BPB (4-hour run), making QAT increasingly valuable as training quality improves. | ||
|
|
||
| **Implementation:** A `fake_quantize_per_row()` function simulates the exact quantization pipeline used in `quantize_state_dict_int8` -- per-row int8 with fp16 scales. The training flag is threaded through the model so evaluation uses unquantized weights. | ||
|
|
||
| ## Status | ||
|
|
||
| Work in progress. Local MLX smoke tests on Apple Silicon are too resource-intensive for meaningful validation -- training runs saturate compute on consumer hardware, which is precisely why H100 access is needed. Pending H100 validation for leaderboard metrics. | ||
|
|
||
| ## Planned Extensions | ||
|
|
||
| - SEQUENCE parameter sharing (Takase & Kiyono 2023) for more effective depth | ||
| - BitNet b1.58 ternary exploration for 5x parameter budget | ||
| - NorMuon optimizer, value embeddings, vocabulary tuning |
9 changes: 9 additions & 0 deletions
9
records/track_non_record_16mb/2026-03-18_QAT_Exploration/submission.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| { | ||
| "author": "Matt Leonard", | ||
| "github_id": "mattleonard16", | ||
| "name": "QAT + Architecture Exploration", | ||
| "blurb": "Quantization-aware training with STE to close the int8 quantization gap. Non-record exploration submission pending H100 validation.", | ||
| "date": "2026-03-18T00:00:00Z", | ||
| "track": "non-record-exploration", | ||
| "status": "in-progress" | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| """Tests for numel-gated fake quantization in QAT training.""" | ||
| import mlx.core as mx | ||
| from train_gpt_mlx import ( | ||
| CastedLinear, | ||
| INT8_KEEP_FLOAT_MAX_NUMEL, | ||
| fake_quantize_per_row, | ||
| ) | ||
|
|
||
|
|
||
| def test_fake_quantize_per_row_roundtrip(): | ||
| """fake_quantize_per_row returns a different tensor (not bitwise identical).""" | ||
| w = mx.random.normal((128, 64)) | ||
| w_q = fake_quantize_per_row(w) | ||
| mx.eval(w_q) | ||
| assert w.shape == w_q.shape | ||
| assert not mx.array_equal(w, w_q), "Quantized weight should differ from original" | ||
|
|
||
|
|
||
| def test_casted_linear_skips_small_weights(): | ||
| """CastedLinear with numel <= threshold should NOT fake-quantize during training.""" | ||
| # 64x64 = 4096 << 65536 threshold | ||
| layer = CastedLinear(64, 64) | ||
| assert layer.weight.size <= INT8_KEEP_FLOAT_MAX_NUMEL | ||
| x = mx.random.normal((1, 4, 64)) | ||
| out_train = layer(x, training=True) | ||
| out_eval = layer(x, training=False) | ||
| mx.eval(out_train, out_eval) | ||
| assert mx.allclose(out_train, out_eval, atol=1e-6), ( | ||
| "Small-weight CastedLinear should produce identical train/eval output" | ||
| ) | ||
|
|
||
|
|
||
| def test_casted_linear_quantizes_large_weights(): | ||
| """CastedLinear with numel > threshold SHOULD fake-quantize during training.""" | ||
| # 512x512 = 262144 > 65536 threshold | ||
| layer = CastedLinear(512, 512) | ||
| assert layer.weight.size > INT8_KEEP_FLOAT_MAX_NUMEL | ||
| x = mx.random.normal((1, 4, 512)) | ||
| out_train = layer(x, training=True) | ||
| out_eval = layer(x, training=False) | ||
| mx.eval(out_train, out_eval) | ||
| assert not mx.allclose(out_train, out_eval, atol=1e-6), ( | ||
| "Large-weight CastedLinear should produce different train/eval output" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_fake_quantize_per_row_roundtrip() | ||
| test_casted_linear_skips_small_weights() | ||
| test_casted_linear_quantizes_large_weights() | ||
| print("All tests passed") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantize_float_array()clips each 2D row atINT8_CLIP_Qbefore computing its int8 scale, but this new training path derivesscalefrom the absolute row max and never clips. When a row contains even one large outlier (a realistic case for embeddings and projection weights), QAT trains against a much coarser quantizer thanquantize_state_dict_int8()will actually serialize, so the measured pre/post-quantization gap no longer reflects the deployed model. If this experiment is intended to validate deployment-aware training, the fake-quant path needs to mirror the exporter’s clipping rule.Useful? React with 👍 / 👎.