Skip to content

Conversation

@HuangBugWei
Copy link
Contributor

@HuangBugWei HuangBugWei commented Mar 17, 2025

What does this PR do?

I noticed that cand_indexes is constructed using the following logic:

cand_indexes = []
for i, token in enumerate(input_tokens):
    if token == "[CLS]" or token == "[SEP]":
        continue

    if len(cand_indexes) >= 1 and token.startswith("##"):
        cand_indexes[-1].append(i)
    else:
        cand_indexes.append([i])

Since i is directly obtained from enumerate(input_tokens), it is guaranteed to be unique across iterations. As a result, there will be no duplicate elements in the flattened cand_indexes.

Given this, the following check appears to be redundant, as it will never be triggered:

is_any_index_covered = False
for index in index_set:
    if index in covered_indexes:
        is_any_index_covered = True
        break
if is_any_index_covered:
    continue

I suggest removing this redundant check to simplify the logic and improve efficiency and readability.
Let me know if there’s any edge case I might have missed!

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@github-actions github-actions bot marked this pull request as draft March 17, 2025 08:30
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@HuangBugWei HuangBugWei marked this pull request as ready for review March 17, 2025 10:44
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct! There should be no duplicates anywhere in cand_indexes and so no way for this block to be relevant.

I suspect this code hasn't been maintained in a while - can you check the rest of the function to confirm it's bug-free as well, and then ping me whenever you're happy for me to merge?

@HuangBugWei
Copy link
Contributor Author

I've reviewed the rest of the function, and I believe most of it is correct. However, I also identified a couple of potential issues:

  1. Potential for Inclusion of Other Special Tokens in Masking Candidates: While the code correctly skips [CLS] and [SEP], although rarely, other special tokens present in input are initially included as candidates for masking in cand_indexes.
    cand_indexes = []
    for i, token in enumerate(input_tokens):
    if token == "[CLS]" or token == "[SEP]":
    continue
    if len(cand_indexes) >= 1 and token.startswith("##"):
    cand_indexes[-1].append(i)
    else:
    cand_indexes.append([i])
  2. Discrepancy Between Target Masking Quantity and Masked Non-Special Tokens: Due to the initial inclusion of these other special tokens, which might be excluded later during the actual masking, the pre-calculated num_to_predict might not accurately represent the number of non-special tokens that are ultimately masked. This could lead to fewer relevant tokens being masked than intended.
    special_tokens_mask = [
    self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  3. Deterministic Calculation of Masking Quantity (with probabilistic expectation): The use of mlm_probability to directly calculate a fixed number of tokens to mask (num_to_predict) results in a deterministic outcome, rather than a probabilistic selection based on the mlm_probability. While the expected number of masked tokens might be similar over many instances, this approach doesn't reflect the probabilistic nature one might expect from the parameter.
    num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))

Overall, the function looks mostly correct. However, I've noted a few code snippets that might deviate from the intended behavior regarding special token handling and the deterministic masking calculation.

@Rocketknight1 What do you think? Would you like me to create a separate PR to address these potential inconsistencies?

@ArthurZucker ArthurZucker removed their request for review March 20, 2025 09:56
@Rocketknight1
Copy link
Member

@HuangBugWei, those are all good points! This collator seems to be specifically designed for Bert tokenizers, so limiting to [CLS] and [SEP] makes sense. I think the other bugs are minor - at some point, we could consider a complete refactor here so this collator works for all MLM models, but for now I think we can just merge this PR. Thanks for your work and for the clean analysis!

@Rocketknight1 Rocketknight1 force-pushed the fix/data_collator_for_WWM branch from 0501079 to 9073008 Compare March 20, 2025 14:01
@Rocketknight1 Rocketknight1 merged commit a63e92e into huggingface:main Mar 20, 2025
21 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
)

remove the redundant snippet of _whole_word_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants