@@ -88,8 +88,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
8888
8989 Args:
9090 hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
91- selected_experts (torch.Tensor): (batch_size * token_num , top_k)
92- routing_weights (torch.Tensor): (batch_size * token_num, num_experts )
91+ selected_experts (torch.Tensor): (batch_size * seq_len , top_k)
92+ routing_weights (torch.Tensor): (batch_size * seq_len, top_k )
9393 Returns:
9494 torch.Tensor
9595 """
@@ -159,8 +159,8 @@ def __init__(self, config):
159159
160160 def forward (self , hidden_states ):
161161 hidden_states = hidden_states .reshape (- 1 , self .hidden_dim )
162- router_logits = F .linear (hidden_states , self .weight , self .bias ) # (seq_len , num_experts)
163- router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 ) # (seq_len , top_k)
162+ router_logits = F .linear (hidden_states , self .weight , self .bias ) # (num_tokens , num_experts)
163+ router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 ) # (num_tokens , top_k)
164164 router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 , dtype = router_top_value .dtype )
165165 router_scores = router_top_value
166166 return router_logits , router_scores , router_indices
0 commit comments