Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 7, 2025

What does this PR do?

cc @stas00

Follow up of #36607, see this comment.

The following code snippet gives

on main

tensor(7.0538, grad_fn=)
None

with this PR.

tensor(7.0538, grad_fn=)
tensor(7.0538, grad_fn=)

This is particular necessary for context parallel to run correctly.

Once the changes is approved, I will update all other places and add documentations.

import torch
from transformers import AutoModelForCausalLM

repo_id = "meta-llama/Llama-3.2-1B"
token = "YOUR_HF_TOKEN"

model = AutoModelForCausalLM.from_pretrained(repo_id, token=token)
input_ids = torch.ones(size=(1, 16), dtype=torch.int64)
labels = input_ids.clone()
shift_labels = torch.nn.functional.pad(labels[..., 1:], (0, 1), value=-100)
outputs_with_labels = model(input_ids, labels=labels)
outputs_with_shift_labels = model(input_ids, shift_labels=shift_labels)

print(outputs_with_labels.loss)
print(outputs_with_shift_labels.loss)

@github-actions github-actions bot marked this pull request as draft May 7, 2025 09:25
@github-actions
Copy link
Contributor

github-actions bot commented May 7, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh ydshieh marked this pull request as ready for review May 7, 2025 09:54
@ydshieh ydshieh requested a review from ArthurZucker May 7, 2025 09:55

loss = None
if labels is not None:
if labels is not None or kwargs.get("shift_labels", None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not leave it to the loss function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but it means that we have to compute the loss outside the model's forward with
the user have to do

from transformers import ForCausalLMLoss

shift_labels = .... 
loss = ForCausalLMLoss(logits=logits, shift_labels=shift_labels, vocab_size=model.config.vocab_size)

It's not too much work for users, but it is nice if we could make it easier for them (they only need to take care of preparing shift_labels ).

Enable this API means people could perform the same workflow with labels and with shift_labels (i.e. put them in model.forward and get it from outputs).

Copy link
Collaborator Author

@ydshieh ydshieh May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can't just pass shift_labels without passing labels

loss = ForCausalLMLoss(logits=logits, shift_labels=shift_labels, vocab_size=model.config.vocab_size)

as it is a required positional argument. We have to do

loss = ForCausalLMLoss(logits=logits, labels=None, shift_labels=shift_labels, vocab_size=model.config.vocab_size)

This kind of details could hide from user if we do it in modeling code.

Having to pass labels (no matter what values it contains) when we mean to use shift_labels is kind confusing.

It could work with model.forward if we pass

model(input_ids=x, labels=labels (or 'anything except None'), shift_labels=shift_labels)

but that is also confusing.

Copy link
Contributor

@stas00 stas00 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, someone reported using a side-effect hack - passing labels=non_None_garbage, shift_labels=real_data and getting the model() to compute the loss ;)

so you know people want this feature ;)

honestly the only reason I didn't propose it is because I didn't want to do it for 200+ files ;) So I'm grateful to @ydshieh for taking the lead on this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I meant @ydshieh is that if we do this we need to update all models!

@ydshieh ydshieh changed the title update update loss computation in modeling code May 7, 2025
"""

num_items_in_batch: Optional[int]
shift_labels: Optional[torch.Tensor]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somehow not super, as shift_labels is specific for ForCausalLMLoss. Not sure if we want to expose it here or within

class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

which is in many modeling files.

@stas00
Copy link
Contributor

stas00 commented May 7, 2025

FWIW, Liger-kernel has just implemented this feature: linkedin/Liger-Kernel#683
It works in 0.5.9. They had to implement it - otherwise the fused cross-entropy won't be possible w/ shift_labels and a huge performance improvement and memory saving will be lost.

@ydshieh ydshieh marked this pull request as draft May 9, 2025 14:58
@ydshieh ydshieh marked this pull request as ready for review May 9, 2025 14:58
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2025

@bot /style

1 similar comment
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 15, 2025

@bot /style

@github-actions
Copy link
Contributor

Style fixes have been applied. View the workflow run here.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to have this merged, just needs:

  • either all model update with make fix copies
  • put this in the loss_function instead (not sure if possible)


loss = None
if labels is not None:
if labels is not None or kwargs.get("shift_labels", None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what I meant @ydshieh is that if we do this we need to update all models!

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 20, 2025

put this in the loss_function instead (not sure if possible)

This is unfortunately impossible, as in current modeling code, it only compute loss if labels is passed

if labels is not None:
    loss = ...

and if we want to use shift_labels, it won't be taken into account if we also give it labels, which is strange behavior.

I will go

either all model update with make fix copies

🤞

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 31, 2025

Ready for an review 🙏

Mostly, change

if labels is not None:

to

if labels is not None or kwargs.get("shift_labels", None) is not None:

for CausalLM (or some ForConditionalGeneration) models.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 31, 2025

I would love to change shift_labels to shifted_labels (also in def ForCausalLMLoss), as shift_labels sounds like a bool argument. Just nit however.

@stas00
Copy link
Contributor

stas00 commented May 31, 2025

It looks like I introduced it here 2 months ago, replicating the long-time pre-existing internal variable name:
8f64b17

I agree that shifted_labels is a better name, but shift_labels is already adopted in the wild - at least liger-kernel and deepspeed. Liger-kernel already made a public release, deepspeed is imminent. I don't know if others started using it.

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 31, 2025

Yeah, I understand. No big deal, I guess probably people focus on make AI brrrr would never care if there is ed or not 😅

@ydshieh ydshieh requested a review from ArthurZucker June 2, 2025 09:03
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big fan of this, IDK if we have a better solution. WDYT about a flag saying they are already shifted? (it's just that kwargs are not to be used here but I mean.... yeah

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 2, 2025

IIRC, @ArthurZucker you mean we allow labels could be a shifted label, and we introduce a flag to indicate this, in

def ForCausalLMLoss(

right? Technically it's doable, but the new flag (say, is_shifted) will be overlapped with the function of shift_labels, which is not ideal.

One possibility is we allow shift_labels to be Optional[Union[torch.Tensor, bool]], then we do inside ForCausalLMLoss

    if shift_labels is True and labels is not None:
        shift_labels = labels
    ....

it could work and it could avoid the changes in the many modeling files. But you see the downside (same argument having 2 possible types and perform something differently).

@stas00 Any comment here?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 2, 2025

@ArthurZucker @stas00

This is probably a better and clean solution

https://github.com/huggingface/transformers/pull/38533/files

Let me know your opinions

@stas00
Copy link
Contributor

stas00 commented Jun 2, 2025

@ydshieh, the last one is a smooth solution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants