Skip to content

Commit 29d439a

Browse files
authored
[DPO] average_log_prob when loss is IPO (huggingface#1265)
* average_log_prob when loss is IPO * updated docs with the fix
1 parent 5760e5d commit 29d439a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

docs/source/dpo_trainer.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Given the preference data, we can fit a binary classifier according to the Bradl
8686

8787
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
8888

89-
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.
89+
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
9090

9191
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
9292

trl/trainer/dpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def concatenated_forward(
966966
all_logps = self.get_batch_logps(
967967
all_logits,
968968
concatenated_batch["concatenated_labels"],
969-
average_log_prob=False,
969+
average_log_prob=self.loss_type == "ipo",
970970
is_encoder_decoder=self.is_encoder_decoder,
971971
label_pad_token_id=self.label_pad_token_id,
972972
)

0 commit comments

Comments
 (0)