-
Notifications
You must be signed in to change notification settings - Fork 31.7k
update loss computation in modeling code #37993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| loss = None | ||
| if labels is not None: | ||
| if labels is not None or kwargs.get("shift_labels", None) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not leave it to the loss function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, but it means that we have to compute the loss outside the model's forward with
the user have to do
from transformers import ForCausalLMLoss
shift_labels = ....
loss = ForCausalLMLoss(logits=logits, shift_labels=shift_labels, vocab_size=model.config.vocab_size)It's not too much work for users, but it is nice if we could make it easier for them (they only need to take care of preparing shift_labels ).
Enable this API means people could perform the same workflow with labels and with shift_labels (i.e. put them in model.forward and get it from outputs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we can't just pass shift_labels without passing labels
loss = ForCausalLMLoss(logits=logits, shift_labels=shift_labels, vocab_size=model.config.vocab_size)
as it is a required positional argument. We have to do
loss = ForCausalLMLoss(logits=logits, labels=None, shift_labels=shift_labels, vocab_size=model.config.vocab_size)
This kind of details could hide from user if we do it in modeling code.
Having to pass labels (no matter what values it contains) when we mean to use shift_labels is kind confusing.
It could work with model.forward if we pass
model(input_ids=x, labels=labels (or 'anything except None'), shift_labels=shift_labels)
but that is also confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, someone reported using a side-effect hack - passing labels=non_None_garbage, shift_labels=real_data and getting the model() to compute the loss ;)
so you know people want this feature ;)
honestly the only reason I didn't propose it is because I didn't want to do it for 200+ files ;) So I'm grateful to @ydshieh for taking the lead on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what I meant @ydshieh is that if we do this we need to update all models!
src/transformers/utils/generic.py
Outdated
| """ | ||
|
|
||
| num_items_in_batch: Optional[int] | ||
| shift_labels: Optional[torch.Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is somehow not super, as shift_labels is specific for ForCausalLMLoss. Not sure if we want to expose it here or within
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
which is in many modeling files.
|
FWIW, Liger-kernel has just implemented this feature: linkedin/Liger-Kernel#683 |
|
@bot /style |
1 similar comment
|
@bot /style |
|
Style fixes have been applied. View the workflow run here. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to have this merged, just needs:
- either all model update with make fix copies
- put this in the loss_function instead (not sure if possible)
|
|
||
| loss = None | ||
| if labels is not None: | ||
| if labels is not None or kwargs.get("shift_labels", None) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what I meant @ydshieh is that if we do this we need to update all models!
This is unfortunately impossible, as in current modeling code, it only compute and if we want to use I will go
🤞 |
|
Ready for an review 🙏 Mostly, change
to
for |
|
I would love to change |
|
It looks like I introduced it here 2 months ago, replicating the long-time pre-existing internal variable name: I agree that |
|
Yeah, I understand. No big deal, I guess probably people focus on make AI brrrr would never care if there is |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a big fan of this, IDK if we have a better solution. WDYT about a flag saying they are already shifted? (it's just that kwargs are not to be used here but I mean.... yeah
|
IIRC, @ArthurZucker you mean we allow right? Technically it's doable, but the new flag (say, One possibility is we allow it could work and it could avoid the changes in the many modeling files. But you see the downside (same argument having 2 possible types and perform something differently). @stas00 Any comment here? |
|
This is probably a better and clean solution https://github.com/huggingface/transformers/pull/38533/files Let me know your opinions |
|
@ydshieh, the last one is a smooth solution! |
What does this PR do?
cc @stas00
Follow up of #36607, see this comment.
The following code snippet gives
on main
with this PR.
This is particular necessary for context parallel to run correctly.
Once the changes is approved, I will update all other places and add documentations.