Skip to content

Commit f02b1be

Browse files
committed
Fixing the Loop Gemini identifed.
Signed-off-by: phaelon74 <[email protected]>
1 parent a7b0e8a commit f02b1be

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/llmcompressor/modeling/glm4_moe.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6262
has_tokens = token_indices.numel() > 0
6363

6464
if self.calibrate_all_experts:
65-
# Send all tokens to this expert for calibration
66-
expert_input = hidden_states
67-
expert_output = expert(expert_input)
68-
69-
if has_tokens:
70-
# Still use routing weights for final output combination
71-
expert_weights = topk_weights[token_indices, weight_indices]
72-
weighted_output = expert_output[
73-
token_indices
74-
] * expert_weights.unsqueeze(-1)
75-
final_hidden_states.index_add_(0, token_indices, weighted_output)
65+
# When calibrating, run all tokens through the expert to gather stats.
66+
# The output is still calculated using only the routed tokens.
67+
expert_output_full = expert(hidden_states)
68+
if not has_tokens:
69+
continue # No tokens routed to this expert, but stats were gathered.
70+
expert_output = expert_output_full[token_indices]
7671
else:
77-
# Normal MoE: only process tokens routed to this expert
78-
if has_tokens:
79-
expert_input = hidden_states[token_indices]
80-
expert_output = expert(expert_input)
81-
expert_weights = topk_weights[token_indices, weight_indices]
82-
weighted_output = expert_output * expert_weights.unsqueeze(-1)
83-
final_hidden_states.index_add_(0, token_indices, weighted_output)
72+
# Standard MoE behavior: only process tokens routed to this expert.
73+
if not has_tokens:
74+
continue
75+
expert_output = expert(hidden_states[token_indices])
76+
77+
# Common logic for combining expert outputs
78+
expert_weights = topk_weights[token_indices, weight_indices]
79+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
80+
final_hidden_states.index_add_(0, token_indices, weighted_output)
8481
# End MoE
8582

8683
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)

0 commit comments

Comments
 (0)