Skip to content

Commit d8d3c40

Browse files
committed
use latest __init__ standards and auto-generate modular
1 parent c54f804 commit d8d3c40

File tree

2 files changed

+19
-83
lines changed

2 files changed

+19
-83
lines changed

src/transformers/models/minimax_text_01/__init__.py

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,15 @@
1515
# limitations under the License.
1616
from typing import TYPE_CHECKING
1717

18-
from ...utils import (
19-
OptionalDependencyNotAvailable,
20-
_LazyModule,
21-
is_torch_available,
22-
)
23-
24-
25-
_import_structure = {
26-
"configuration_minimax_text_01": ["MiniMaxText01Config"],
27-
}
28-
29-
30-
try:
31-
if not is_torch_available():
32-
raise OptionalDependencyNotAvailable()
33-
except OptionalDependencyNotAvailable:
34-
pass
35-
else:
36-
_import_structure["modeling_minimax_text_01"] = [
37-
"MiniMaxText01ForCausalLM",
38-
"MiniMaxText01ForQuestionAnswering",
39-
"MiniMaxText01Model",
40-
"MiniMaxText01PreTrainedModel",
41-
"MiniMaxText01ForSequenceClassification",
42-
"MiniMaxText01ForTokenClassification",
43-
]
18+
from ...utils import _LazyModule
19+
from ...utils.import_utils import define_import_structure
4420

4521

4622
if TYPE_CHECKING:
47-
from .configuration_minimax_text_01 import MiniMaxText01Config
48-
49-
try:
50-
if not is_torch_available():
51-
raise OptionalDependencyNotAvailable()
52-
except OptionalDependencyNotAvailable:
53-
pass
54-
else:
55-
from .modeling_minimax_text_01 import (
56-
MiniMaxText01ForCausalLM,
57-
MiniMaxText01ForQuestionAnswering,
58-
MiniMaxText01ForSequenceClassification,
59-
MiniMaxText01ForTokenClassification,
60-
MiniMaxText01Model,
61-
MiniMaxText01PreTrainedModel,
62-
)
63-
64-
23+
from .configuration_minimax_text_01 import *
24+
from .modeling_minimax_text_01 import *
6525
else:
6626
import sys
6727

68-
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
28+
_file = globals()["__file__"]
29+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

src/transformers/models/minimax_text_01/modeling_minimax_text_01.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,6 @@ def eager_attention_forward(
259259
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
260260
attn_weights = attn_weights + causal_mask
261261

262-
# print()
263-
# ic(module.layer_idx)
264-
# show_tensor(query, False, True)
265-
# show_tensor(key_states, False, True)
266-
# show_tensor(value_states, False, True)
267-
# show_tensor(attn_weights, False, True)
268-
269262
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
270263
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
271264
attn_output = torch.matmul(attn_weights, value_states)
@@ -310,23 +303,11 @@ def forward(
310303
cos, sin = position_embeddings
311304
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
312305

313-
# print(self.layer_idx)
314-
# show_tensor(query_states, end=False, only_shapes=False)
315-
# show_tensor(key_states, end=False, only_shapes=True)
316-
# show_tensor(value_states, end=True, only_shapes=True)
317-
318-
# print()
319-
# print()
320-
# ic(self.layer_idx)
321-
# show_tensor(key_states, False, True)
322-
323306
if past_key_value is not None:
324307
# sin and cos are specific to RoPE models; cache_position needed for the static cache
325308
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
326309
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
327310

328-
# show_tensor(key_states, False, True)
329-
330311
attention_interface: Callable = eager_attention_forward
331312
if self.config._attn_implementation != "eager":
332313
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -351,10 +332,6 @@ def forward(
351332

352333
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
353334
attn_output = self.o_proj(attn_output)
354-
355-
# ic(self.layer_idx)
356-
# show_tensor(attn_output, False, True)
357-
358335
return attn_output, attn_weights
359336

360337

@@ -592,7 +569,7 @@ def _dynamic_frequency_update(self, position_ids, device):
592569
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
593570
"""
594571
seq_len = torch.max(position_ids) + 1
595-
if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update
572+
if seq_len > self.max_seq_len_cached: # growth
596573
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
597574
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
598575
self.max_seq_len_cached = seq_len
@@ -628,7 +605,7 @@ def forward(self, x, position_ids):
628605
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
629606

630607

631-
MINI_MAX_TEXT01_START_DOCSTRING = r"""
608+
MINIMAX_TEXT_01_START_DOCSTRING = r"""
632609
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
633610
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
634611
etc.)
@@ -647,7 +624,7 @@ def forward(self, x, position_ids):
647624

648625
@add_start_docstrings(
649626
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
650-
MINI_MAX_TEXT01_START_DOCSTRING,
627+
MINIMAX_TEXT_01_START_DOCSTRING,
651628
)
652629
class MiniMaxText01PreTrainedModel(PreTrainedModel):
653630
config_class = MiniMaxText01Config
@@ -674,7 +651,7 @@ def _init_weights(self, module):
674651
module.weight.data[module.padding_idx].zero_()
675652

676653

677-
MINI_MAX_TEXT01_INPUTS_DOCSTRING = r"""
654+
MINIMAX_TEXT_01_INPUTS_DOCSTRING = r"""
678655
Args:
679656
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
680657
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -751,7 +728,7 @@ def _init_weights(self, module):
751728

752729
@add_start_docstrings(
753730
"The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
754-
MINI_MAX_TEXT01_START_DOCSTRING,
731+
MINIMAX_TEXT_01_START_DOCSTRING,
755732
)
756733
class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
757734
"""
@@ -783,7 +760,7 @@ def get_input_embeddings(self):
783760
def set_input_embeddings(self, value):
784761
self.embed_tokens = value
785762

786-
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
763+
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
787764
def forward(
788765
self,
789766
input_ids: torch.LongTensor = None,
@@ -820,7 +797,6 @@ def forward(
820797
)
821798
use_cache = False
822799

823-
# TODO: raise exception here?
824800
if use_cache and past_key_values is None:
825801
past_key_values = DynamicCache()
826802

@@ -1173,7 +1149,7 @@ def set_decoder(self, decoder):
11731149
def get_decoder(self):
11741150
return self.model
11751151

1176-
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
1152+
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
11771153
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
11781154
def forward(
11791155
self,
@@ -1222,7 +1198,6 @@ def forward(
12221198
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
12231199
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
12241200
```"""
1225-
# ic(input_ids.shape, input_ids)
12261201

12271202
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
12281203
output_router_logits = (
@@ -1299,7 +1274,7 @@ def forward(
12991274
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
13001275
each row of the batch).
13011276
""",
1302-
MINI_MAX_TEXT01_START_DOCSTRING,
1277+
MINIMAX_TEXT_01_START_DOCSTRING,
13031278
)
13041279
class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
13051280
def __init__(self, config):
@@ -1317,7 +1292,7 @@ def get_input_embeddings(self):
13171292
def set_input_embeddings(self, value):
13181293
self.model.embed_tokens = value
13191294

1320-
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
1295+
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
13211296
def forward(
13221297
self,
13231298
input_ids: Optional[torch.LongTensor] = None,
@@ -1395,7 +1370,7 @@ def forward(
13951370
The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
13961371
output) e.g. for Named-Entity-Recognition (NER) tasks.
13971372
""",
1398-
MINI_MAX_TEXT01_START_DOCSTRING,
1373+
MINIMAX_TEXT_01_START_DOCSTRING,
13991374
)
14001375
class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel):
14011376
def __init__(self, config):
@@ -1420,7 +1395,7 @@ def get_input_embeddings(self):
14201395
def set_input_embeddings(self, value):
14211396
self.model.embed_tokens = value
14221397

1423-
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
1398+
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
14241399
@add_code_sample_docstrings(
14251400
checkpoint=_CHECKPOINT_FOR_DOC,
14261401
output_type=TokenClassifierOutput,
@@ -1483,7 +1458,7 @@ def forward(
14831458
The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like
14841459
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
14851460
""",
1486-
MINI_MAX_TEXT01_START_DOCSTRING,
1461+
MINIMAX_TEXT_01_START_DOCSTRING,
14871462
)
14881463
class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel):
14891464
base_model_prefix = "model"
@@ -1502,7 +1477,7 @@ def get_input_embeddings(self):
15021477
def set_input_embeddings(self, value):
15031478
self.model.embed_tokens = value
15041479

1505-
@add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING)
1480+
@add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING)
15061481
def forward(
15071482
self,
15081483
input_ids: Optional[torch.LongTensor] = None,

0 commit comments

Comments
 (0)