Skip to content

Commit a1492e9

Browse files
mzhong4claude
andcommitted
feat: full-dim low-rank experts + MoS pure softmax routing
- Replace split-dim experts (dim//E) with full-dim low-rank (dim→rank→dim) so every expert sees all dimensions through a rank bottleneck - MoS routing: pure softmax convex combination (Mixtape paper), removed sigmoid gates (expert_gate_ctp/ntp_logits) - Added configurable attn_expert_rank / mlp_expert_rank hyperparameters - Added MoS eval diagnostics: usage/entropy/balance_cv for CTP+NTP - Updated metrics plot: expert usage shows min/max/mean/median per component (Attn, MLP, MoS CTP, MoS NTP) for scalability - Updated CLAUDE.md Constraint openai#2 with full-dim + Mixtape clarifications - Result: val_bpb=1.4094, attn_cv 0.32→0.22, artifact 15.4MB Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2227ac4 commit a1492e9

5 files changed

Lines changed: 198 additions & 109 deletions

File tree

CLAUDE.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,18 @@ When proposing architecture improvements:
197197
### 2. Soft Dense Routing (Dense MoE on ALL components)
198198
- Paper: Soft MoE (arxiv:2308.00951) — adapted for dense routing
199199
- ALL experts process ALL tokens — no top-k selection, no token dropping
200-
- Routing weights via softmax + input-dependent sigmoid gate
200+
- **Router routes on component INPUT** (pre-computation), consistent across all components
201+
- **Full-dim low-rank experts**: every expert operates on the FULL model hidden dimension.
202+
Use low-rank matrices (dim→rank→dim) to control parameter count.
203+
Do NOT partition dimensions across experts (no `expert_size = dim // num_experts`).
204+
- **Attn/MLP routing**: softmax + per-expert sigmoid gate (SoftDenseRouter)
205+
- **MoS routing (exception)**: pure softmax only (convex combination summing to 1), NO sigmoid gates.
206+
Per Mixtape paper ("Breaking the Softmax Bottleneck Efficiently", NeurIPS 2019).
207+
The softmax bottleneck is broken by the mixture of softmaxes itself, not by gating.
201208
- **Applied to ALL components**: attention output, MLP hidden, MoS output heads
202209
- **Regularization** (per-token sparsity + global balance + orthogonality):
203210
- **Per-token sparsity**: L1 on routing weights (each token concentrates on fewer experts)
204-
- **Global balance**: MSE between mean expert usage and uniform target
211+
- **Global balance**: MSE between mean expert usage and uniform target (per-component)
205212
- **Expert orthogonality**: |cos_sim| between expert weight groups → 0 (not ±1)
206213
- Fully differentiable, no discrete routing decisions
207214

experiments/plot_metrics.py

Lines changed: 113 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
- Row 1: Train Loss, Val BPB, Step Avg (ms)
55
- Row 2: NTP Loss, CTP Loss, Pre-clip Grad Norm
66
- Row 3: DEQ Residual, DEQ Recon Error, DEQ Iter Convergence
7-
- Row 4: Load Balance (per component per expert), Expert Entropy, Expert Orthogonality
8-
- Row 5: Summary text with final values comparison
7+
- Row 4: Expert Usage (min/max/mean/median per component), Expert Entropy, Expert Orthogonality
8+
- Row 5: Balance Loss (per component), Conv Loss, (spare)
9+
- Row 6: Summary text with final values comparison
10+
11+
Expert usage plots show min/max (shaded band), mean (solid), median (dashed) per component
12+
for scalability when expert count grows.
913
"""
1014
import re
1115
import sys
@@ -14,6 +18,8 @@
1418
# Consistent colors: blue for Baseline, orange for Current
1519
COLOR_BASELINE = "#1f77b4" # matplotlib default blue
1620
COLOR_CURRENT = "#ff7f0e" # matplotlib default orange
21+
# Component colors
22+
COMP_COLORS = {"mlp": "#2ca02c", "attn": "#d62728", "mos_ctp": "#9467bd", "mos_ntp": "#8c564b"}
1723

1824

1925
def parse_log(logpath: str) -> dict:
@@ -31,6 +37,10 @@ def parse_log(logpath: str) -> dict:
3137
# Per-component expert usage: each is list of lists
3238
"mlp_usage": [], "attn_usage": [],
3339
"mlp_entropy": [], "attn_entropy": [],
40+
# MoS per-head routing diagnostics
41+
"mos_ctp_usage": [], "mos_ntp_usage": [],
42+
"mos_ctp_entropy": [], "mos_ntp_entropy": [],
43+
"mos_ctp_cv": [], "mos_ntp_cv": [],
3444
# Per-component orthogonality and balance
3545
"mlp_ortho": [], "attn_ortho": [], "mos_ortho": [],
3646
"mlp_bal": [], "attn_bal": [], "mos_bal": [],
@@ -76,7 +86,7 @@ def parse_log(logpath: str) -> dict:
7686
data["expert_usage"].append(usage)
7787
else:
7888
data["expert_usage"].append([])
79-
# Per-component expert usage
89+
# Per-component expert usage (attn, mlp)
8090
for comp in ("mlp", "attn"):
8191
m_u = re.search(rf"{comp}_usage:\[([\d.,]+)\]", line)
8292
if m_u:
@@ -85,6 +95,17 @@ def parse_log(logpath: str) -> dict:
8595
data[f"{comp}_usage"].append([])
8696
m_e = re.search(rf"{comp}_entropy:([\d.]+)", line)
8797
data[f"{comp}_entropy"].append(float(m_e.group(1)) if m_e else 0.0)
98+
# MoS per-head routing diagnostics
99+
for head in ("ctp", "ntp"):
100+
m_u = re.search(rf"mos_{head}_usage:\[([\d.,]+)\]", line)
101+
if m_u:
102+
data[f"mos_{head}_usage"].append([float(v) for v in m_u.group(1).split(",")])
103+
else:
104+
data[f"mos_{head}_usage"].append([])
105+
m_e = re.search(rf"mos_{head}_entropy:([\d.]+)", line)
106+
data[f"mos_{head}_entropy"].append(float(m_e.group(1)) if m_e else 0.0)
107+
m_cv = re.search(rf"mos_{head}_cv:([\d.]+)", line)
108+
data[f"mos_{head}_cv"].append(float(m_cv.group(1)) if m_cv else 0.0)
88109
# Per-component orthogonality and balance
89110
for comp in ("mlp", "attn", "mos"):
90111
m_o = re.search(rf"{comp}_ortho:([\d.]+)", line)
@@ -95,6 +116,25 @@ def parse_log(logpath: str) -> dict:
95116
return data
96117

97118

119+
def _usage_stats(usage_list):
120+
"""Compute min/max/mean/median per step from list of expert usage lists."""
121+
import numpy as np
122+
mins, maxs, means, medians = [], [], [], []
123+
for u in usage_list:
124+
if u:
125+
arr = np.array(u)
126+
mins.append(arr.min())
127+
maxs.append(arr.max())
128+
means.append(arr.mean())
129+
medians.append(np.median(arr))
130+
else:
131+
mins.append(0.0)
132+
maxs.append(0.0)
133+
means.append(0.0)
134+
medians.append(0.0)
135+
return mins, maxs, means, medians
136+
137+
98138
def _plot_line(ax, b, c, b_key, c_key, b_steps, c_steps, title, ylabel=None):
99139
"""Plot two line series on the same axis with consistent colors."""
100140
if b[b_key] and c[c_key]:
@@ -112,6 +152,22 @@ def _plot_line(ax, b, c, b_key, c_key, b_steps, c_steps, title, ylabel=None):
112152
ax.grid(True, alpha=0.3)
113153

114154

155+
def _plot_usage_stats(ax, data, steps_key, usage_key, color, label_prefix, linestyle="-"):
156+
"""Plot expert usage as min-max shaded band + mean solid + median dashed."""
157+
usage_list = data[usage_key]
158+
if not usage_list or not any(u for u in usage_list):
159+
return
160+
mins, maxs, means, medians = _usage_stats(usage_list)
161+
steps = data[steps_key]
162+
if not steps:
163+
return
164+
ax.fill_between(steps, mins, maxs, color=color, alpha=0.15)
165+
ax.plot(steps, means, color=color, linestyle=linestyle, alpha=0.8,
166+
label=f"{label_prefix} mean", linewidth=1.5)
167+
ax.plot(steps, medians, color=color, linestyle="--", alpha=0.5,
168+
label=f"{label_prefix} median", linewidth=1.0)
169+
170+
115171
def plot_comparison(baseline_log: str, current_log: str, outdir: str):
116172
"""Plot baseline vs current experiment comparison with full training curves."""
117173
try:
@@ -147,62 +203,52 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
147203
_plot_line(axes[2, 2], b, c, "deq_iter_conv", "deq_iter_conv", "val_steps", "val_steps",
148204
"DEQ Iter Conv ||z_T - z_{T-1}||")
149205

150-
# Row 4: Expert diagnostics
151-
# Load Balance: per-component (MLP, Attn) per-expert lines when available,
152-
# falling back to combined expert_usage for older logs.
153-
ax_lb = axes[3, 0]
154-
_has_per_comp = any(len(u) > 0 for u in b.get("mlp_usage", []) + c.get("mlp_usage", []))
155-
156-
if _has_per_comp:
157-
# Per-component expert usage: different colors per component, line styles per expert
158-
comp_colors = {"mlp": ("#2ca02c", "#98df8a"), "attn": ("#d62728", "#ff9896")}
159-
line_styles = ["-", "--", ":", "-."]
160-
for comp in ("mlp", "attn"):
161-
all_usage = b[f"{comp}_usage"] + c[f"{comp}_usage"]
162-
max_e = max((len(u) for u in all_usage), default=0)
163-
b_color, c_color = comp_colors[comp]
164-
for ei in range(max_e):
165-
ls = line_styles[ei % len(line_styles)]
166-
b_vals = [u[ei] if ei < len(u) else 0.0 for u in b[f"{comp}_usage"]]
167-
c_vals = [u[ei] if ei < len(u) else 0.0 for u in c[f"{comp}_usage"]]
168-
if b_vals and b["val_steps"]:
169-
ax_lb.plot(b["val_steps"], b_vals, color=b_color, linestyle=ls,
170-
alpha=0.7, label=f"B {comp} E{ei}", linewidth=1.5)
171-
if c_vals and c["val_steps"]:
172-
ax_lb.plot(c["val_steps"], c_vals, color=c_color, linestyle=ls,
173-
alpha=0.7, label=f"C {comp} E{ei}", linewidth=1.5)
174-
else:
206+
# Row 4: Expert diagnostics (usage, entropy, orthogonality)
207+
# Usage: min/max/mean/median per component (Attn, MLP, MoS CTP, MoS NTP)
208+
ax_usage = axes[3, 0]
209+
usage_keys = [
210+
("mlp", "mlp_usage", COMP_COLORS["mlp"]),
211+
("attn", "attn_usage", COMP_COLORS["attn"]),
212+
("mos_ctp", "mos_ctp_usage", COMP_COLORS["mos_ctp"]),
213+
("mos_ntp", "mos_ntp_usage", COMP_COLORS["mos_ntp"]),
214+
]
215+
has_any_usage = False
216+
for comp_label, ukey, color in usage_keys:
217+
for dataset, prefix, ls in [(b, f"B {comp_label}", "-"), (c, f"C {comp_label}", "-")]:
218+
if any(u for u in dataset.get(ukey, [])):
219+
has_any_usage = True
220+
_plot_usage_stats(ax_usage, dataset, "val_steps", ukey, color, prefix, ls)
221+
222+
if not has_any_usage:
175223
# Fallback: combined expert_usage (older logs)
176-
max_experts = max((len(u) for u in b["expert_usage"] + c["expert_usage"]), default=0)
177-
line_styles = ["-", "--", ":", "-."]
178-
for ei in range(max_experts):
179-
b_vals = [u[ei] if ei < len(u) else 0.0 for u in b["expert_usage"]]
180-
c_vals = [u[ei] if ei < len(u) else 0.0 for u in c["expert_usage"]]
181-
ls = line_styles[ei % len(line_styles)]
182-
if b_vals and b["val_steps"]:
183-
ax_lb.plot(b["val_steps"], b_vals, color=COLOR_BASELINE, linestyle=ls,
184-
alpha=0.7, label=f"B Expert {ei}", linewidth=1.5)
185-
if c_vals and c["val_steps"]:
186-
ax_lb.plot(c["val_steps"], c_vals, color=COLOR_CURRENT, linestyle=ls,
187-
alpha=0.7, label=f"C Expert {ei}", linewidth=1.5)
188-
ax_lb.set_title("Expert Usage (per Component)", fontsize=11)
189-
ax_lb.set_xlabel("Step")
190-
ax_lb.legend(fontsize=6, ncol=2)
191-
ax_lb.grid(True, alpha=0.3)
192-
193-
# Expert Entropy: per-component lines
224+
_plot_usage_stats(ax_usage, b, "val_steps", "expert_usage", COLOR_BASELINE, "B")
225+
_plot_usage_stats(ax_usage, c, "val_steps", "expert_usage", COLOR_CURRENT, "C")
226+
227+
ax_usage.set_title("Expert Usage (min/max/mean/median per Component)", fontsize=11)
228+
ax_usage.set_xlabel("Step")
229+
ax_usage.set_ylabel("Usage fraction")
230+
ax_usage.legend(fontsize=6, ncol=2)
231+
ax_usage.grid(True, alpha=0.3)
232+
233+
# Expert Entropy: per-component lines (Attn, MLP, MoS CTP, MoS NTP)
194234
ax_ent = axes[3, 1]
195-
comp_colors_ent = {"mlp": "#2ca02c", "attn": "#d62728"}
196-
for comp, color in comp_colors_ent.items():
197-
key = f"{comp}_entropy"
198-
if b[key] and any(v > 0 for v in b[key]):
199-
ax_ent.plot(b["val_steps"], b[key], color=color, linestyle="-", alpha=0.7,
200-
label=f"B {comp}", linewidth=1.5)
201-
if c[key] and any(v > 0 for v in c[key]):
202-
ax_ent.plot(c["val_steps"], c[key], color=color, linestyle="--", alpha=0.7,
203-
label=f"C {comp}", linewidth=1.5)
204-
# Fallback to combined entropy
205-
if not any(v > 0 for v in b.get("mlp_entropy", []) + c.get("mlp_entropy", [])):
235+
ent_keys = [
236+
("mlp", "mlp_entropy", COMP_COLORS["mlp"]),
237+
("attn", "attn_entropy", COMP_COLORS["attn"]),
238+
("mos_ctp", "mos_ctp_entropy", COMP_COLORS["mos_ctp"]),
239+
("mos_ntp", "mos_ntp_entropy", COMP_COLORS["mos_ntp"]),
240+
]
241+
has_any_ent = False
242+
for comp_label, ekey, color in ent_keys:
243+
if b.get(ekey) and any(v > 0 for v in b[ekey]):
244+
ax_ent.plot(b["val_steps"], b[ekey], color=color, linestyle="-", alpha=0.7,
245+
label=f"B {comp_label}", linewidth=1.5)
246+
has_any_ent = True
247+
if c.get(ekey) and any(v > 0 for v in c[ekey]):
248+
ax_ent.plot(c["val_steps"], c[ekey], color=color, linestyle="--", alpha=0.7,
249+
label=f"C {comp_label}", linewidth=1.5)
250+
has_any_ent = True
251+
if not has_any_ent:
206252
_plot_line(ax_ent, b, c, "expert_entropy", "expert_entropy", "val_steps", "val_steps", "")
207253
ax_ent.set_title("Expert Entropy (per Component)", fontsize=11)
208254
ax_ent.set_xlabel("Step")
@@ -211,12 +257,13 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
211257

212258
# Expert Orthogonality: per-component lines
213259
ax_ort = axes[3, 2]
214-
for comp, color in {"mlp": "#2ca02c", "attn": "#d62728", "mos": "#9467bd"}.items():
260+
for comp, color in {"mlp": COMP_COLORS["mlp"], "attn": COMP_COLORS["attn"],
261+
"mos": COMP_COLORS["mos_ctp"]}.items():
215262
key = f"{comp}_ortho"
216-
if b[key] and any(v > 0 for v in b[key]):
263+
if b.get(key) and any(v > 0 for v in b[key]):
217264
ax_ort.plot(b["val_steps"], b[key], color=color, linestyle="-", alpha=0.7,
218265
label=f"B {comp}", linewidth=1.5)
219-
if c[key] and any(v > 0 for v in c[key]):
266+
if c.get(key) and any(v > 0 for v in c[key]):
220267
ax_ort.plot(c["val_steps"], c[key], color=color, linestyle="--", alpha=0.7,
221268
label=f"C {comp}", linewidth=1.5)
222269
if not any(v > 0 for v in b.get("mlp_ortho", []) + c.get("mlp_ortho", [])):
@@ -226,23 +273,23 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
226273
ax_ort.legend(fontsize=7)
227274
ax_ort.grid(True, alpha=0.3)
228275

229-
# Row 5: Regularization losses (balance, sparsity, conv_loss)
230-
# Balance loss per component
276+
# Row 5: Regularization losses (balance, conv_loss, spare)
231277
ax_bal = axes[4, 0]
232-
for comp, color in {"mlp": "#2ca02c", "attn": "#d62728", "mos": "#9467bd"}.items():
278+
for comp, color in {"mlp": COMP_COLORS["mlp"], "attn": COMP_COLORS["attn"],
279+
"mos": COMP_COLORS["mos_ctp"]}.items():
233280
key = f"{comp}_bal"
234-
if b[key] and any(v > 0 for v in b[key]):
281+
if b.get(key) and any(v > 0 for v in b[key]):
235282
ax_bal.plot(b["val_steps"], b[key], color=color, linestyle="-", alpha=0.7,
236283
label=f"B {comp}", linewidth=1.5)
237-
if c[key] and any(v > 0 for v in c[key]):
284+
if c.get(key) and any(v > 0 for v in c[key]):
238285
ax_bal.plot(c["val_steps"], c[key], color=color, linestyle="--", alpha=0.7,
239286
label=f"C {comp}", linewidth=1.5)
240287
ax_bal.set_title("Balance Loss (per Component)", fontsize=11)
241288
ax_bal.set_xlabel("Step")
242289
ax_bal.legend(fontsize=7)
243290
ax_bal.grid(True, alpha=0.3)
244291

245-
# Conv loss (from training lines)
292+
# Conv loss
246293
_plot_line(axes[4, 1], b, c, "ctp_loss", "ctp_loss", "train_steps", "train_steps",
247294
"Convergence Loss (from train)")
248295
axes[4, 2].axis("off") # spare slot

experiments/smoke_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def _get_expert_diagnostics(model):
5454
diag["mlp_usage"] = mr._expert_usage
5555
diag["mlp_entropy"] = mr._expert_entropy
5656
diag["mlp_balance_cv"] = mr._expert_balance_cv
57+
# MoS routing diagnostics
58+
mos = model.mos_head
59+
for head in ("ctp", "ntp"):
60+
usage = getattr(mos, f'_{head}_expert_usage', None)
61+
if usage is not None:
62+
diag[f"mos_{head}_usage"] = usage
63+
diag[f"mos_{head}_entropy"] = getattr(mos, f'_{head}_expert_entropy', 0)
64+
diag[f"mos_{head}_balance_cv"] = getattr(mos, f'_{head}_expert_balance_cv', 0)
5765
# Orthogonality (from 3D expert weight tensors [num_experts, rows, cols])
5866
with torch.no_grad():
5967
for name, w in [
@@ -81,6 +89,7 @@ def smoke_test(num_steps: int = 300, eval_every: int = 50):
8189
logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init,
8290
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
8391
kv_latent_dim=args.kv_latent_dim, num_refinements=args.num_refinements,
92+
attn_expert_rank=args.attn_expert_rank, mlp_expert_rank=args.mlp_expert_rank,
8493
).cuda()
8594

8695
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
@@ -130,6 +139,11 @@ def smoke_test(num_steps: int = 300, eval_every: int = 50):
130139
if "attn_usage" in diag:
131140
print(f" attn: usage={diag['attn_usage']} entropy={diag['attn_entropy']:.4f} "
132141
f"balance_cv={diag['attn_balance_cv']:.4f} ortho={diag.get('attn_ortho', 0):.4f}")
142+
for head in ("ctp", "ntp"):
143+
if f"mos_{head}_usage" in diag:
144+
print(f" mos_{head}: usage={diag[f'mos_{head}_usage']} "
145+
f"entropy={diag[f'mos_{head}_entropy']:.4f} "
146+
f"balance_cv={diag[f'mos_{head}_balance_cv']:.4f}")
133147

134148
# --- Results ---
135149
print(f"\n--- Smoke Test Results ---")

experiments/test_arch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def test_all_constraints():
4343

4444
# Check constraint #4: FSQ in MoS Head
4545
assert hasattr(model, 'mos_head'), "Must have MoS output head"
46-
assert hasattr(model.mos_head, 'expert_gate_ctp_logits'), "Must have CTP expert gates"
47-
assert hasattr(model.mos_head, 'expert_gate_ntp_logits'), "Must have NTP expert gates"
46+
assert hasattr(model.mos_head, 'gate_ctp'), "Must have CTP gate (pure softmax routing)"
47+
assert hasattr(model.mos_head, 'gate_ntp'), "Must have NTP gate (pure softmax routing)"
48+
assert not hasattr(model.mos_head, 'expert_gate_ctp_logits'), "Sigmoid gates removed (Mixtape)"
4849

4950
# Check constraint #5: Diffusion-AR (refinement)
5051
assert model.num_refinements >= 1, "Must have at least 1 refinement step"

0 commit comments

Comments
 (0)