44- Row 1: Train Loss, Val BPB, Step Avg (ms)
55- Row 2: NTP Loss, CTP Loss, Pre-clip Grad Norm
66- Row 3: DEQ Residual, DEQ Recon Error, DEQ Iter Convergence
7- - Row 4: Load Balance (per component per expert), Expert Entropy, Expert Orthogonality
8- - Row 5: Summary text with final values comparison
7+ - Row 4: Expert Usage (min/max/mean/median per component), Expert Entropy, Expert Orthogonality
8+ - Row 5: Balance Loss (per component), Conv Loss, (spare)
9+ - Row 6: Summary text with final values comparison
10+
11+ Expert usage plots show min/max (shaded band), mean (solid), median (dashed) per component
12+ for scalability when expert count grows.
913"""
1014import re
1115import sys
1418# Consistent colors: blue for Baseline, orange for Current
1519COLOR_BASELINE = "#1f77b4" # matplotlib default blue
1620COLOR_CURRENT = "#ff7f0e" # matplotlib default orange
21+ # Component colors
22+ COMP_COLORS = {"mlp" : "#2ca02c" , "attn" : "#d62728" , "mos_ctp" : "#9467bd" , "mos_ntp" : "#8c564b" }
1723
1824
1925def parse_log (logpath : str ) -> dict :
@@ -31,6 +37,10 @@ def parse_log(logpath: str) -> dict:
3137 # Per-component expert usage: each is list of lists
3238 "mlp_usage" : [], "attn_usage" : [],
3339 "mlp_entropy" : [], "attn_entropy" : [],
40+ # MoS per-head routing diagnostics
41+ "mos_ctp_usage" : [], "mos_ntp_usage" : [],
42+ "mos_ctp_entropy" : [], "mos_ntp_entropy" : [],
43+ "mos_ctp_cv" : [], "mos_ntp_cv" : [],
3444 # Per-component orthogonality and balance
3545 "mlp_ortho" : [], "attn_ortho" : [], "mos_ortho" : [],
3646 "mlp_bal" : [], "attn_bal" : [], "mos_bal" : [],
@@ -76,7 +86,7 @@ def parse_log(logpath: str) -> dict:
7686 data ["expert_usage" ].append (usage )
7787 else :
7888 data ["expert_usage" ].append ([])
79- # Per-component expert usage
89+ # Per-component expert usage (attn, mlp)
8090 for comp in ("mlp" , "attn" ):
8191 m_u = re .search (rf"{ comp } _usage:\[([\d.,]+)\]" , line )
8292 if m_u :
@@ -85,6 +95,17 @@ def parse_log(logpath: str) -> dict:
8595 data [f"{ comp } _usage" ].append ([])
8696 m_e = re .search (rf"{ comp } _entropy:([\d.]+)" , line )
8797 data [f"{ comp } _entropy" ].append (float (m_e .group (1 )) if m_e else 0.0 )
98+ # MoS per-head routing diagnostics
99+ for head in ("ctp" , "ntp" ):
100+ m_u = re .search (rf"mos_{ head } _usage:\[([\d.,]+)\]" , line )
101+ if m_u :
102+ data [f"mos_{ head } _usage" ].append ([float (v ) for v in m_u .group (1 ).split ("," )])
103+ else :
104+ data [f"mos_{ head } _usage" ].append ([])
105+ m_e = re .search (rf"mos_{ head } _entropy:([\d.]+)" , line )
106+ data [f"mos_{ head } _entropy" ].append (float (m_e .group (1 )) if m_e else 0.0 )
107+ m_cv = re .search (rf"mos_{ head } _cv:([\d.]+)" , line )
108+ data [f"mos_{ head } _cv" ].append (float (m_cv .group (1 )) if m_cv else 0.0 )
88109 # Per-component orthogonality and balance
89110 for comp in ("mlp" , "attn" , "mos" ):
90111 m_o = re .search (rf"{ comp } _ortho:([\d.]+)" , line )
@@ -95,6 +116,25 @@ def parse_log(logpath: str) -> dict:
95116 return data
96117
97118
119+ def _usage_stats (usage_list ):
120+ """Compute min/max/mean/median per step from list of expert usage lists."""
121+ import numpy as np
122+ mins , maxs , means , medians = [], [], [], []
123+ for u in usage_list :
124+ if u :
125+ arr = np .array (u )
126+ mins .append (arr .min ())
127+ maxs .append (arr .max ())
128+ means .append (arr .mean ())
129+ medians .append (np .median (arr ))
130+ else :
131+ mins .append (0.0 )
132+ maxs .append (0.0 )
133+ means .append (0.0 )
134+ medians .append (0.0 )
135+ return mins , maxs , means , medians
136+
137+
98138def _plot_line (ax , b , c , b_key , c_key , b_steps , c_steps , title , ylabel = None ):
99139 """Plot two line series on the same axis with consistent colors."""
100140 if b [b_key ] and c [c_key ]:
@@ -112,6 +152,22 @@ def _plot_line(ax, b, c, b_key, c_key, b_steps, c_steps, title, ylabel=None):
112152 ax .grid (True , alpha = 0.3 )
113153
114154
155+ def _plot_usage_stats (ax , data , steps_key , usage_key , color , label_prefix , linestyle = "-" ):
156+ """Plot expert usage as min-max shaded band + mean solid + median dashed."""
157+ usage_list = data [usage_key ]
158+ if not usage_list or not any (u for u in usage_list ):
159+ return
160+ mins , maxs , means , medians = _usage_stats (usage_list )
161+ steps = data [steps_key ]
162+ if not steps :
163+ return
164+ ax .fill_between (steps , mins , maxs , color = color , alpha = 0.15 )
165+ ax .plot (steps , means , color = color , linestyle = linestyle , alpha = 0.8 ,
166+ label = f"{ label_prefix } mean" , linewidth = 1.5 )
167+ ax .plot (steps , medians , color = color , linestyle = "--" , alpha = 0.5 ,
168+ label = f"{ label_prefix } median" , linewidth = 1.0 )
169+
170+
115171def plot_comparison (baseline_log : str , current_log : str , outdir : str ):
116172 """Plot baseline vs current experiment comparison with full training curves."""
117173 try :
@@ -147,62 +203,52 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
147203 _plot_line (axes [2 , 2 ], b , c , "deq_iter_conv" , "deq_iter_conv" , "val_steps" , "val_steps" ,
148204 "DEQ Iter Conv ||z_T - z_{T-1}||" )
149205
150- # Row 4: Expert diagnostics
151- # Load Balance: per-component (MLP, Attn) per-expert lines when available,
152- # falling back to combined expert_usage for older logs.
153- ax_lb = axes [3 , 0 ]
154- _has_per_comp = any (len (u ) > 0 for u in b .get ("mlp_usage" , []) + c .get ("mlp_usage" , []))
155-
156- if _has_per_comp :
157- # Per-component expert usage: different colors per component, line styles per expert
158- comp_colors = {"mlp" : ("#2ca02c" , "#98df8a" ), "attn" : ("#d62728" , "#ff9896" )}
159- line_styles = ["-" , "--" , ":" , "-." ]
160- for comp in ("mlp" , "attn" ):
161- all_usage = b [f"{ comp } _usage" ] + c [f"{ comp } _usage" ]
162- max_e = max ((len (u ) for u in all_usage ), default = 0 )
163- b_color , c_color = comp_colors [comp ]
164- for ei in range (max_e ):
165- ls = line_styles [ei % len (line_styles )]
166- b_vals = [u [ei ] if ei < len (u ) else 0.0 for u in b [f"{ comp } _usage" ]]
167- c_vals = [u [ei ] if ei < len (u ) else 0.0 for u in c [f"{ comp } _usage" ]]
168- if b_vals and b ["val_steps" ]:
169- ax_lb .plot (b ["val_steps" ], b_vals , color = b_color , linestyle = ls ,
170- alpha = 0.7 , label = f"B { comp } E{ ei } " , linewidth = 1.5 )
171- if c_vals and c ["val_steps" ]:
172- ax_lb .plot (c ["val_steps" ], c_vals , color = c_color , linestyle = ls ,
173- alpha = 0.7 , label = f"C { comp } E{ ei } " , linewidth = 1.5 )
174- else :
206+ # Row 4: Expert diagnostics (usage, entropy, orthogonality)
207+ # Usage: min/max/mean/median per component (Attn, MLP, MoS CTP, MoS NTP)
208+ ax_usage = axes [3 , 0 ]
209+ usage_keys = [
210+ ("mlp" , "mlp_usage" , COMP_COLORS ["mlp" ]),
211+ ("attn" , "attn_usage" , COMP_COLORS ["attn" ]),
212+ ("mos_ctp" , "mos_ctp_usage" , COMP_COLORS ["mos_ctp" ]),
213+ ("mos_ntp" , "mos_ntp_usage" , COMP_COLORS ["mos_ntp" ]),
214+ ]
215+ has_any_usage = False
216+ for comp_label , ukey , color in usage_keys :
217+ for dataset , prefix , ls in [(b , f"B { comp_label } " , "-" ), (c , f"C { comp_label } " , "-" )]:
218+ if any (u for u in dataset .get (ukey , [])):
219+ has_any_usage = True
220+ _plot_usage_stats (ax_usage , dataset , "val_steps" , ukey , color , prefix , ls )
221+
222+ if not has_any_usage :
175223 # Fallback: combined expert_usage (older logs)
176- max_experts = max ((len (u ) for u in b ["expert_usage" ] + c ["expert_usage" ]), default = 0 )
177- line_styles = ["-" , "--" , ":" , "-." ]
178- for ei in range (max_experts ):
179- b_vals = [u [ei ] if ei < len (u ) else 0.0 for u in b ["expert_usage" ]]
180- c_vals = [u [ei ] if ei < len (u ) else 0.0 for u in c ["expert_usage" ]]
181- ls = line_styles [ei % len (line_styles )]
182- if b_vals and b ["val_steps" ]:
183- ax_lb .plot (b ["val_steps" ], b_vals , color = COLOR_BASELINE , linestyle = ls ,
184- alpha = 0.7 , label = f"B Expert { ei } " , linewidth = 1.5 )
185- if c_vals and c ["val_steps" ]:
186- ax_lb .plot (c ["val_steps" ], c_vals , color = COLOR_CURRENT , linestyle = ls ,
187- alpha = 0.7 , label = f"C Expert { ei } " , linewidth = 1.5 )
188- ax_lb .set_title ("Expert Usage (per Component)" , fontsize = 11 )
189- ax_lb .set_xlabel ("Step" )
190- ax_lb .legend (fontsize = 6 , ncol = 2 )
191- ax_lb .grid (True , alpha = 0.3 )
192-
193- # Expert Entropy: per-component lines
224+ _plot_usage_stats (ax_usage , b , "val_steps" , "expert_usage" , COLOR_BASELINE , "B" )
225+ _plot_usage_stats (ax_usage , c , "val_steps" , "expert_usage" , COLOR_CURRENT , "C" )
226+
227+ ax_usage .set_title ("Expert Usage (min/max/mean/median per Component)" , fontsize = 11 )
228+ ax_usage .set_xlabel ("Step" )
229+ ax_usage .set_ylabel ("Usage fraction" )
230+ ax_usage .legend (fontsize = 6 , ncol = 2 )
231+ ax_usage .grid (True , alpha = 0.3 )
232+
233+ # Expert Entropy: per-component lines (Attn, MLP, MoS CTP, MoS NTP)
194234 ax_ent = axes [3 , 1 ]
195- comp_colors_ent = {"mlp" : "#2ca02c" , "attn" : "#d62728" }
196- for comp , color in comp_colors_ent .items ():
197- key = f"{ comp } _entropy"
198- if b [key ] and any (v > 0 for v in b [key ]):
199- ax_ent .plot (b ["val_steps" ], b [key ], color = color , linestyle = "-" , alpha = 0.7 ,
200- label = f"B { comp } " , linewidth = 1.5 )
201- if c [key ] and any (v > 0 for v in c [key ]):
202- ax_ent .plot (c ["val_steps" ], c [key ], color = color , linestyle = "--" , alpha = 0.7 ,
203- label = f"C { comp } " , linewidth = 1.5 )
204- # Fallback to combined entropy
205- if not any (v > 0 for v in b .get ("mlp_entropy" , []) + c .get ("mlp_entropy" , [])):
235+ ent_keys = [
236+ ("mlp" , "mlp_entropy" , COMP_COLORS ["mlp" ]),
237+ ("attn" , "attn_entropy" , COMP_COLORS ["attn" ]),
238+ ("mos_ctp" , "mos_ctp_entropy" , COMP_COLORS ["mos_ctp" ]),
239+ ("mos_ntp" , "mos_ntp_entropy" , COMP_COLORS ["mos_ntp" ]),
240+ ]
241+ has_any_ent = False
242+ for comp_label , ekey , color in ent_keys :
243+ if b .get (ekey ) and any (v > 0 for v in b [ekey ]):
244+ ax_ent .plot (b ["val_steps" ], b [ekey ], color = color , linestyle = "-" , alpha = 0.7 ,
245+ label = f"B { comp_label } " , linewidth = 1.5 )
246+ has_any_ent = True
247+ if c .get (ekey ) and any (v > 0 for v in c [ekey ]):
248+ ax_ent .plot (c ["val_steps" ], c [ekey ], color = color , linestyle = "--" , alpha = 0.7 ,
249+ label = f"C { comp_label } " , linewidth = 1.5 )
250+ has_any_ent = True
251+ if not has_any_ent :
206252 _plot_line (ax_ent , b , c , "expert_entropy" , "expert_entropy" , "val_steps" , "val_steps" , "" )
207253 ax_ent .set_title ("Expert Entropy (per Component)" , fontsize = 11 )
208254 ax_ent .set_xlabel ("Step" )
@@ -211,12 +257,13 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
211257
212258 # Expert Orthogonality: per-component lines
213259 ax_ort = axes [3 , 2 ]
214- for comp , color in {"mlp" : "#2ca02c" , "attn" : "#d62728" , "mos" : "#9467bd" }.items ():
260+ for comp , color in {"mlp" : COMP_COLORS ["mlp" ], "attn" : COMP_COLORS ["attn" ],
261+ "mos" : COMP_COLORS ["mos_ctp" ]}.items ():
215262 key = f"{ comp } _ortho"
216- if b [ key ] and any (v > 0 for v in b [key ]):
263+ if b . get ( key ) and any (v > 0 for v in b [key ]):
217264 ax_ort .plot (b ["val_steps" ], b [key ], color = color , linestyle = "-" , alpha = 0.7 ,
218265 label = f"B { comp } " , linewidth = 1.5 )
219- if c [ key ] and any (v > 0 for v in c [key ]):
266+ if c . get ( key ) and any (v > 0 for v in c [key ]):
220267 ax_ort .plot (c ["val_steps" ], c [key ], color = color , linestyle = "--" , alpha = 0.7 ,
221268 label = f"C { comp } " , linewidth = 1.5 )
222269 if not any (v > 0 for v in b .get ("mlp_ortho" , []) + c .get ("mlp_ortho" , [])):
@@ -226,23 +273,23 @@ def plot_comparison(baseline_log: str, current_log: str, outdir: str):
226273 ax_ort .legend (fontsize = 7 )
227274 ax_ort .grid (True , alpha = 0.3 )
228275
229- # Row 5: Regularization losses (balance, sparsity, conv_loss)
230- # Balance loss per component
276+ # Row 5: Regularization losses (balance, conv_loss, spare)
231277 ax_bal = axes [4 , 0 ]
232- for comp , color in {"mlp" : "#2ca02c" , "attn" : "#d62728" , "mos" : "#9467bd" }.items ():
278+ for comp , color in {"mlp" : COMP_COLORS ["mlp" ], "attn" : COMP_COLORS ["attn" ],
279+ "mos" : COMP_COLORS ["mos_ctp" ]}.items ():
233280 key = f"{ comp } _bal"
234- if b [ key ] and any (v > 0 for v in b [key ]):
281+ if b . get ( key ) and any (v > 0 for v in b [key ]):
235282 ax_bal .plot (b ["val_steps" ], b [key ], color = color , linestyle = "-" , alpha = 0.7 ,
236283 label = f"B { comp } " , linewidth = 1.5 )
237- if c [ key ] and any (v > 0 for v in c [key ]):
284+ if c . get ( key ) and any (v > 0 for v in c [key ]):
238285 ax_bal .plot (c ["val_steps" ], c [key ], color = color , linestyle = "--" , alpha = 0.7 ,
239286 label = f"C { comp } " , linewidth = 1.5 )
240287 ax_bal .set_title ("Balance Loss (per Component)" , fontsize = 11 )
241288 ax_bal .set_xlabel ("Step" )
242289 ax_bal .legend (fontsize = 7 )
243290 ax_bal .grid (True , alpha = 0.3 )
244291
245- # Conv loss (from training lines)
292+ # Conv loss
246293 _plot_line (axes [4 , 1 ], b , c , "ctp_loss" , "ctp_loss" , "train_steps" , "train_steps" ,
247294 "Convergence Loss (from train)" )
248295 axes [4 , 2 ].axis ("off" ) # spare slot
0 commit comments