-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[ForCausalLMLoss] allow users to pass shifted labels #36607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ForCausalLMLoss] allow users to pass shifted labels #36607
Conversation
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
|
Related to #36609. This PR seems simple enough that we could probably accept it as-is, but cc @ArthurZucker @Cyrilvallez if you think it's okay with our core philosophy |
|
Also please note at the end of the OP I have an alternative proposal that isn't in the PR, which might be neater? Not sure. |
muellerzr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking in shift labels is one I also agree with, solves a few headaches @ArthurZucker and I ran into when it came to what other models take.
|
Thank you, Zach! So we just need to decide which API is neater: |
|
could someone please hit the merge button? unless we are waiting for someone else to review? |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for that! eager to see the full feature merged!
|
sorry @stas00 for the delay |
|
Thank you, so much, Arthur! |
|
This issue is not resolved: |
|
I think you're implying that the head of the model is also shifted_labels aware, but that's not the case. This PR only changed Before: After: Of course, it's possible to make the Before support shift_labels, but that would be a much bigger change. |
|
Actually, with v4.50.3, I found that passing |
|
Heh, yes, nice! I can see how this would work, because If accepted the PR could be: then |
* [ForCausalLMLoss] allow users to pass shifted labels Signed-off-by: Stas Bekman <[email protected]> * style Signed-off-by: Stas Bekman <[email protected]> --------- Signed-off-by: Stas Bekman <[email protected]>
I'm porting DeepSpeed Ulysses sequence parallelism from Megatron-Deepspeed to the HF transformers ecosphere so that everybody could use this SP implementation. I'm integrating it into ArcticTraining and then once everything is neat we can integrate it into HF Accelerate and make it available to many frameworks/users.
One of the nuances of SP implementation is that each rank computes a shard of the loss and then the loss/grads are merged together - this allows for sequence length of 1M and more.
The problem emerges when loss is computed:
In unsharded seqlen logits we end up with (shift left):
when sharded seqlen logits (each gpu processes half seqlen in this example) we lose label 5 once shifted:
so we either need the
ForCausalLMLossAPI to allow for the user to provide the padding token to replace-100with5in this case, or a much simpler solution is to just let the user do the shifting. This PR proposes the latter.An alternative API change would be to pass
labels=shift_labels- so no API change here, but to add a flagare_labels_shifted=FalseSo 2 ways:
whatever works for you is good for me. The PR is currently made for (1).