@@ -62,25 +62,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6262 has_tokens = token_indices .numel () > 0
6363
6464 if self .calibrate_all_experts :
65- # Send all tokens to this expert for calibration
66- expert_input = hidden_states
67- expert_output = expert (expert_input )
68-
69- if has_tokens :
70- # Still use routing weights for final output combination
71- expert_weights = topk_weights [token_indices , weight_indices ]
72- weighted_output = expert_output [
73- token_indices
74- ] * expert_weights .unsqueeze (- 1 )
75- final_hidden_states .index_add_ (0 , token_indices , weighted_output )
65+ # When calibrating, run all tokens through the expert to gather stats.
66+ # The output is still calculated using only the routed tokens.
67+ expert_output_full = expert (hidden_states )
68+ if not has_tokens :
69+ continue # No tokens routed to this expert, but stats were gathered.
70+ expert_output = expert_output_full [token_indices ]
7671 else :
77- # Normal MoE: only process tokens routed to this expert
78- if has_tokens :
79- expert_input = hidden_states [token_indices ]
80- expert_output = expert (expert_input )
81- expert_weights = topk_weights [token_indices , weight_indices ]
82- weighted_output = expert_output * expert_weights .unsqueeze (- 1 )
83- final_hidden_states .index_add_ (0 , token_indices , weighted_output )
72+ # Standard MoE behavior: only process tokens routed to this expert.
73+ if not has_tokens :
74+ continue
75+ expert_output = expert (hidden_states [token_indices ])
76+
77+ # Common logic for combining expert outputs
78+ expert_weights = topk_weights [token_indices , weight_indices ]
79+ weighted_output = expert_output * expert_weights .unsqueeze (- 1 )
80+ final_hidden_states .index_add_ (0 , token_indices , weighted_output )
8481 # End MoE
8582
8683 hidden_states = final_hidden_states .type (hidden_states .dtype ).view (* orig_shape )
0 commit comments