File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
src/transformers/models/jetmoe Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -490,10 +490,10 @@ class JetMoeDecoderLayer(GradientCheckpointingLayer):
490490 def __init__ (self , config : JetMoeConfig , layer_idx : Optional [int ] = None ):
491491 super ().__init__ ()
492492 self .hidden_size = config .hidden_size
493- self .self_attn = JetMoeAttention (config , layer_idx )
494493 self .mlp = JetMoeMoE (config )
495494 self .input_layernorm = JetMoeRMSNorm (config .hidden_size )
496495 self .post_attention_layernorm = JetMoeRMSNorm (config .hidden_size )
496+ self .self_attention = JetMoeAttention (config , layer_idx )
497497
498498 @deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
499499 def forward (
@@ -510,7 +510,7 @@ def forward(
510510 residual = hidden_states
511511 hidden_states = self .input_layernorm (hidden_states )
512512 # Self Attention
513- hidden_states , _ , _ = self .self_attn (
513+ hidden_states , _ , _ = self .self_attention (
514514 hidden_states = hidden_states ,
515515 attention_mask = attention_mask ,
516516 position_ids = position_ids ,
Original file line number Diff line number Diff line change @@ -374,9 +374,10 @@ class JetMoeDecoderLayer(LlamaDecoderLayer):
374374 def __init__ (self , config : JetMoeConfig , layer_idx : Optional [int ] = None ):
375375 super ().__init__ (config , layer_idx )
376376 self .input_layernorm = JetMoeRMSNorm (config .hidden_size )
377- self .self_attn = JetMoeAttention (config , layer_idx )
377+ self .self_attention = JetMoeAttention (config , layer_idx )
378378 self .post_attention_layernorm = JetMoeRMSNorm (config .hidden_size )
379379 self .mlp = JetMoeMoE (config )
380+ del self .self_attn
380381
381382 def forward (
382383 self ,
@@ -392,7 +393,7 @@ def forward(
392393 residual = hidden_states
393394 hidden_states = self .input_layernorm (hidden_states )
394395 # Self Attention
395- hidden_states , _ , _ = self .self_attn (
396+ hidden_states , _ , _ = self .self_attention (
396397 hidden_states = hidden_states ,
397398 attention_mask = attention_mask ,
398399 position_ids = position_ids ,
You can’t perform that action at this time.
0 commit comments