-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Pass num_items_in_batch directly to loss computation #36753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass num_items_in_batch directly to loss computation #36753
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
qubvel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @eljandoubi for opening a PR and fixing the issue! Just a small nit
| loss = fixed_cross_entropy( | ||
| logits.reshape(-1, self.decoder.config.vocab_size), | ||
| labels.reshape(-1), | ||
| num_items_in_batch=num_items_in_batch, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use instead, see llama for example. All reshape/view ops will happen under the hood
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel like this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit different, there is no need to make view while passing to the function, see my example above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel I've updated my code, what do you think about it now?
|
Thanks @eljandoubi |
* Pass num_items_in_batch directly to loss computation * use self loss instead * fix loss kwrgs * fix vocab size
What does this PR do?
Fixes #36744
Models: