Skip to content

Commit e7337ee

Browse files
authored
Pass num_items_in_batch directly to loss computation (#36753)
* Pass num_items_in_batch directly to loss computation * use self loss instead * fix loss kwrgs * fix vocab size
1 parent 8b479e3 commit e7337ee

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import torch
2323
from torch import nn
24-
from torch.nn import CrossEntropyLoss
2524

2625
from ...configuration_utils import PretrainedConfig
2726
from ...generation import GenerationMixin
@@ -582,6 +581,9 @@ def forward(
582581
```"""
583582
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
584583

584+
# num_items_in_batch is only needed for loss computation
585+
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
586+
585587
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
586588

587589
kwargs_decoder = {
@@ -638,8 +640,13 @@ def forward(
638640
loss = None
639641
if labels is not None:
640642
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
641-
loss_fct = CrossEntropyLoss()
642-
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
643+
644+
loss = self.loss_function(
645+
logits=logits,
646+
labels=labels,
647+
vocab_size=self.decoder.config.vocab_size,
648+
num_items_in_batch=num_items_in_batch,
649+
)
643650

644651
if not return_dict:
645652
if loss is not None:

0 commit comments

Comments
 (0)