Skip to content

Conversation

@HiDolen
Copy link
Contributor

@HiDolen HiDolen commented Mar 1, 2025

What does this PR do?

PrefixConstrainedLogitsProcessor.__call__() does not account for the case where the size of the input input_ids is zero.

It can lead to an error when executing input_ids.view(-1, self._num_beams, input_ids.shape[-1]):

RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

Fixes noamgat/lm-format-enforcer#132

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

@github-actions github-actions bot marked this pull request as draft March 1, 2025 15:08
@github-actions
Copy link
Contributor

github-actions bot commented Mar 1, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@HiDolen HiDolen marked this pull request as ready for review March 1, 2025 15:16
@gante
Copy link
Contributor

gante commented Mar 4, 2025

Hi @HiDolen 👋 Thank you for opening the PR!

Technically, we might want to apply prefix_allowed_tokens_fn() if input_ids is empty. Instead of providing an early return in that case, can we change how we iterate over the batch and the beam items so as to not hit this exception?

e.g.

for row_id in range(input_ids.shape[0]):
    batch_id = row_id // self._num_beams
    beam_id = row_id % self._num_beams
    sent = input_ids[row_id, ...]
    (...)

@HiDolen
Copy link
Contributor Author

HiDolen commented Mar 6, 2025

Hi @HiDolen 👋 Thank you for opening the PR!

Technically, we might want to apply prefix_allowed_tokens_fn() if input_ids is empty. Instead of providing an early return in that case, can we change how we iterate over the batch and the beam items so as to not hit this exception?

e.g.

for row_id in range(input_ids.shape[0]):
    batch_id = row_id // self._num_beams
    beam_id = row_id % self._num_beams
    sent = input_ids[row_id, ...]
    (...)

How adout this? Sorry for the late reply.

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    mask = torch.full_like(scores, -math.inf)
-   if input_ids.shape[-1] == 0:
-       return scores + mask
+   batch_size = input_ids.shape[0] // self._num_beams

-   for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
-       for beam_id, sent in enumerate(beam_sent):
+   for batch_id in range(batch_size):
+       for beam_id in range(self._num_beams):
+           sent = input_ids[batch_id * self._num_beams + beam_id]
            prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
            if len(prefix_allowed_tokens) == 0:
                raise ValueError(
                    f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                    f"This means that the constraint is unsatisfiable. Please check your implementation"
                    f"of `prefix_allowed_tokens_fn` "
                )
            mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0

    scores_processed = scores + mask
    return scores_processed

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you for iterating and for improving transformers 🙌

@gante gante merged commit 6f77597 into huggingface:main Mar 7, 2025
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ Request ] Multimodal model generation

2 participants