35
35
BaseModelOutputWithPastAndCrossAttentions ,
36
36
Seq2SeqLMOutput ,
37
37
Seq2SeqModelOutput ,
38
- Seq2SeqSequenceClassifierOutput ,
38
+ SequenceClassifierOutput ,
39
39
TokenClassifierOutput ,
40
40
)
41
41
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
@@ -959,20 +959,19 @@ def forward(
959
959
class T5GemmaModel (T5GemmaPreTrainedModel ):
960
960
def __init__ (self , config : T5GemmaConfig ):
961
961
super ().__init__ (config )
962
- self .encoder = T5GemmaEncoder (config .encoder )
963
962
964
- # In encoder-only mode, only encoder is adopted.
965
- if self .config .is_encoder_decoder :
966
- self .decoder = T5GemmaDecoder (config .decoder )
963
+ if not config .is_encoder_decoder :
964
+ raise ValueError ("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead." )
965
+
966
+ self .encoder = T5GemmaEncoder (config .encoder )
967
+ self .decoder = T5GemmaDecoder (config .decoder )
967
968
968
969
self .post_init ()
969
970
970
971
def get_encoder (self ):
971
972
return self .encoder
972
973
973
974
def get_decoder (self ):
974
- if not self .config .is_encoder_decoder :
975
- return None
976
975
return self .decoder
977
976
978
977
def get_input_embeddings (self ):
@@ -1027,50 +1026,42 @@ def forward(
1027
1026
1028
1027
encoder_hidden_states = encoder_outputs .last_hidden_state
1029
1028
1030
- if self .config .is_encoder_decoder :
1031
- # Decode
1032
- decoder_outputs = self .decoder (
1033
- input_ids = decoder_input_ids ,
1034
- attention_mask = decoder_attention_mask ,
1035
- position_ids = decoder_position_ids ,
1036
- inputs_embeds = decoder_inputs_embeds ,
1037
- past_key_values = past_key_values ,
1038
- encoder_hidden_states = encoder_hidden_states ,
1039
- encoder_attention_mask = attention_mask ,
1040
- use_cache = use_cache ,
1041
- output_attentions = output_attentions ,
1042
- output_hidden_states = output_hidden_states ,
1043
- cache_position = cache_position ,
1044
- ** flash_attn_kwargs ,
1045
- )
1029
+ # Decode
1030
+ decoder_outputs = self .decoder (
1031
+ input_ids = decoder_input_ids ,
1032
+ attention_mask = decoder_attention_mask ,
1033
+ position_ids = decoder_position_ids ,
1034
+ inputs_embeds = decoder_inputs_embeds ,
1035
+ past_key_values = past_key_values ,
1036
+ encoder_hidden_states = encoder_hidden_states ,
1037
+ encoder_attention_mask = attention_mask ,
1038
+ use_cache = use_cache ,
1039
+ output_attentions = output_attentions ,
1040
+ output_hidden_states = output_hidden_states ,
1041
+ cache_position = cache_position ,
1042
+ ** flash_attn_kwargs ,
1043
+ )
1046
1044
1047
- return Seq2SeqModelOutput (
1048
- last_hidden_state = decoder_outputs .last_hidden_state ,
1049
- past_key_values = decoder_outputs .past_key_values ,
1050
- decoder_hidden_states = decoder_outputs .hidden_states ,
1051
- decoder_attentions = decoder_outputs .attentions ,
1052
- cross_attentions = decoder_outputs .cross_attentions ,
1053
- encoder_last_hidden_state = encoder_outputs .last_hidden_state ,
1054
- encoder_hidden_states = encoder_outputs .hidden_states ,
1055
- encoder_attentions = encoder_outputs .attentions ,
1056
- )
1057
- else :
1058
- return Seq2SeqModelOutput (
1059
- last_hidden_state = None ,
1060
- past_key_values = None ,
1061
- decoder_hidden_states = None ,
1062
- decoder_attentions = None ,
1063
- cross_attentions = None ,
1064
- encoder_last_hidden_state = encoder_outputs .last_hidden_state ,
1065
- encoder_hidden_states = encoder_outputs .hidden_states ,
1066
- encoder_attentions = encoder_outputs .attentions ,
1067
- )
1045
+ return Seq2SeqModelOutput (
1046
+ last_hidden_state = decoder_outputs .last_hidden_state ,
1047
+ past_key_values = decoder_outputs .past_key_values ,
1048
+ decoder_hidden_states = decoder_outputs .hidden_states ,
1049
+ decoder_attentions = decoder_outputs .attentions ,
1050
+ cross_attentions = decoder_outputs .cross_attentions ,
1051
+ encoder_last_hidden_state = encoder_outputs .last_hidden_state ,
1052
+ encoder_hidden_states = encoder_outputs .hidden_states ,
1053
+ encoder_attentions = encoder_outputs .attentions ,
1054
+ )
1068
1055
1069
1056
1070
1057
@auto_docstring
1071
1058
class T5GemmaEncoderModel (T5GemmaPreTrainedModel ):
1072
1059
def __init__ (self , config : T5GemmaConfig ):
1073
1060
super ().__init__ (config )
1061
+
1062
+ if config .is_encoder_decoder :
1063
+ raise ValueError ("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead." )
1064
+
1074
1065
self .encoder = T5GemmaEncoder (config .encoder )
1075
1066
self .post_init ()
1076
1067
@@ -1175,13 +1166,6 @@ def forward(
1175
1166
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1176
1167
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1177
1168
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1178
-
1179
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
1180
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1181
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1182
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1183
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1184
- This is useful when using packed tensor format (single dimension for batch and sequence length).
1185
1169
"""
1186
1170
if self .training and self .config ._attn_implementation != "eager" :
1187
1171
logger .warning_once (
@@ -1253,10 +1237,14 @@ def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = N
1253
1237
config .is_encoder_decoder = is_encoder_decoder
1254
1238
super ().__init__ (config )
1255
1239
self .num_labels = config .num_labels
1256
- self .model = T5GemmaModel (config )
1240
+
1241
+ if config .is_encoder_decoder :
1242
+ self .model = T5GemmaModel (config )
1243
+ else :
1244
+ self .model = T5GemmaEncoderModel (config )
1257
1245
1258
1246
hidden_size = config .encoder .hidden_size
1259
- if is_encoder_decoder :
1247
+ if config . is_encoder_decoder :
1260
1248
hidden_size = config .decoder .hidden_size
1261
1249
1262
1250
classifier_dropout = getattr (config , "classifier_dropout_rate" , 0.1 )
@@ -1288,7 +1276,7 @@ def forward(
1288
1276
labels : Optional [torch .LongTensor ] = None ,
1289
1277
output_attentions : Optional [bool ] = None ,
1290
1278
output_hidden_states : Optional [bool ] = None ,
1291
- ) -> Seq2SeqSequenceClassifierOutput :
1279
+ ) -> SequenceClassifierOutput :
1292
1280
r"""
1293
1281
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
1294
1282
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
@@ -1314,27 +1302,38 @@ def forward(
1314
1302
)
1315
1303
decoder_input_ids = self ._shift_right (input_ids )
1316
1304
1317
- outputs : Seq2SeqModelOutput = self .model (
1318
- input_ids ,
1319
- attention_mask = attention_mask ,
1320
- position_ids = position_ids ,
1321
- decoder_input_ids = decoder_input_ids ,
1322
- decoder_attention_mask = decoder_attention_mask ,
1323
- decoder_position_ids = decoder_position_ids ,
1324
- encoder_outputs = encoder_outputs ,
1325
- inputs_embeds = inputs_embeds ,
1326
- decoder_inputs_embeds = decoder_inputs_embeds ,
1327
- use_cache = False ,
1328
- output_attentions = output_attentions ,
1329
- output_hidden_states = output_hidden_states ,
1330
- )
1331
1305
if self .config .is_encoder_decoder :
1332
- hidden_states = outputs .last_hidden_state
1333
- if hidden_states is None :
1334
- raise ValueError ("Hidden states shouldn't be None under encoder-decoder mode." )
1306
+ outputs : Seq2SeqModelOutput = self .model (
1307
+ input_ids ,
1308
+ attention_mask = attention_mask ,
1309
+ position_ids = position_ids ,
1310
+ decoder_input_ids = decoder_input_ids ,
1311
+ decoder_attention_mask = decoder_attention_mask ,
1312
+ decoder_position_ids = decoder_position_ids ,
1313
+ encoder_outputs = encoder_outputs ,
1314
+ inputs_embeds = inputs_embeds ,
1315
+ decoder_inputs_embeds = decoder_inputs_embeds ,
1316
+ use_cache = False ,
1317
+ output_attentions = output_attentions ,
1318
+ output_hidden_states = output_hidden_states ,
1319
+ )
1320
+ last_hidden_state = outputs .last_hidden_state
1321
+ hidden_states = outputs .decoder_hidden_states
1322
+ attentions = outputs .decoder_attentions
1335
1323
else :
1336
- hidden_states = outputs .encoder_last_hidden_state
1337
- logits = self .score (hidden_states )
1324
+ outputs : BaseModelOutput = self .model (
1325
+ input_ids ,
1326
+ attention_mask = attention_mask ,
1327
+ position_ids = position_ids ,
1328
+ inputs_embeds = inputs_embeds ,
1329
+ output_attentions = output_attentions ,
1330
+ output_hidden_states = output_hidden_states ,
1331
+ )
1332
+ last_hidden_state = outputs .last_hidden_state
1333
+ hidden_states = outputs .hidden_states
1334
+ attentions = outputs .attentions
1335
+
1336
+ logits = self .score (last_hidden_state )
1338
1337
1339
1338
if input_ids is not None :
1340
1339
batch_size = input_ids .shape [0 ]
@@ -1367,16 +1366,11 @@ def forward(
1367
1366
if labels is not None :
1368
1367
loss = self .loss_function (logits = logits , labels = labels , pooled_logits = pooled_logits , config = self .config )
1369
1368
1370
- return Seq2SeqSequenceClassifierOutput (
1369
+ return SequenceClassifierOutput (
1371
1370
loss = loss ,
1372
1371
logits = pooled_logits ,
1373
- past_key_values = outputs .past_key_values ,
1374
- decoder_hidden_states = outputs .decoder_hidden_states ,
1375
- decoder_attentions = outputs .decoder_attentions ,
1376
- cross_attentions = outputs .cross_attentions ,
1377
- encoder_last_hidden_state = outputs .encoder_last_hidden_state ,
1378
- encoder_hidden_states = outputs .encoder_hidden_states ,
1379
- encoder_attentions = outputs .encoder_attentions ,
1372
+ hidden_states = hidden_states ,
1373
+ attentions = attentions ,
1380
1374
)
1381
1375
1382
1376
@@ -1391,10 +1385,14 @@ def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = N
1391
1385
config .is_encoder_decoder = is_encoder_decoder
1392
1386
super ().__init__ (config )
1393
1387
self .num_labels = config .num_labels
1394
- self .model = T5GemmaModel (config )
1388
+
1389
+ if config .is_encoder_decoder :
1390
+ self .model = T5GemmaModel (config )
1391
+ else :
1392
+ self .model = T5GemmaEncoderModel (config )
1395
1393
1396
1394
hidden_size = config .encoder .hidden_size
1397
- if is_encoder_decoder :
1395
+ if config . is_encoder_decoder :
1398
1396
hidden_size = config .decoder .hidden_size
1399
1397
1400
1398
classifier_dropout = getattr (config , "classifier_dropout_rate" , 0.1 )
@@ -1453,29 +1451,38 @@ def forward(
1453
1451
)
1454
1452
decoder_input_ids = self ._shift_right (input_ids )
1455
1453
1456
- outputs : Seq2SeqModelOutput = self .model (
1457
- input_ids ,
1458
- attention_mask = attention_mask ,
1459
- position_ids = position_ids ,
1460
- decoder_input_ids = decoder_input_ids ,
1461
- decoder_attention_mask = decoder_attention_mask ,
1462
- decoder_position_ids = decoder_position_ids ,
1463
- encoder_outputs = encoder_outputs ,
1464
- inputs_embeds = inputs_embeds ,
1465
- decoder_inputs_embeds = decoder_inputs_embeds ,
1466
- use_cache = False ,
1467
- output_attentions = output_attentions ,
1468
- output_hidden_states = output_hidden_states ,
1469
- )
1470
1454
if self .config .is_encoder_decoder :
1471
- hidden_states = outputs .last_hidden_state
1472
- if hidden_states is None :
1473
- raise ValueError ("Hidden states shouldn't be None under encoder-decoder mode." )
1455
+ outputs : Seq2SeqModelOutput = self .model (
1456
+ input_ids ,
1457
+ attention_mask = attention_mask ,
1458
+ position_ids = position_ids ,
1459
+ decoder_input_ids = decoder_input_ids ,
1460
+ decoder_attention_mask = decoder_attention_mask ,
1461
+ decoder_position_ids = decoder_position_ids ,
1462
+ encoder_outputs = encoder_outputs ,
1463
+ inputs_embeds = inputs_embeds ,
1464
+ decoder_inputs_embeds = decoder_inputs_embeds ,
1465
+ use_cache = False ,
1466
+ output_attentions = output_attentions ,
1467
+ output_hidden_states = output_hidden_states ,
1468
+ )
1469
+ last_hidden_state = outputs .last_hidden_state
1470
+ hidden_states = outputs .decoder_hidden_states
1471
+ attentions = outputs .decoder_attentions
1474
1472
else :
1475
- hidden_states = outputs .encoder_last_hidden_state
1473
+ outputs : BaseModelOutput = self .model (
1474
+ input_ids ,
1475
+ attention_mask = attention_mask ,
1476
+ position_ids = position_ids ,
1477
+ inputs_embeds = inputs_embeds ,
1478
+ output_attentions = output_attentions ,
1479
+ output_hidden_states = output_hidden_states ,
1480
+ )
1481
+ last_hidden_state = outputs .last_hidden_state
1482
+ hidden_states = outputs .hidden_states
1483
+ attentions = outputs .attentions
1476
1484
1477
- sequence_output = hidden_states
1478
- logits = self .score (sequence_output )
1485
+ logits = self .score (last_hidden_state )
1479
1486
1480
1487
loss = None
1481
1488
if labels is not None :
@@ -1484,10 +1491,8 @@ def forward(
1484
1491
return TokenClassifierOutput (
1485
1492
loss = loss ,
1486
1493
logits = logits ,
1487
- hidden_states = outputs .decoder_hidden_states
1488
- if self .config .is_encoder_decoder
1489
- else outputs .encoder_hidden_states ,
1490
- attentions = outputs .decoder_attentions if self .config .is_encoder_decoder else outputs .encoder_attentions ,
1494
+ hidden_states = hidden_states ,
1495
+ attentions = attentions ,
1491
1496
)
1492
1497
1493
1498
0 commit comments