Skip to content

Commit 2227ac4

Browse files
mzhong4claude
andcommitted
experiment: 3x attn balance weight + bal_loss=0.5 to fix expert collapse
Attention expert collapses to 1 expert ([0.029,0.052,0.413,0.506]). Fix: 3x stronger balance loss for attention router specifically, plus overall balance weight increased from 0.1 to 0.5. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 35ebd5e commit 2227ac4

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

train_gpt.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,11 +1303,15 @@ def _collect_routing_losses(self, device: torch.device) -> tuple[Tensor, Tensor,
13031303
"""Collect balance, sparsity, and orthogonality losses from all routers."""
13041304
zero = torch.tensor(0.0, device=device)
13051305
bal, spar, ortho = zero, zero, zero
1306-
# All SoftDenseRouters: attn + mlp
1307-
routers = [self.shared_block.attn.attn_router, self.shared_block.mlp.mlp_router]
1308-
for r in routers:
1309-
bal = bal + getattr(r, '_balance_loss', zero)
1310-
spar = spar + getattr(r, '_sparsity_loss', zero)
1306+
# Per-component routing losses with stronger weight for attention (prevents collapse)
1307+
for name, r, bal_weight in [
1308+
("attn", self.shared_block.attn.attn_router, 3.0), # 3x weight for attention
1309+
("mlp", self.shared_block.mlp.mlp_router, 1.0),
1310+
]:
1311+
r_bal = getattr(r, '_balance_loss', zero)
1312+
r_spar = getattr(r, '_sparsity_loss', zero)
1313+
bal = bal + bal_weight * r_bal
1314+
spar = spar + r_spar
13111315
# MoS head routing
13121316
bal = bal + getattr(self.mos_head, '_balance_loss', zero)
13131317
spar = spar + getattr(self.mos_head, '_sparsity_loss', zero)
@@ -1343,7 +1347,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
13431347
# CTP weight scales with refinement steps: at step 0 input is clean one-hot,
13441348
# CTP becomes meaningful only after soft embedding refinement
13451349
ctp_weight = 0.1 * self.num_refinements
1346-
return ntp_loss + ctp_weight * ctp_loss + 0.01 * conv_loss + 0.1 * bal_loss + 0.001 * spar_loss + 0.01 * ortho_loss
1350+
return ntp_loss + ctp_weight * ctp_loss + 0.01 * conv_loss + 0.5 * bal_loss + 0.001 * spar_loss + 0.01 * ortho_loss
13471351

13481352
def forward_logits(self, input_ids: Tensor) -> Tensor:
13491353
x = self._encode(input_ids)

0 commit comments

Comments
 (0)