@@ -259,13 +259,6 @@ def eager_attention_forward(
259259 causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
260260 attn_weights = attn_weights + causal_mask
261261
262- # print()
263- # ic(module.layer_idx)
264- # show_tensor(query, False, True)
265- # show_tensor(key_states, False, True)
266- # show_tensor(value_states, False, True)
267- # show_tensor(attn_weights, False, True)
268-
269262 attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
270263 attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
271264 attn_output = torch .matmul (attn_weights , value_states )
@@ -310,23 +303,11 @@ def forward(
310303 cos , sin = position_embeddings
311304 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
312305
313- # print(self.layer_idx)
314- # show_tensor(query_states, end=False, only_shapes=False)
315- # show_tensor(key_states, end=False, only_shapes=True)
316- # show_tensor(value_states, end=True, only_shapes=True)
317-
318- # print()
319- # print()
320- # ic(self.layer_idx)
321- # show_tensor(key_states, False, True)
322-
323306 if past_key_value is not None :
324307 # sin and cos are specific to RoPE models; cache_position needed for the static cache
325308 cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
326309 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
327310
328- # show_tensor(key_states, False, True)
329-
330311 attention_interface : Callable = eager_attention_forward
331312 if self .config ._attn_implementation != "eager" :
332313 if self .config ._attn_implementation == "sdpa" and kwargs .get ("output_attentions" , False ):
@@ -351,10 +332,6 @@ def forward(
351332
352333 attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
353334 attn_output = self .o_proj (attn_output )
354-
355- # ic(self.layer_idx)
356- # show_tensor(attn_output, False, True)
357-
358335 return attn_output , attn_weights
359336
360337
@@ -592,7 +569,7 @@ def _dynamic_frequency_update(self, position_ids, device):
592569 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
593570 """
594571 seq_len = torch .max (position_ids ) + 1
595- if seq_len > self .max_seq_len_cached : # growth_dynamic_frequency_update
572+ if seq_len > self .max_seq_len_cached : # growth
596573 inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device , seq_len = seq_len )
597574 self .register_buffer ("inv_freq" , inv_freq , persistent = False ) # TODO joao: may break with compilation
598575 self .max_seq_len_cached = seq_len
@@ -628,7 +605,7 @@ def forward(self, x, position_ids):
628605 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
629606
630607
631- MINI_MAX_TEXT01_START_DOCSTRING = r"""
608+ MINIMAX_TEXT_01_START_DOCSTRING = r"""
632609 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
633610 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
634611 etc.)
@@ -647,7 +624,7 @@ def forward(self, x, position_ids):
647624
648625@add_start_docstrings (
649626 "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top." ,
650- MINI_MAX_TEXT01_START_DOCSTRING ,
627+ MINIMAX_TEXT_01_START_DOCSTRING ,
651628)
652629class MiniMaxText01PreTrainedModel (PreTrainedModel ):
653630 config_class = MiniMaxText01Config
@@ -674,7 +651,7 @@ def _init_weights(self, module):
674651 module .weight .data [module .padding_idx ].zero_ ()
675652
676653
677- MINI_MAX_TEXT01_INPUTS_DOCSTRING = r"""
654+ MINIMAX_TEXT_01_INPUTS_DOCSTRING = r"""
678655 Args:
679656 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
680657 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -751,7 +728,7 @@ def _init_weights(self, module):
751728
752729@add_start_docstrings (
753730 "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top." ,
754- MINI_MAX_TEXT01_START_DOCSTRING ,
731+ MINIMAX_TEXT_01_START_DOCSTRING ,
755732)
756733class MiniMaxText01Model (MiniMaxText01PreTrainedModel ):
757734 """
@@ -783,7 +760,7 @@ def get_input_embeddings(self):
783760 def set_input_embeddings (self , value ):
784761 self .embed_tokens = value
785762
786- @add_start_docstrings_to_model_forward (MINI_MAX_TEXT01_INPUTS_DOCSTRING )
763+ @add_start_docstrings_to_model_forward (MINIMAX_TEXT_01_INPUTS_DOCSTRING )
787764 def forward (
788765 self ,
789766 input_ids : torch .LongTensor = None ,
@@ -820,7 +797,6 @@ def forward(
820797 )
821798 use_cache = False
822799
823- # TODO: raise exception here?
824800 if use_cache and past_key_values is None :
825801 past_key_values = DynamicCache ()
826802
@@ -1173,7 +1149,7 @@ def set_decoder(self, decoder):
11731149 def get_decoder (self ):
11741150 return self .model
11751151
1176- @add_start_docstrings_to_model_forward (MINI_MAX_TEXT01_INPUTS_DOCSTRING )
1152+ @add_start_docstrings_to_model_forward (MINIMAX_TEXT_01_INPUTS_DOCSTRING )
11771153 @replace_return_docstrings (output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
11781154 def forward (
11791155 self ,
@@ -1222,7 +1198,6 @@ def forward(
12221198 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
12231199 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
12241200 ```"""
1225- # ic(input_ids.shape, input_ids)
12261201
12271202 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
12281203 output_router_logits = (
@@ -1299,7 +1274,7 @@ def forward(
12991274 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
13001275 each row of the batch).
13011276 """ ,
1302- MINI_MAX_TEXT01_START_DOCSTRING ,
1277+ MINIMAX_TEXT_01_START_DOCSTRING ,
13031278)
13041279class MiniMaxText01ForSequenceClassification (MiniMaxText01PreTrainedModel ):
13051280 def __init__ (self , config ):
@@ -1317,7 +1292,7 @@ def get_input_embeddings(self):
13171292 def set_input_embeddings (self , value ):
13181293 self .model .embed_tokens = value
13191294
1320- @add_start_docstrings_to_model_forward (MINI_MAX_TEXT01_INPUTS_DOCSTRING )
1295+ @add_start_docstrings_to_model_forward (MINIMAX_TEXT_01_INPUTS_DOCSTRING )
13211296 def forward (
13221297 self ,
13231298 input_ids : Optional [torch .LongTensor ] = None ,
@@ -1395,7 +1370,7 @@ def forward(
13951370 The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
13961371 output) e.g. for Named-Entity-Recognition (NER) tasks.
13971372 """ ,
1398- MINI_MAX_TEXT01_START_DOCSTRING ,
1373+ MINIMAX_TEXT_01_START_DOCSTRING ,
13991374)
14001375class MiniMaxText01ForTokenClassification (MiniMaxText01PreTrainedModel ):
14011376 def __init__ (self , config ):
@@ -1420,7 +1395,7 @@ def get_input_embeddings(self):
14201395 def set_input_embeddings (self , value ):
14211396 self .model .embed_tokens = value
14221397
1423- @add_start_docstrings_to_model_forward (MINI_MAX_TEXT01_INPUTS_DOCSTRING )
1398+ @add_start_docstrings_to_model_forward (MINIMAX_TEXT_01_INPUTS_DOCSTRING )
14241399 @add_code_sample_docstrings (
14251400 checkpoint = _CHECKPOINT_FOR_DOC ,
14261401 output_type = TokenClassifierOutput ,
@@ -1483,7 +1458,7 @@ def forward(
14831458The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like
14841459SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
14851460 """ ,
1486- MINI_MAX_TEXT01_START_DOCSTRING ,
1461+ MINIMAX_TEXT_01_START_DOCSTRING ,
14871462)
14881463class MiniMaxText01ForQuestionAnswering (MiniMaxText01PreTrainedModel ):
14891464 base_model_prefix = "model"
@@ -1502,7 +1477,7 @@ def get_input_embeddings(self):
15021477 def set_input_embeddings (self , value ):
15031478 self .model .embed_tokens = value
15041479
1505- @add_start_docstrings_to_model_forward (MINI_MAX_TEXT01_INPUTS_DOCSTRING )
1480+ @add_start_docstrings_to_model_forward (MINIMAX_TEXT_01_INPUTS_DOCSTRING )
15061481 def forward (
15071482 self ,
15081483 input_ids : Optional [torch .LongTensor ] = None ,
0 commit comments