File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
src/transformers/models/vision_encoder_decoder Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change 2121
2222import torch
2323from torch import nn
24- from torch .nn import CrossEntropyLoss
2524
2625from ...configuration_utils import PretrainedConfig
2726from ...generation import GenerationMixin
@@ -582,6 +581,9 @@ def forward(
582581 ```"""
583582 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
584583
584+ # num_items_in_batch is only needed for loss computation
585+ num_items_in_batch = kwargs .pop ("num_items_in_batch" , None )
586+
585587 kwargs_encoder = {argument : value for argument , value in kwargs .items () if not argument .startswith ("decoder_" )}
586588
587589 kwargs_decoder = {
@@ -638,8 +640,13 @@ def forward(
638640 loss = None
639641 if labels is not None :
640642 logits = decoder_outputs .logits if return_dict else decoder_outputs [0 ]
641- loss_fct = CrossEntropyLoss ()
642- loss = loss_fct (logits .reshape (- 1 , self .decoder .config .vocab_size ), labels .reshape (- 1 ))
643+
644+ loss = self .loss_function (
645+ logits = logits ,
646+ labels = labels ,
647+ vocab_size = self .decoder .config .vocab_size ,
648+ num_items_in_batch = num_items_in_batch ,
649+ )
643650
644651 if not return_dict :
645652 if loss is not None :
You can’t perform that action at this time.
0 commit comments