Skip to content

Conversation

@stas00
Copy link
Contributor

@stas00 stas00 commented Mar 7, 2025

I'm porting DeepSpeed Ulysses sequence parallelism from Megatron-Deepspeed to the HF transformers ecosphere so that everybody could use this SP implementation. I'm integrating it into ArcticTraining and then once everything is neat we can integrate it into HF Accelerate and make it available to many frameworks/users.

One of the nuances of SP implementation is that each rank computes a shard of the loss and then the loss/grads are merged together - this allows for sequence length of 1M and more.

The problem emerges when loss is computed:

In unsharded seqlen logits we end up with (shift left):

input_ids: [1 2 3 4 5 6 7    8   ]
labels   : [1 2 3 4 5 6 7    8   ]
shiftedl : [2 3 4 5 6 7 8 -100]

when sharded seqlen logits (each gpu processes half seqlen in this example) we lose label 5 once shifted:

input_ids: [1 2 3    4] [5 6 7    8]
labels   : [1 2 3    4] [5 6 7    8]
shiftedl : [2 3 4 -100] [6 7 8 -100]

so we either need the ForCausalLMLoss API to allow for the user to provide the padding token to replace -100 with 5 in this case, or a much simpler solution is to just let the user do the shifting. This PR proposes the latter.

An alternative API change would be to pass labels=shift_labels - so no API change here, but to add a flag are_labels_shifted=False

So 2 ways:

def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, 
    shift_labels=None, **kwargs
def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, 
    are_labels_shifted=False, **kwargs

whatever works for you is good for me. The PR is currently made for (1).

@stas00 stas00 marked this pull request as ready for review March 7, 2025 19:40
@Rocketknight1
Copy link
Member

Related to #36609. This PR seems simple enough that we could probably accept it as-is, but cc @ArthurZucker @Cyrilvallez if you think it's okay with our core philosophy

@sfc-gh-sbekman
Copy link
Contributor

Also please note at the end of the OP I have an alternative proposal that isn't in the PR, which might be neater? Not sure.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking in shift labels is one I also agree with, solves a few headaches @ArthurZucker and I ran into when it came to what other models take.

@stas00
Copy link
Contributor Author

stas00 commented Mar 12, 2025

Thank you, Zach! So we just need to decide which API is neater:

def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, 
    shift_labels=None, **kwargs
def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, 
    are_labels_shifted=False, **kwargs

@sfc-gh-sbekman
Copy link
Contributor

sfc-gh-sbekman commented Mar 18, 2025

could someone please hit the merge button? unless we are waiting for someone else to review?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for that! eager to see the full feature merged!

@ArthurZucker ArthurZucker merged commit 8f64b17 into huggingface:main Mar 20, 2025
21 checks passed
@ArthurZucker
Copy link
Collaborator

sorry @stas00 for the delay

@sfc-gh-sbekman
Copy link
Contributor

Thank you, so much, Arthur!

@Triang-jyed-driung
Copy link

This issue is not resolved:
Take Qwen2ForCausalLM for example.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L842
If one only passes shifted labels, this function will simply return loss=None.

@stas00
Copy link
Contributor Author

stas00 commented Apr 5, 2025

I think you're implying that the head of the model is also shifted_labels aware, but that's not the case.

This PR only changed ForCausalLMLoss which only works if you pass the batch w/o labels to model(**batch) call and then you manually calculate the loss using shifted labels.

Before:

loss = model(**batch)

After:

batch, labels = remove_labels(batch)
outputs = model(**batch)
[...]
shift_labels = do_shift_labels(labels)
loss = model.loss_function(logits=outputs.logits,..., labels=None, ..., shift_labels=shift_labels)

Of course, it's possible to make the Before support shift_labels, but that would be a much bigger change.

@Triang-jyed-driung
Copy link

Actually, with v4.50.3, I found that passing model(input_ids=x, labels='anything except None', shift_labels=y) will return the correct loss with shifted labels.

@stas00
Copy link
Contributor Author

stas00 commented Apr 6, 2025

Heh, yes, nice! I can see how this would work, because shift_labels gets passed via kwargs. But this is an unintended behaviour. I'd suggest to bring it up in a separate issue to be made intended.

If accepted the PR could be:

if labels is not None or kwargs.get("shift_labels", None) is not None:
   loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

then labels=None would work and shift_labels becomes part of the official API for model head calling (and is tested). The tricky part now is to find all models that use ForCausalLMLoss loss and apply the above change and also would need to document that the behavior is different for models of this loss function, since other models will not do anything about the shift_labels arg.

@stas00 stas00 deleted the loss-pre-shifted-labels branch April 6, 2025 16:14
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* [ForCausalLMLoss] allow users to pass shifted labels

Signed-off-by: Stas Bekman <[email protected]>

* style

Signed-off-by: Stas Bekman <[email protected]>

---------

Signed-off-by: Stas Bekman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants