Add per-iteration LoRA adapters for depth recurrence#1
Add per-iteration LoRA adapters for depth recurrence#1Meirzhan05 wants to merge 1 commit intomainfrom
Conversation
When LORA_RANK>0, each extra loop iteration through blocks 3-5 gets unique low-rank corrections to attention Q/K/V/proj, recovering expressivity lost from weight sharing. Zero overhead when disabled.
📝 WalkthroughWalkthroughThis pull request introduces Changes
Sequence Diagram(s)sequenceDiagram
participant Main as train_and_eval()
participant Load as Data Loading
participant Build as Model Building
participant Train as Training Loop
participant Val as Validation
participant Quant as GPTQ Quantization
participant Eval as Evaluation
Main->>Load: load_validation_tokens()
Load-->>Main: val_data
Main->>Build: GPT(h)
Build-->>Main: model
Main->>Train: train_model(h, val_data)
Train->>Train: Forward pass, compute loss
Train->>Train: Backward, optimizer step
Train->>Val: eval_val() periodic
Val-->>Train: validation loss
Train-->>Main: trained model
Main->>Quant: collect_hessians()
Quant->>Quant: Forward hooks gather Gram matrices
Quant-->>Main: hessian dict
Main->>Quant: gptq_mixed_quantize()
Quant->>Quant: Quantize weights, compress payload
Quant-->>Main: quantized state_dict
Main->>Eval: eval_val()
Eval->>Eval: Fixed-length validation
Eval-->>Main: final metrics
Main->>Eval: eval_val_sliding()
Eval->>Eval: Sliding window evaluation
Eval-->>Main: sliding metrics
Main->>Eval: eval_val_ttt()
Eval->>Eval: Test-time training evaluation
Eval-->>Main: TTT metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
train_gpt_mtp.py (1)
143-165: LoRA adapter allocation & schedule — logic is correct; one subtlety worth a comment.I traced the schedule with
num_loops=2,loop_start=3,loop_end=5,num_layers=11:
all_indices = [0,1,2, 3,4,5, 3,4,5, 3,4,5, 6,7,8,9,10](17 entries) — the loop segment appearsnum_loops+1 = 3times, so the number of extra visits per looped block isnum_loops = 2, matchinglen(self.lora_adapters) = h.num_loops. ✓num_enc = 8, soencoder_indices = [0,1,2,3,4,5,3,4],decoder_indices = [5,3,4,5,6,7,8,9,10].- In
_build_sched,block_iteris captured by closure and shared across the two calls, so visit counts carry over from encoder to decoder. This gives, e.g., block 5 →c=0in encoder (None),c=1in decoder (lora_adapters[0][2]),c=2in decoder (lora_adapters[1][2]). Block 3 → None, thenlora_adapters[0][0], thenlora_adapters[1][0]. That matches the stated intent ("each extra iteration receives unique LoRA adapters"). ✓- Inner index
bi - h.loop_startstays in[0, loop_end-loop_start], matching the inner ModuleList length. ✓The cross-encoder/decoder closure sharing of
block_iteris correct but very easy to misread as a bug during future refactors (e.g., someone movingblock_iter={}inside_build_schedwould silently break the "third visit uses adapter[1]" semantics for blocks whose first two visits straddle the split). A short comment would pay for itself:📝 Suggested clarification (optional)
- if self.lora_adapters is not None: - loop_range=set(range(h.loop_start,h.loop_end+1));block_iter={} + if self.lora_adapters is not None: + # block_iter is intentionally shared across the two _build_sched calls so that + # visit counts for looped blocks carry over from encoder to decoder. + # Visit 0 -> None (baseline shared weights), visit k>=1 -> lora_adapters[k-1][bi-loop_start]. + loop_range=set(range(h.loop_start,h.loop_end+1));block_iter={}Otherwise the construction looks good.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@train_gpt_mtp.py` around lines 143 - 165, Add a short clarifying comment where block_iter is declared (used by _build_sched) explaining that block_iter is intentionally shared across the two calls to _build_sched so visit counts persist between building encoder_lora_schedule and decoder_lora_schedule; mention that this sharing is required so successive visits map to entries in lora_adapters (and that moving block_iter inside _build_sched would break the intended mapping between extra loop visits and lora_adapters indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@train_gpt_mtp.py`:
- Around line 309-311: Calibration runs model.forward_logits under
torch.no_grad() without autocast, causing bf16 × fp32 matmul errors when
AttentionLoRA (restore_fp32_params) keeps LoRA tensors in fp32 (e.g., lora.q_A)
while the model is bf16; fix by wrapping the calibration block (the code around
model.eval() / with torch.no_grad(): for _ in range(n_calibration_batches): ...
model.forward_logits(x)) in the same torch.autocast context used elsewhere
(torch.autocast(device_type=<appropriate device type>, dtype=torch.bfloat16)) so
inputs and fp32 LoRA tensors are cast consistently during
collect_hessians/serialize; alternatively if you prefer Option B, change the
AttentionLoRA call site (where x @ lora.q_A happens) to cast lora tensors to
x.dtype (same pattern as CastedLinear) so matmuls are dtype-safe.
- Around line 108-115: The LoRA "A" matrices (q_A, k_A, v_A, proj_A) are shaped
(dim, rank) causing kaiming_uniform_ to use fan_in=rank instead of dim and
produce an oversized init; change the A parameter shapes to (rank, dim) to match
standard LoRA/PEFT convention (and adjust any downstream multiplications
accordingly) so nn.init.kaiming_uniform_ computes fan_in=dim as intended, or
alternatively if you must keep (dim, rank) compute and pass the correct fan_in
explicitly when initializing those tensors; update q_A, k_A, v_A, proj_A
initialization to the chosen convention and ensure q_B/k_B/v_B/proj_B shapes and
any forward logic are adjusted to remain consistent.
---
Nitpick comments:
In `@train_gpt_mtp.py`:
- Around line 143-165: Add a short clarifying comment where block_iter is
declared (used by _build_sched) explaining that block_iter is intentionally
shared across the two calls to _build_sched so visit counts persist between
building encoder_lora_schedule and decoder_lora_schedule; mention that this
sharing is required so successive visits map to entries in lora_adapters (and
that moving block_iter inside _build_sched would break the intended mapping
between extra loop visits and lora_adapters indices).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
| class AttentionLoRA(nn.Module): | ||
| def __init__(self,dim,kv_dim,rank): | ||
| super().__init__() | ||
| self.q_A=nn.Parameter(torch.empty(dim,rank));self.q_B=nn.Parameter(torch.zeros(rank,dim)) | ||
| self.k_A=nn.Parameter(torch.empty(dim,rank));self.k_B=nn.Parameter(torch.zeros(rank,kv_dim)) | ||
| self.v_A=nn.Parameter(torch.empty(dim,rank));self.v_B=nn.Parameter(torch.zeros(rank,kv_dim)) | ||
| self.proj_A=nn.Parameter(torch.empty(dim,rank));self.proj_B=nn.Parameter(torch.zeros(rank,dim)) | ||
| for n in('q_A','k_A','v_A','proj_A'):nn.init.kaiming_uniform_(getattr(self,n),a=math.sqrt(5)) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In PyTorch, for a 2D tensor of shape (A, B) passed to nn.init.kaiming_uniform_ with mode='fan_in', which dimension is treated as fan_in — A (size(0)) or B (size(1))?
💡 Result:
In PyTorch, for a 2D tensor of shape (A, B) passed to nn.init.kaiming_uniform_ with mode='fan_in', the fan_in dimension is B, which is tensor.size(1).
Citations:
- 1: https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
- 2: https://pytorch.org/docs/2.5/_modules/torch/nn/init.html
- 3: a trivial question in torch.nn.init._calculate_fan_in_and_fan_out pytorch/pytorch#29089
🏁 Script executed:
# Check if train_gpt_mtp.py exists and examine lines 108-115
find . -name "train_gpt_mtp.py" -type fRepository: Meirzhan05/parameter-golf
Length of output: 87
🏁 Script executed:
sed -n '108,115p' ./train_gpt_mtp.pyRepository: Meirzhan05/parameter-golf
Length of output: 634
Non-standard Kaiming fan_in for LoRA A matrices.
The A parameters are stored transposed relative to the original/PEFT LoRA convention: here q_A has shape (dim, rank), whereas Microsoft LoRA / HuggingFace PEFT store lora_A as (rank, in_features). nn.init.kaiming_uniform_ with the default mode='fan_in' reads tensor.size(1), so this code uses fan_in = rank (e.g. 8), not fan_in = dim (e.g. 512). The resulting init bound is roughly sqrt(dim/rank) ≈ 8× larger than the standard choice.
Because B is zero-initialized, the initial forward output is still zero, so this is not a correctness bug. However, the first gradients on B are proportional to A, so an oversized A makes the effective LoRA learning rate ~8× higher than the configured lora_lr for the first several steps. Worth confirming this is intentional; if not, the standard fix is either to shape A as (rank, dim) or to pass fan_in explicitly:
♻️ Proposed fix (match standard LoRA init scale)
- for n in('q_A','k_A','v_A','proj_A'):nn.init.kaiming_uniform_(getattr(self,n),a=math.sqrt(5))
+ # A is (dim, rank); kaiming_uniform_ uses size(1) as fan_in, which would give fan_in=rank.
+ # Standard LoRA init uses fan_in=dim, so fill with U(-bound, bound), bound = sqrt(1/dim).
+ bound=math.sqrt(1./dim)
+ for n in('q_A','k_A','v_A','proj_A'):nn.init.uniform_(getattr(self,n),-bound,bound)🧰 Tools
🪛 Ruff (0.15.11)
[error] 111-111: Multiple statements on one line (semicolon)
(E702)
[error] 112-112: Multiple statements on one line (semicolon)
(E702)
[error] 113-113: Multiple statements on one line (semicolon)
(E702)
[error] 114-114: Multiple statements on one line (semicolon)
(E702)
[error] 115-115: Multiple statements on one line (colon)
(E701)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@train_gpt_mtp.py` around lines 108 - 115, The LoRA "A" matrices (q_A, k_A,
v_A, proj_A) are shaped (dim, rank) causing kaiming_uniform_ to use fan_in=rank
instead of dim and produce an oversized init; change the A parameter shapes to
(rank, dim) to match standard LoRA/PEFT convention (and adjust any downstream
multiplications accordingly) so nn.init.kaiming_uniform_ computes fan_in=dim as
intended, or alternatively if you must keep (dim, rank) compute and pass the
correct fan_in explicitly when initializing those tensors; update q_A, k_A, v_A,
proj_A initialization to the chosen convention and ensure q_B/k_B/v_B/proj_B
shapes and any forward logic are adjusted to remain consistent.
| model.eval() | ||
| with torch.no_grad(): | ||
| for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does torch.matmul on CUDA auto-promote when one operand is torch.bfloat16 and the other is torch.float32, or does it raise RuntimeError ("expected mat1 and mat2 to have the same dtype")?
💡 Result:
torch.matmul on CUDA raises RuntimeError ("expected mat1 and mat2 to have the same dtype") when one operand is torch.bfloat16 and the other is torch.float32. It does not auto-promote. PyTorch matrix multiplication operations, including torch.matmul, require the two input tensors (mat1 and mat2) to have the same dtype. This is enforced in the underlying cuBLAS GEMM calls and meta checks, as evidenced by multiple error reports and discussions. While PyTorch supports dtype promotion in elementwise arithmetic operations (add, mul, etc.) per tensor attribute docs, matrix multiplication does not auto-promote between bfloat16 and float32. Users must explicitly cast one tensor to match the other's dtype, e.g., a.to(b.dtype) @ b. This behavior is consistent across CUDA backends, with errors like "mat1 and mat2 must have the same dtype", "self and mat2 must have the same dtype, but got Half and Float", and explicit checks added in recent versions for compile-time detection. Official docs for torch.matmul do not mention promotion between these dtypes, and all examples use matching dtypes. Mixed dtype matmul requires manual handling, often via torch.amp autocast for training.
Citations:
- 1: https://discuss.pytorch.org/t/runtimeerror-mat1-and-mat2-must-have-the-same-dtype/166759
- 2: Add dtype mismatch check for mm/bmm meta tensors pytorch/pytorch#172194
- 3: https://www.codegenes.net/blog/pytorch-mat1-and-mat2-must-have-the-same-dtype/
- 4: torch.compile fails with dtype mismatch (Half vs Float) in matrix multiplication operations pytorch/pytorch#166647
- 5: https://pytorch.org/docs/1.10/tensor_attributes.html
Calibration crashes when LoRA is active: bf16 × fp32 matmul raises dtype error.
When LORA_RANK>0 and num_loops>0, the end of training leaves base_model.looping_active=True (set at line 494). serialize then calls collect_hessians(base_model, …) which runs model.forward_logits(x) inside torch.no_grad() but without any torch.autocast. The model is bf16 (line 445), but restore_fp32_params keeps AttentionLoRA in fp32 (line 283), so at line 102 the code executes x @ lora.q_A with x in bf16 and lora.q_A in fp32. PyTorch's matmul raises RuntimeError: expected mat1 and mat2 to have the same dtype (BFloat16 vs Float) — aborting GPTQ. During training/eval this is masked because everything runs under torch.autocast(..., dtype=torch.bfloat16).
Two options (pick one):
🛡️ Option A — wrap calibration in bf16 autocast (minimal, matches other forward call sites)
- model.eval()
- with torch.no_grad():
- for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)
+ model.eval()
+ with torch.no_grad(),torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True):
+ for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x)🛡️ Option B — make LoRA dtype-safe like CastedLinear (future-proof)
Cast the fp32 LoRA tensors to x.dtype at the call site (same pattern CastedLinear already uses for its weights):
- if lora is not None:q=q+(x@lora.q_A)@lora.q_B;k=k+(x@lora.k_A)@lora.k_B;v=v+(x@lora.v_A)@lora.v_B
+ if lora is not None:
+ qA=lora.q_A.to(x.dtype);qB=lora.q_B.to(x.dtype);kA=lora.k_A.to(x.dtype);kB=lora.k_B.to(x.dtype);vA=lora.v_A.to(x.dtype);vB=lora.v_B.to(x.dtype)
+ q=q+(x@qA)@qB;k=k+(x@kA)@kB;v=v+(x@vA)@vB
...
- if lora is not None:out=out+(y@lora.proj_A)@lora.proj_B
+ if lora is not None:pA=lora.proj_A.to(y.dtype);pB=lora.proj_B.to(y.dtype);out=out+(y@pA)@pB📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model.eval() | |
| with torch.no_grad(): | |
| for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) | |
| model.eval() | |
| with torch.no_grad(),torch.autocast(device_type='cuda',dtype=torch.bfloat16,enabled=True): | |
| for _ in range(n_calibration_batches):x,_=train_loader.next_batch(h.train_batch_tokens,h.grad_accum_steps);model.forward_logits(x) |
🧰 Tools
🪛 Ruff (0.15.11)
[error] 311-311: Multiple statements on one line (colon)
(E701)
[error] 311-311: Multiple statements on one line (semicolon)
(E702)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@train_gpt_mtp.py` around lines 309 - 311, Calibration runs
model.forward_logits under torch.no_grad() without autocast, causing bf16 × fp32
matmul errors when AttentionLoRA (restore_fp32_params) keeps LoRA tensors in
fp32 (e.g., lora.q_A) while the model is bf16; fix by wrapping the calibration
block (the code around model.eval() / with torch.no_grad(): for _ in
range(n_calibration_batches): ... model.forward_logits(x)) in the same
torch.autocast context used elsewhere (torch.autocast(device_type=<appropriate
device type>, dtype=torch.bfloat16)) so inputs and fp32 LoRA tensors are cast
consistently during collect_hessians/serialize; alternatively if you prefer
Option B, change the AttentionLoRA call site (where x @ lora.q_A happens) to
cast lora tensors to x.dtype (same pattern as CastedLinear) so matmuls are
dtype-safe.
Summary
AttentionLoRAclass providing low-rank corrections to attention Q/K/V/proj projectionsnum_loops=2) gets unique LoRA adapters, recovering expressivity lost from weight sharingLORA_RANKenv var (default 0 = disabled, zero overhead) andLORA_LR(default 0.01)Test plan
ast.parse)Generated with Claude Code
Summary by CodeRabbit