Skip to content

Commit 0d69980

Browse files
committed
Updates:
* split encoder-only model out * make t5gemmamodel encoder-decoder only * update token and sequence classification * update tests
1 parent 3780d72 commit 0d69980

File tree

3 files changed

+227
-217
lines changed

3 files changed

+227
-217
lines changed

src/transformers/models/t5gemma/modeling_t5gemma.py

Lines changed: 112 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
BaseModelOutputWithPastAndCrossAttentions,
3636
Seq2SeqLMOutput,
3737
Seq2SeqModelOutput,
38-
Seq2SeqSequenceClassifierOutput,
38+
SequenceClassifierOutput,
3939
TokenClassifierOutput,
4040
)
4141
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
@@ -959,20 +959,19 @@ def forward(
959959
class T5GemmaModel(T5GemmaPreTrainedModel):
960960
def __init__(self, config: T5GemmaConfig):
961961
super().__init__(config)
962-
self.encoder = T5GemmaEncoder(config.encoder)
963962

964-
# In encoder-only mode, only encoder is adopted.
965-
if self.config.is_encoder_decoder:
966-
self.decoder = T5GemmaDecoder(config.decoder)
963+
if not config.is_encoder_decoder:
964+
raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
965+
966+
self.encoder = T5GemmaEncoder(config.encoder)
967+
self.decoder = T5GemmaDecoder(config.decoder)
967968

968969
self.post_init()
969970

970971
def get_encoder(self):
971972
return self.encoder
972973

973974
def get_decoder(self):
974-
if not self.config.is_encoder_decoder:
975-
return None
976975
return self.decoder
977976

978977
def get_input_embeddings(self):
@@ -1027,50 +1026,42 @@ def forward(
10271026

10281027
encoder_hidden_states = encoder_outputs.last_hidden_state
10291028

1030-
if self.config.is_encoder_decoder:
1031-
# Decode
1032-
decoder_outputs = self.decoder(
1033-
input_ids=decoder_input_ids,
1034-
attention_mask=decoder_attention_mask,
1035-
position_ids=decoder_position_ids,
1036-
inputs_embeds=decoder_inputs_embeds,
1037-
past_key_values=past_key_values,
1038-
encoder_hidden_states=encoder_hidden_states,
1039-
encoder_attention_mask=attention_mask,
1040-
use_cache=use_cache,
1041-
output_attentions=output_attentions,
1042-
output_hidden_states=output_hidden_states,
1043-
cache_position=cache_position,
1044-
**flash_attn_kwargs,
1045-
)
1029+
# Decode
1030+
decoder_outputs = self.decoder(
1031+
input_ids=decoder_input_ids,
1032+
attention_mask=decoder_attention_mask,
1033+
position_ids=decoder_position_ids,
1034+
inputs_embeds=decoder_inputs_embeds,
1035+
past_key_values=past_key_values,
1036+
encoder_hidden_states=encoder_hidden_states,
1037+
encoder_attention_mask=attention_mask,
1038+
use_cache=use_cache,
1039+
output_attentions=output_attentions,
1040+
output_hidden_states=output_hidden_states,
1041+
cache_position=cache_position,
1042+
**flash_attn_kwargs,
1043+
)
10461044

1047-
return Seq2SeqModelOutput(
1048-
last_hidden_state=decoder_outputs.last_hidden_state,
1049-
past_key_values=decoder_outputs.past_key_values,
1050-
decoder_hidden_states=decoder_outputs.hidden_states,
1051-
decoder_attentions=decoder_outputs.attentions,
1052-
cross_attentions=decoder_outputs.cross_attentions,
1053-
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1054-
encoder_hidden_states=encoder_outputs.hidden_states,
1055-
encoder_attentions=encoder_outputs.attentions,
1056-
)
1057-
else:
1058-
return Seq2SeqModelOutput(
1059-
last_hidden_state=None,
1060-
past_key_values=None,
1061-
decoder_hidden_states=None,
1062-
decoder_attentions=None,
1063-
cross_attentions=None,
1064-
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1065-
encoder_hidden_states=encoder_outputs.hidden_states,
1066-
encoder_attentions=encoder_outputs.attentions,
1067-
)
1045+
return Seq2SeqModelOutput(
1046+
last_hidden_state=decoder_outputs.last_hidden_state,
1047+
past_key_values=decoder_outputs.past_key_values,
1048+
decoder_hidden_states=decoder_outputs.hidden_states,
1049+
decoder_attentions=decoder_outputs.attentions,
1050+
cross_attentions=decoder_outputs.cross_attentions,
1051+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1052+
encoder_hidden_states=encoder_outputs.hidden_states,
1053+
encoder_attentions=encoder_outputs.attentions,
1054+
)
10681055

10691056

10701057
@auto_docstring
10711058
class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
10721059
def __init__(self, config: T5GemmaConfig):
10731060
super().__init__(config)
1061+
1062+
if config.is_encoder_decoder:
1063+
raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.")
1064+
10741065
self.encoder = T5GemmaEncoder(config.encoder)
10751066
self.post_init()
10761067

@@ -1175,13 +1166,6 @@ def forward(
11751166
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
11761167
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
11771168
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1178-
1179-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
1180-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1181-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1182-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1183-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1184-
This is useful when using packed tensor format (single dimension for batch and sequence length).
11851169
"""
11861170
if self.training and self.config._attn_implementation != "eager":
11871171
logger.warning_once(
@@ -1253,10 +1237,14 @@ def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = N
12531237
config.is_encoder_decoder = is_encoder_decoder
12541238
super().__init__(config)
12551239
self.num_labels = config.num_labels
1256-
self.model = T5GemmaModel(config)
1240+
1241+
if config.is_encoder_decoder:
1242+
self.model = T5GemmaModel(config)
1243+
else:
1244+
self.model = T5GemmaEncoderModel(config)
12571245

12581246
hidden_size = config.encoder.hidden_size
1259-
if is_encoder_decoder:
1247+
if config.is_encoder_decoder:
12601248
hidden_size = config.decoder.hidden_size
12611249

12621250
classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
@@ -1288,7 +1276,7 @@ def forward(
12881276
labels: Optional[torch.LongTensor] = None,
12891277
output_attentions: Optional[bool] = None,
12901278
output_hidden_states: Optional[bool] = None,
1291-
) -> Seq2SeqSequenceClassifierOutput:
1279+
) -> SequenceClassifierOutput:
12921280
r"""
12931281
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
12941282
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
@@ -1314,27 +1302,38 @@ def forward(
13141302
)
13151303
decoder_input_ids = self._shift_right(input_ids)
13161304

1317-
outputs: Seq2SeqModelOutput = self.model(
1318-
input_ids,
1319-
attention_mask=attention_mask,
1320-
position_ids=position_ids,
1321-
decoder_input_ids=decoder_input_ids,
1322-
decoder_attention_mask=decoder_attention_mask,
1323-
decoder_position_ids=decoder_position_ids,
1324-
encoder_outputs=encoder_outputs,
1325-
inputs_embeds=inputs_embeds,
1326-
decoder_inputs_embeds=decoder_inputs_embeds,
1327-
use_cache=False,
1328-
output_attentions=output_attentions,
1329-
output_hidden_states=output_hidden_states,
1330-
)
13311305
if self.config.is_encoder_decoder:
1332-
hidden_states = outputs.last_hidden_state
1333-
if hidden_states is None:
1334-
raise ValueError("Hidden states shouldn't be None under encoder-decoder mode.")
1306+
outputs: Seq2SeqModelOutput = self.model(
1307+
input_ids,
1308+
attention_mask=attention_mask,
1309+
position_ids=position_ids,
1310+
decoder_input_ids=decoder_input_ids,
1311+
decoder_attention_mask=decoder_attention_mask,
1312+
decoder_position_ids=decoder_position_ids,
1313+
encoder_outputs=encoder_outputs,
1314+
inputs_embeds=inputs_embeds,
1315+
decoder_inputs_embeds=decoder_inputs_embeds,
1316+
use_cache=False,
1317+
output_attentions=output_attentions,
1318+
output_hidden_states=output_hidden_states,
1319+
)
1320+
last_hidden_state = outputs.last_hidden_state
1321+
hidden_states = outputs.decoder_hidden_states
1322+
attentions = outputs.decoder_attentions
13351323
else:
1336-
hidden_states = outputs.encoder_last_hidden_state
1337-
logits = self.score(hidden_states)
1324+
outputs: BaseModelOutput = self.model(
1325+
input_ids,
1326+
attention_mask=attention_mask,
1327+
position_ids=position_ids,
1328+
inputs_embeds=inputs_embeds,
1329+
output_attentions=output_attentions,
1330+
output_hidden_states=output_hidden_states,
1331+
)
1332+
last_hidden_state = outputs.last_hidden_state
1333+
hidden_states = outputs.hidden_states
1334+
attentions = outputs.attentions
1335+
1336+
logits = self.score(last_hidden_state)
13381337

13391338
if input_ids is not None:
13401339
batch_size = input_ids.shape[0]
@@ -1367,16 +1366,11 @@ def forward(
13671366
if labels is not None:
13681367
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
13691368

1370-
return Seq2SeqSequenceClassifierOutput(
1369+
return SequenceClassifierOutput(
13711370
loss=loss,
13721371
logits=pooled_logits,
1373-
past_key_values=outputs.past_key_values,
1374-
decoder_hidden_states=outputs.decoder_hidden_states,
1375-
decoder_attentions=outputs.decoder_attentions,
1376-
cross_attentions=outputs.cross_attentions,
1377-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1378-
encoder_hidden_states=outputs.encoder_hidden_states,
1379-
encoder_attentions=outputs.encoder_attentions,
1372+
hidden_states=hidden_states,
1373+
attentions=attentions,
13801374
)
13811375

13821376

@@ -1391,10 +1385,14 @@ def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = N
13911385
config.is_encoder_decoder = is_encoder_decoder
13921386
super().__init__(config)
13931387
self.num_labels = config.num_labels
1394-
self.model = T5GemmaModel(config)
1388+
1389+
if config.is_encoder_decoder:
1390+
self.model = T5GemmaModel(config)
1391+
else:
1392+
self.model = T5GemmaEncoderModel(config)
13951393

13961394
hidden_size = config.encoder.hidden_size
1397-
if is_encoder_decoder:
1395+
if config.is_encoder_decoder:
13981396
hidden_size = config.decoder.hidden_size
13991397

14001398
classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1)
@@ -1453,29 +1451,38 @@ def forward(
14531451
)
14541452
decoder_input_ids = self._shift_right(input_ids)
14551453

1456-
outputs: Seq2SeqModelOutput = self.model(
1457-
input_ids,
1458-
attention_mask=attention_mask,
1459-
position_ids=position_ids,
1460-
decoder_input_ids=decoder_input_ids,
1461-
decoder_attention_mask=decoder_attention_mask,
1462-
decoder_position_ids=decoder_position_ids,
1463-
encoder_outputs=encoder_outputs,
1464-
inputs_embeds=inputs_embeds,
1465-
decoder_inputs_embeds=decoder_inputs_embeds,
1466-
use_cache=False,
1467-
output_attentions=output_attentions,
1468-
output_hidden_states=output_hidden_states,
1469-
)
14701454
if self.config.is_encoder_decoder:
1471-
hidden_states = outputs.last_hidden_state
1472-
if hidden_states is None:
1473-
raise ValueError("Hidden states shouldn't be None under encoder-decoder mode.")
1455+
outputs: Seq2SeqModelOutput = self.model(
1456+
input_ids,
1457+
attention_mask=attention_mask,
1458+
position_ids=position_ids,
1459+
decoder_input_ids=decoder_input_ids,
1460+
decoder_attention_mask=decoder_attention_mask,
1461+
decoder_position_ids=decoder_position_ids,
1462+
encoder_outputs=encoder_outputs,
1463+
inputs_embeds=inputs_embeds,
1464+
decoder_inputs_embeds=decoder_inputs_embeds,
1465+
use_cache=False,
1466+
output_attentions=output_attentions,
1467+
output_hidden_states=output_hidden_states,
1468+
)
1469+
last_hidden_state = outputs.last_hidden_state
1470+
hidden_states = outputs.decoder_hidden_states
1471+
attentions = outputs.decoder_attentions
14741472
else:
1475-
hidden_states = outputs.encoder_last_hidden_state
1473+
outputs: BaseModelOutput = self.model(
1474+
input_ids,
1475+
attention_mask=attention_mask,
1476+
position_ids=position_ids,
1477+
inputs_embeds=inputs_embeds,
1478+
output_attentions=output_attentions,
1479+
output_hidden_states=output_hidden_states,
1480+
)
1481+
last_hidden_state = outputs.last_hidden_state
1482+
hidden_states = outputs.hidden_states
1483+
attentions = outputs.attentions
14761484

1477-
sequence_output = hidden_states
1478-
logits = self.score(sequence_output)
1485+
logits = self.score(last_hidden_state)
14791486

14801487
loss = None
14811488
if labels is not None:
@@ -1484,10 +1491,8 @@ def forward(
14841491
return TokenClassifierOutput(
14851492
loss=loss,
14861493
logits=logits,
1487-
hidden_states=outputs.decoder_hidden_states
1488-
if self.config.is_encoder_decoder
1489-
else outputs.encoder_hidden_states,
1490-
attentions=outputs.decoder_attentions if self.config.is_encoder_decoder else outputs.encoder_attentions,
1494+
hidden_states=hidden_states,
1495+
attentions=attentions,
14911496
)
14921497

14931498

0 commit comments

Comments
 (0)