This document provides a detailed explanation of a complete implementation of a small transformer-based language model inspired by Qwen3, including custom attention, tokenizer integration, optimizer (Muon), and inference functions.
- Overview
- Core Concepts
- Code Modules
- 1. Imports and Setup
- 2. Utilities
- 3. Configuration
- 4. Optimizer - Muon
- 5. Data Loading
- 6. Dataset Class
- 7. Rotary Position Embedding (RoPE)
- 8. Attention (Grouped Query Attention)
- 9. SwiGLU Feedforward Network
- 10. Transformer Block
- 11. Language Model Class
- 12. Evaluation Function
- 13. Optimizer Setup
- 14. Training Loop
- 15. Inference Functions
- Key Concepts Summary
This project implements a mini Transformer-based language model inspired by Qwen3. It covers:
- Custom grouped-query attention with RoPE
- SwiGLU activation in feedforward
- Hybrid optimizer (Muon + AdamW)
- Tokenized dataset with HuggingFace
- Training and evaluation loops
- Text generation and demo
- Transformer Architecture: Uses multi-head attention, layer norm, and feedforward layers in blocks.
- Rotary Position Embedding (RoPE): Replaces traditional positional encoding.
- Grouped Query Attention (GQA): Optimizes memory/computation by grouping key-value heads.
- SwiGLU: Efficient activation function combining Swish and GLU.
- Muon Optimizer: A momentum-based optimizer that applies Newton-Schulz iterations to improve convergence.
Standard imports including PyTorch, HuggingFace's datasets, transformers, and tqdm for progress visualization.
Ensures reproducibility by seeding all random number generators (Python, NumPy, Torch, CUDA).
Dataclass that stores model, training, and data parameters, including head counts, hidden sizes, and sequence lengths.
Performs orthogonalization using Newton-Schulz iteration to stabilize gradients.
Custom optimizer that combines Nesterov momentum with orthogonalization, applied only to 2D parameters (like weight matrices).
- Loads dataset and tokenizer from HuggingFace.
- Tokenizes text data and caches it.
- Stores as
.pklfor quick future reloads.
Builds training samples as sliding windows from token stream, aligning inputs x and targets y (shifted by 1).
Generates sine/cosine embeddings for RoPE to allow extrapolation beyond training sequence lengths.
- Projects Q, K, V from input.
- Applies QK normalization and RoPE.
- Implements Grouped Query Attention using
repeat_kvto align key/value heads with query heads.
Implements Swish + Gated Linear Units to enhance activation expressiveness.
- Combines
Qwen3AttentionandSwiGLUFeedForward - Uses RMSNorm and residual connections
- Embedding + positional dropout
- Stacks multiple Transformer blocks
- Uses weight tying between input and output layers
Calculates cross-entropy loss, token-level accuracy, and perplexity on validation data.
Splits parameters between Muon (matrices) and AdamW (everything else) for optimal efficiency.
- Gradient accumulation
- Automatic Mixed Precision (AMP)
- Cosine LR scheduler with warmup
- Logging every 10 steps and evaluation every 500 steps
- Saves best model and final model checkpoints
Samples tokens from the model using nucleus sampling and top-k filtering.
CLI interface for prompting the model.
Runs fixed prompt tests after training.
| Concept | Description |
|---|---|
| RoPE | Position encoding method that uses trigonometric rotations |
| GQA | Grouped Query Attention reduces KV projections for efficiency |
| SwiGLU | Combines Swish activation with Gated Linear Units |
| Muon | Optimizer that improves convergence by using orthogonal gradients |
| Tokenizer | Uses HuggingFace AutoTokenizer for encoding/decoding |
| AMP | Mixed precision training for speed and memory efficiency |
| Grad Accumulation | Allows larger effective batch sizes without increasing memory usage |
- This model is not trained on a large corpus, and results are meant for educational or experimental purposes.
- To avoid NaN loss issues, ensure that the
vocab_sizematches target tensor range, useignore_index=pad_token_idinCrossEntropyLoss, and consider lowering the learning rate.