Skip to content

Commit 68dcd13

Browse files
authored
Fix shapes in modular_gpt_oss.py (#42737)
* Fix shapes in modular_gpt_oss.py * Run make fix-copies
1 parent accb698 commit 68dcd13

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
8888
8989
Args:
9090
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
91-
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
92-
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
92+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
9393
Returns:
9494
torch.Tensor
9595
"""
@@ -159,8 +159,8 @@ def __init__(self, config):
159159

160160
def forward(self, hidden_states):
161161
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
162-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
163-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
162+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
163+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
164164
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
165165
router_scores = router_top_value
166166
return router_logits, router_scores, router_indices

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
8686
8787
Args:
8888
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
89-
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
90-
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
89+
selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
90+
routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
9191
Returns:
9292
torch.Tensor
9393
"""
@@ -157,8 +157,8 @@ def __init__(self, config):
157157

158158
def forward(self, hidden_states):
159159
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
160-
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
161-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
160+
router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
161+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
162162
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
163163
router_scores = router_top_value
164164
return router_logits, router_scores, router_indices

0 commit comments

Comments
 (0)