Skip to content

Commit e11a00a

Browse files
authored
JetMoe Fix jetmoe after #40132 (#41324)
* update * up
1 parent 1bc75db commit e11a00a

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/transformers/models/jetmoe/modeling_jetmoe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,10 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
490490
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
491491
super().__init__()
492492
self.hidden_size = config.hidden_size
493-
self.self_attn = JetMoeAttention(config, layer_idx)
494493
self.mlp = JetMoeMoE(config)
495494
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
496495
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
496+
self.self_attention = JetMoeAttention(config, layer_idx)
497497

498498
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
499499
def forward(
@@ -510,7 +510,7 @@ def forward(
510510
residual = hidden_states
511511
hidden_states = self.input_layernorm(hidden_states)
512512
# Self Attention
513-
hidden_states, _, _ = self.self_attn(
513+
hidden_states, _, _ = self.self_attention(
514514
hidden_states=hidden_states,
515515
attention_mask=attention_mask,
516516
position_ids=position_ids,

src/transformers/models/jetmoe/modular_jetmoe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,10 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
374374
def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
375375
super().__init__(config, layer_idx)
376376
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
377-
self.self_attn = JetMoeAttention(config, layer_idx)
377+
self.self_attention = JetMoeAttention(config, layer_idx)
378378
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
379379
self.mlp = JetMoeMoE(config)
380+
del self.self_attn
380381

381382
def forward(
382383
self,
@@ -392,7 +393,7 @@ def forward(
392393
residual = hidden_states
393394
hidden_states = self.input_layernorm(hidden_states)
394395
# Self Attention
395-
hidden_states, _, _ = self.self_attn(
396+
hidden_states, _, _ = self.self_attention(
396397
hidden_states=hidden_states,
397398
attention_mask=attention_mask,
398399
position_ids=position_ids,

0 commit comments

Comments
 (0)