Skip to content

Add per-iteration LoRA adapters for depth recurrence#1

Closed
Meirzhan05 wants to merge 1 commit intomainfrom
worktree-agent-a211f0bb
Closed

Add per-iteration LoRA adapters for depth recurrence#1
Meirzhan05 wants to merge 1 commit intomainfrom
worktree-agent-a211f0bb

Conversation

@Meirzhan05
Copy link
Copy Markdown
Owner

@Meirzhan05 Meirzhan05 commented Apr 24, 2026

Summary

  • Adds AttentionLoRA class providing low-rank corrections to attention Q/K/V/proj projections
  • Each extra loop iteration (blocks 3-5 looped via num_loops=2) gets unique LoRA adapters, recovering expressivity lost from weight sharing
  • Controlled by LORA_RANK env var (default 0 = disabled, zero overhead) and LORA_LR (default 0.01)
  • LoRA schedule is precomputed at init time; adapters use Adam optimizer (not Muon); params kept in fp32

Test plan

  • Syntax check passes (ast.parse)
  • Full training run on 8xH100 (not available locally, needs cluster)

Generated with Claude Code

Summary by CodeRabbit

  • New Features
    • Complete training and evaluation framework for transformer models with distributed support
    • Configurable model architectures with multiple attention mechanism options
    • Post-training quantization with compression for efficient storage and deployment
    • Advanced evaluation modes including test-time training capabilities

When LORA_RANK>0, each extra loop iteration through blocks 3-5 gets
unique low-rank corrections to attention Q/K/V/proj, recovering
expressivity lost from weight sharing. Zero overhead when disabled.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

This pull request introduces train_gpt_mtp.py, a comprehensive end-to-end training and evaluation script for a configurable GPT-style transformer model. The script implements distributed training with a custom Muon optimizer, GPTQ-style post-training quantization, and multiple evaluation modes including sliding-window and test-time training variants.

Changes

Cohort / File(s) Summary
Training Pipeline & Model Architecture
train_gpt_mtp.py
Adds complete GPT transformer implementation with RoPE positional encoding, flash-attention, optional xsa attention modification, block-parallel residual mixing, and per-loop attention LoRA adapters. Includes custom Muon optimizer combining AdamW and matrix parameter optimization with distributed gradient synchronization, gradient accumulation, and EMA weight averaging. Implements distributed training loop with warmup scheduling and optional wallclock-based early stopping.
Quantization & Serialization
train_gpt_mtp.py
Implements GPTQ-style post-training quantization pipeline: collects Hessian Gram matrices via forward hooks, quantizes selected linear layers with per-row scales, mixes quantized and passthrough parameters, and compresses via byte-shuffle plus lzma/brotli codec. Includes dequantization logic for model loading and evaluation.
Evaluation Framework
train_gpt_mtp.py
Adds three evaluation modes: fixed-length validation, sliding-window evaluation for longer sequences, and test-time training (TTT) over chunked windows. Supports multi-rank distributed evaluation with synchronized metrics reduction and optional logfile output.
Data Loading & Infrastructure
train_gpt_mtp.py
Provides shuffled sequence loader with configurable batch generation, SentencePiece vocabulary utilities, validation data handling, and distributed-aware token loading. Includes hyperparameter management via environment variables and structured logging across all training/evaluation phases.

Sequence Diagram(s)

sequenceDiagram
    participant Main as train_and_eval()
    participant Load as Data Loading
    participant Build as Model Building
    participant Train as Training Loop
    participant Val as Validation
    participant Quant as GPTQ Quantization
    participant Eval as Evaluation

    Main->>Load: load_validation_tokens()
    Load-->>Main: val_data
    Main->>Build: GPT(h)
    Build-->>Main: model
    Main->>Train: train_model(h, val_data)
    Train->>Train: Forward pass, compute loss
    Train->>Train: Backward, optimizer step
    Train->>Val: eval_val() periodic
    Val-->>Train: validation loss
    Train-->>Main: trained model
    Main->>Quant: collect_hessians()
    Quant->>Quant: Forward hooks gather Gram matrices
    Quant-->>Main: hessian dict
    Main->>Quant: gptq_mixed_quantize()
    Quant->>Quant: Quantize weights, compress payload
    Quant-->>Main: quantized state_dict
    Main->>Eval: eval_val()
    Eval->>Eval: Fixed-length validation
    Eval-->>Main: final metrics
    Main->>Eval: eval_val_sliding()
    Eval->>Eval: Sliding window evaluation
    Eval-->>Main: sliding metrics
    Main->>Eval: eval_val_ttt()
    Eval->>Eval: Test-time training evaluation
    Eval-->>Main: TTT metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 A transformer rises, layer by layer,
With ropes to rotate and LoRA to layer,
Muon steps wisely through matrix terrain,
GPTQ squeezes the weights, less brain,
Then eval dances in three different ways—
hop hop training's here to stay! 🚀

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Title check ⚠️ Warning The PR title describes a specific feature (per-iteration LoRA adapters for depth recurrence) that is a real component of the changeset, but the title does not capture the main point of this massive 534-line addition which is a complete end-to-end training/evaluation script with many integrated features. The title should reflect the primary scope: adding a complete training script with configurable GPT transformer, distributed training, quantization, and evaluation capabilities. Consider: 'Add end-to-end GPT training and evaluation script with configurable transformer features' or similar.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch worktree-agent-a211f0bb

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
train_gpt_mtp.py (1)

143-165: LoRA adapter allocation & schedule — logic is correct; one subtlety worth a comment.

I traced the schedule with num_loops=2, loop_start=3, loop_end=5, num_layers=11:

  • all_indices = [0,1,2, 3,4,5, 3,4,5, 3,4,5, 6,7,8,9,10] (17 entries) — the loop segment appears num_loops+1 = 3 times, so the number of extra visits per looped block is num_loops = 2, matching len(self.lora_adapters) = h.num_loops. ✓
  • num_enc = 8, so encoder_indices = [0,1,2,3,4,5,3,4], decoder_indices = [5,3,4,5,6,7,8,9,10].
  • In _build_sched, block_iter is captured by closure and shared across the two calls, so visit counts carry over from encoder to decoder. This gives, e.g., block 5 → c=0 in encoder (None), c=1 in decoder (lora_adapters[0][2]), c=2 in decoder (lora_adapters[1][2]). Block 3 → None, then lora_adapters[0][0], then lora_adapters[1][0]. That matches the stated intent ("each extra iteration receives unique LoRA adapters"). ✓
  • Inner index bi - h.loop_start stays in [0, loop_end-loop_start], matching the inner ModuleList length. ✓

The cross-encoder/decoder closure sharing of block_iter is correct but very easy to misread as a bug during future refactors (e.g., someone moving block_iter={} inside _build_sched would silently break the "third visit uses adapter[1]" semantics for blocks whose first two visits straddle the split). A short comment would pay for itself:

📝 Suggested clarification (optional)
-		if self.lora_adapters is not None:
-			loop_range=set(range(h.loop_start,h.loop_end+1));block_iter={}
+		if self.lora_adapters is not None:
+			# block_iter is intentionally shared across the two _build_sched calls so that
+			# visit counts for looped blocks carry over from encoder to decoder.
+			# Visit 0 -> None (baseline shared weights), visit k>=1 -> lora_adapters[k-1][bi-loop_start].
+			loop_range=set(range(h.loop_start,h.loop_end+1));block_iter={}

Otherwise the construction looks good.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@train_gpt_mtp.py` around lines 143 - 165, Add a short clarifying comment
where block_iter is declared (used by _build_sched) explaining that block_iter
is intentionally shared across the two calls to _build_sched so visit counts
persist between building encoder_lora_schedule and decoder_lora_schedule;
mention that this sharing is required so successive visits map to entries in
lora_adapters (and that moving block_iter inside _build_sched would break the
intended mapping between extra loop visits and lora_adapters indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@train_gpt_mtp.py`:
- Around line 309-311: Calibration runs model.forward_logits under
torch.no_grad() without autocast, causing bf16 × fp32 matmul errors when
AttentionLoRA (restore_fp32_params) keeps LoRA tensors in fp32 (e.g., lora.q_A)
while the model is bf16; fix by wrapping the calibration block (the code around
model.eval() / with torch.no_grad(): for _ in range(n_calibration_batches): ...
model.forward_logits(x)) in the same torch.autocast context used elsewhere
(torch.autocast(device_type=<appropriate device type>, dtype=torch.bfloat16)) so
inputs and fp32 LoRA tensors are cast consistently during
collect_hessians/serialize; alternatively if you prefer Option B, change the
AttentionLoRA call site (where x @ lora.q_A happens) to cast lora tensors to
x.dtype (same pattern as CastedLinear) so matmuls are dtype-safe.
- Around line 108-115: The LoRA "A" matrices (q_A, k_A, v_A, proj_A) are shaped
(dim, rank) causing kaiming_uniform_ to use fan_in=rank instead of dim and
produce an oversized init; change the A parameter shapes to (rank, dim) to match
standard LoRA/PEFT convention (and adjust any downstream multiplications
accordingly) so nn.init.kaiming_uniform_ computes fan_in=dim as intended, or
alternatively if you must keep (dim, rank) compute and pass the correct fan_in
explicitly when initializing those tensors; update q_A, k_A, v_A, proj_A
initialization to the chosen convention and ensure q_B/k_B/v_B/proj_B shapes and
any forward logic are adjusted to remain consistent.

---

Nitpick comments:
In `@train_gpt_mtp.py`:
- Around line 143-165: Add a short clarifying comment where block_iter is
declared (used by _build_sched) explaining that block_iter is intentionally
shared across the two calls to _build_sched so visit counts persist between
building encoder_lora_schedule and decoder_lora_schedule; mention that this
sharing is required so successive visits map to entries in lora_adapters (and
that moving block_iter inside _build_sched would break the intended mapping
between extra loop visits and lora_adapters indices).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 93836c35-5819-40a0-bf34-063edfec8b6c

📥 Commits

Reviewing files that changed from the base of the PR and between ea6e666 and 0889c00.

📒 Files selected for processing (1)
  • train_gpt_mtp.py

Comment thread train_gpt_mtp.py
Comment on lines +108 to +115
class AttentionLoRA(nn.Module):
def __init__(self,dim,kv_dim,rank):
super().__init__()
self.q_A=nn.Parameter(torch.empty(dim,rank));self.q_B=nn.Parameter(torch.zeros(rank,dim))
self.k_A=nn.Parameter(torch.empty(dim,rank));self.k_B=nn.Parameter(torch.zeros(rank,kv_dim))
self.v_A=nn.Parameter(torch.empty(dim,rank));self.v_B=nn.Parameter(torch.zeros(rank,kv_dim))
self.proj_A=nn.Parameter(torch.empty(dim,rank));self.proj_B=nn.Parameter(torch.zeros(rank,dim))
for n in('q_A','k_A','v_A','proj_A'):nn.init.kaiming_uniform_(getattr(self,n),a=math.sqrt(5))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

In PyTorch, for a 2D tensor of shape (A, B) passed to nn.init.kaiming_uniform_ with mode='fan_in', which dimension is treated as fan_in — A (size(0)) or B (size(1))?

💡 Result:

In PyTorch, for a 2D tensor of shape (A, B) passed to nn.init.kaiming_uniform_ with mode='fan_in', the fan_in dimension is B, which is tensor.size(1).

Citations:


🏁 Script executed:

# Check if train_gpt_mtp.py exists and examine lines 108-115
find . -name "train_gpt_mtp.py" -type f

Repository: Meirzhan05/parameter-golf

Length of output: 87


🏁 Script executed:

sed -n '108,115p' ./train_gpt_mtp.py

Repository: Meirzhan05/parameter-golf

Length of output: 634


Non-standard Kaiming fan_in for LoRA A matrices.

The A parameters are stored transposed relative to the original/PEFT LoRA convention: here q_A has shape (dim, rank), whereas Microsoft LoRA / HuggingFace PEFT store lora_A as (rank, in_features). nn.init.kaiming_uniform_ with the default mode='fan_in' reads tensor.size(1), so this code uses fan_in = rank (e.g. 8), not fan_in = dim (e.g. 512). The resulting init bound is roughly sqrt(dim/rank) ≈ 8× larger than the standard choice.

Because B is zero-initialized, the initial forward output is still zero, so this is not a correctness bug. However, the first gradients on B are proportional to A, so an oversized A makes the effective LoRA learning rate ~8× higher than the configured lora_lr for the first several steps. Worth confirming this is intentional; if not, the standard fix is either to shape A as (rank, dim) or to pass fan_in explicitly:

♻️ Proposed fix (match standard LoRA init scale)
-		for n in('q_A','k_A','v_A','proj_A'):nn.init.kaiming_uniform_(getattr(self,n),a=math.sqrt(5))
+		# A is (dim, rank); kaiming_uniform_ uses size(1) as fan_in, which would give fan_in=rank.
+		# Standard LoRA init uses fan_in=dim, so fill with U(-bound, bound), bound = sqrt(1/dim).
+		bound=math.sqrt(1./dim)
+		for n in('q_A','k_A','v_A','proj_A'):nn.init.uniform_(getattr(self,n),-bound,bound)
🧰 Tools
🪛 Ruff (0.15.11)

[error] 111-111: Multiple statements on one line (semicolon)

(E702)


[error] 112-112: Multiple statements on one line (semicolon)

(E702)


[error] 113-113: Multiple statements on one line (semicolon)

(E702)


[error] 114-114: Multiple statements on one line (semicolon)

(E702)


[error] 115-115: Multiple statements on one line (colon)

(E701)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@train_gpt_mtp.py` around lines 108 - 115, The LoRA "A" matrices (q_A, k_A,
v_A, proj_A) are shaped (dim, rank) causing kaiming_uniform_ to use fan_in=rank
instead of dim and produce an oversized init; change the A parameter shapes to
(rank, dim) to match standard LoRA/PEFT convention (and adjust any downstream
multiplications accordingly) so nn.init.kaiming_uniform_ computes fan_in=dim as
intended, or alternatively if you must keep (dim, rank) compute and pass the
correct fan_in explicitly when initializing those tensors; update q_A, k_A, v_A,
proj_A initialization to the chosen convention and ensure q_B/k_B/v_B/proj_B
shapes and any forward logic are adjusted to remain consistent.

Comment thread train_gpt_mtp.py
Comment on lines +309 to +311
model.eval()
with torch.no_grad():
for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

Does torch.matmul on CUDA auto-promote when one operand is torch.bfloat16 and the other is torch.float32, or does it raise RuntimeError ("expected mat1 and mat2 to have the same dtype")?

💡 Result:

torch.matmul on CUDA raises RuntimeError ("expected mat1 and mat2 to have the same dtype") when one operand is torch.bfloat16 and the other is torch.float32. It does not auto-promote. PyTorch matrix multiplication operations, including torch.matmul, require the two input tensors (mat1 and mat2) to have the same dtype. This is enforced in the underlying cuBLAS GEMM calls and meta checks, as evidenced by multiple error reports and discussions. While PyTorch supports dtype promotion in elementwise arithmetic operations (add, mul, etc.) per tensor attribute docs, matrix multiplication does not auto-promote between bfloat16 and float32. Users must explicitly cast one tensor to match the other's dtype, e.g., a.to(b.dtype) @ b. This behavior is consistent across CUDA backends, with errors like "mat1 and mat2 must have the same dtype", "self and mat2 must have the same dtype, but got Half and Float", and explicit checks added in recent versions for compile-time detection. Official docs for torch.matmul do not mention promotion between these dtypes, and all examples use matching dtypes. Mixed dtype matmul requires manual handling, often via torch.amp autocast for training.

Citations:


Calibration crashes when LoRA is active: bf16 × fp32 matmul raises dtype error.

When LORA_RANK>0 and num_loops>0, the end of training leaves base_model.looping_active=True (set at line 494). serialize then calls collect_hessians(base_model, …) which runs model.forward_logits(x) inside torch.no_grad() but without any torch.autocast. The model is bf16 (line 445), but restore_fp32_params keeps AttentionLoRA in fp32 (line 283), so at line 102 the code executes x @ lora.q_A with x in bf16 and lora.q_A in fp32. PyTorch's matmul raises RuntimeError: expected mat1 and mat2 to have the same dtype (BFloat16 vs Float) — aborting GPTQ. During training/eval this is masked because everything runs under torch.autocast(..., dtype=torch.bfloat16).

Two options (pick one):

🛡️ Option A — wrap calibration in bf16 autocast (minimal, matches other forward call sites)
-	model.eval()
-	with torch.no_grad():
-		for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
+	model.eval()
+	with torch.no_grad(),torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):
+		for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
🛡️ Option B — make LoRA dtype-safe like CastedLinear (future-proof)

Cast the fp32 LoRA tensors to x.dtype at the call site (same pattern CastedLinear already uses for its weights):

-		if lora is not None:q=q+(x@lora.q_A)@lora.q_B;k=k+(x@lora.k_A)@lora.k_B;v=v+(x@lora.v_A)@lora.v_B
+		if lora is not None:
+			qA=lora.q_A.to(x.dtype);qB=lora.q_B.to(x.dtype);kA=lora.k_A.to(x.dtype);kB=lora.k_B.to(x.dtype);vA=lora.v_A.to(x.dtype);vB=lora.v_B.to(x.dtype)
+			q=q+(x@qA)@qB;k=k+(x@kA)@kB;v=v+(x@vA)@vB
...
-		if lora is not None:out=out+(y@lora.proj_A)@lora.proj_B
+		if lora is not None:pA=lora.proj_A.to(y.dtype);pB=lora.proj_B.to(y.dtype);out=out+(y@pA)@pB
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model.eval()
with torch.no_grad():
for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
model.eval()
with torch.no_grad(),torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):
for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
🧰 Tools
🪛 Ruff (0.15.11)

[error] 311-311: Multiple statements on one line (colon)

(E701)


[error] 311-311: Multiple statements on one line (semicolon)

(E702)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@train_gpt_mtp.py` around lines 309 - 311, Calibration runs
model.forward_logits under torch.no_grad() without autocast, causing bf16 × fp32
matmul errors when AttentionLoRA (restore_fp32_params) keeps LoRA tensors in
fp32 (e.g., lora.q_A) while the model is bf16; fix by wrapping the calibration
block (the code around model.eval() / with torch.no_grad(): for _ in
range(n_calibration_batches): ... model.forward_logits(x)) in the same
torch.autocast context used elsewhere (torch.autocast(device_type=<appropriate
device type>, dtype=torch.bfloat16)) so inputs and fp32 LoRA tensors are cast
consistently during collect_hessians/serialize; alternatively if you prefer
Option B, change the AttentionLoRA call site (where x @ lora.q_A happens) to
cast lora tensors to x.dtype (same pattern as CastedLinear) so matmuls are
dtype-safe.

@Meirzhan05 Meirzhan05 closed this Apr 24, 2026
Meirzhan05 added a commit that referenced this pull request Apr 28, 2026
Comparable in magnitude to recent merged record gaps:
- #1 -> openai#2: 0.0012 BPB
- openai#2 -> openai#3: 0.0006 BPB
- openai#3 -> openai#4: 0.0007 BPB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant