Skip to content

skip_special_tokens has different behavior between slow and fast tokenizer #23250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
2 of 4 tasks
BuxianChen opened this issue May 10, 2023 · 8 comments · Fixed by #23909
Closed
2 of 4 tasks

skip_special_tokens has different behavior between slow and fast tokenizer #23250

BuxianChen opened this issue May 10, 2023 · 8 comments · Fixed by #23909
Assignees
Labels
Core: Tokenization Internals of the library; Tokenization.

Comments

@BuxianChen
Copy link

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, recently, I find some subtle difference between slow tokenizer and fast tokenizer, Here is a example

from transformers import AutoTokenizer, T5Tokenizer
path = "t5-small"
text = "this is a ஐ apple"

fast_tokenizer = AutoTokenizer.from_pretrained(path)
num = fast_tokenizer.add_tokens(["ஐ"], special_tokens=True)
assert num == 1
ids = fast_tokenizer(text)["input_ids"]
fast_tokenizer.decode(ids, skip_special_tokens=True)  # 'this is a apple'

slow_tokenizer = T5Tokenizer.from_pretrained(path)
num = slow_tokenizer.add_tokens(["ஐ"], special_tokens=True)
assert num == 1
ids = slow_tokenizer(text)["input_ids"]
slow_tokenizer.decode(ids, skip_special_tokens=True)  # 'this is a ஐ apple'

Here are more informations about the issue, I'm not a native English speaker, hope to be understood.

  • I know in the first situation, fast tokenizer utilizes 🤗 Tokenizer, which will invoke tokenizers.Tokenizer.add_special_tokens(tokens), thus the token will be added to vocabulary, and be viewed as "special token", and never be processed by tokenizer.model.
  • In the second situation, when decoding, slow tokenizer treats the added token as "normal token", so it will not be skipped. By the way, I read the related source code, when skip_special_tokens=True, slow tokenizer only skip self.all_special_ids, but is not stored in this, but self.added_tokens_encoder.

I read some 🤗 official documents, and struggled to figure out the meaning of so called "special token", and realize it's a subtle concept, here is my thought: Tokens can be divided to these categories:

  • normal tokens: these tokens can be split
  • control tokens (the name inspired by SentencePiece): bos_token, eos_token, ..., additional_special_tokens, the major propose of these tokens is used in encode post-processing pipeline. When these tokens appeared in input text, in slow tokenizer situation, in most cases, these tokens also be included in self.unique_no_split_tokens, so these tokens will not be split, but I don't know the treatment in fast tokenizer case.
  • user add tokens:
    • If the token already in vocab, but it can be marked as "special token", and this token will never be split now (but cannot be treated as the same as control tokens in some subtle situation).
    • If the token not in vocab, it will be added (allocate a new token_id to it), this token also will never be split.
      so, in both cases, these user added tokens will never be split.

Please let me know if there are any misunderstandings.

Several weeks ago, I summit a issue 23001 related to return_overflowing_tokens behavior, which is considered as a specific feature of fast tokenizer, so it's a feature not a bug. Generally, I want to know the differences between slow and fast tokenizer, should be viewed as features, or bugs.

Expected behavior

The slow tokenizer should behave same as fast tokenizer.

@BuxianChen
Copy link
Author

I'd like to confirm my understandings to the concept, since the PR 23312 is in progressing:

In 🤗 Transformers, for both slow and fast tokenizers, there are only two types of tokens:

  • normal tokens: these tokens can be split. These tokens cannot be add, but when add_tokens(tokens, special_tokens=True) be called, and the tokens to be added already in normal tokens, in this case, they will be marked as special tokens and will not be split.
  • special tokens: these tokens cannot be split, include:
    • eos_token, bos_token, ..., additional_special_tokens, which are defined in SpecialTokensMixin
    • user add tokens via add_tokens(tokens). (1) When set the parameter special_tokens=False, if a token in tokens already in normal tokens, do nothing to the token; (2) When set the parameter special_tokens=False, if a token in tokens already in normal tokens, mark it as special tokens and will not be split;

In both slow and fast tokenizer, tokenizer.decode(ids, skip_special_tokens=True) will skip all special tokens.

Please let me know if there are any misunderstandings.

@amyeroberts
Copy link
Collaborator

cc @younesbelkada

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for reporting this!

  • Differences between fast and slow are sometimes bugs, sometimes features, which is what makes it a bit complicated.

Now about the core of the issue, you have a good grasp of what is going on, good job! 🤗 And thanks for taking the time to dig in. T5 is a bit of a special case because it uses a hack in the _convert_token_to_ids method.

The core issue is that the additional_special_tokens list and the added_specilal_tokens encoder and decoder are not perfectly linked. Updating one does not update the other, which is a bug. Documentation is also rather scarce on how we use the additional_special_tokens, I am trying to regroup issues linked to that to create a proper fix. Will have a look at the PR!

@ArthurZucker
Copy link
Collaborator

One thing is that some of the added tokens can be non special tokens, which is why you have:

  • normal tokens ( from the original vocab file of the SPModel for example)
  • special tokens (which can be added int he additional special tokens or, control tokens which are class attributes) that behave the same
  • added normal tokens, which should not be split, and have their own index. These are useful when a token is missing from the spmodel, which you can never touch.

@ArthurZucker ArthurZucker self-assigned this May 25, 2023
@ArthurZucker ArthurZucker added the Core: Tokenization Internals of the library; Tokenization. label May 25, 2023
@BuxianChen
Copy link
Author

Thanks for your reply, so the example for slow and fast tokenizer, which behavior is expected?

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, recently, I find some subtle difference between slow tokenizer and fast tokenizer, Here is a example

from transformers import AutoTokenizer, T5Tokenizer
path = "t5-small"
text = "this is a ஐ apple"

fast_tokenizer = AutoTokenizer.from_pretrained(path)
num = fast_tokenizer.add_tokens(["ஐ"], special_tokens=True)
assert num == 1
ids = fast_tokenizer(text)["input_ids"]
fast_tokenizer.decode(ids, skip_special_tokens=True)  # 'this is a apple'

slow_tokenizer = T5Tokenizer.from_pretrained(path)
num = slow_tokenizer.add_tokens(["ஐ"], special_tokens=True)
assert num == 1
ids = slow_tokenizer(text)["input_ids"]
slow_tokenizer.decode(ids, skip_special_tokens=True)  # 'this is a ஐ apple'

Here are more informations about the issue, I'm not a native English speaker, hope to be understood.

  • I know in the first situation, fast tokenizer utilizes 🤗 Tokenizer, which will invoke tokenizers.Tokenizer.add_special_tokens(tokens), thus the token will be added to vocabulary, and be viewed as "special token", and never be processed by tokenizer.model.
  • In the second situation, when decoding, slow tokenizer treats the added token as "normal token", so it will not be skipped. By the way, I read the related source code, when skip_special_tokens=True, slow tokenizer only skip self.all_special_ids, but is not stored in this, but self.added_tokens_encoder.

I read some 🤗 official documents, and struggled to figure out the meaning of so called "special token", and realize it's a subtle concept, here is my thought: Tokens can be divided to these categories:

  • normal tokens: these tokens can be split

  • control tokens (the name inspired by SentencePiece): bos_token, eos_token, ..., additional_special_tokens, the major propose of these tokens is used in encode post-processing pipeline. When these tokens appeared in input text, in slow tokenizer situation, in most cases, these tokens also be included in self.unique_no_split_tokens, so these tokens will not be split, but I don't know the treatment in fast tokenizer case.

  • user add tokens:

    • If the token already in vocab, but it can be marked as "special token", and this token will never be split now (but cannot be treated as the same as control tokens in some subtle situation).
    • If the token not in vocab, it will be added (allocate a new token_id to it), this token also will never be split.
      so, in both cases, these user added tokens will never be split.

Please let me know if there are any misunderstandings.

Several weeks ago, I summit a issue 23001 related to return_overflowing_tokens behavior, which is considered as a specific feature of fast tokenizer, so it's a feature not a bug. Generally, I want to know the differences between slow and fast tokenizer, should be viewed as features, or bugs.

Expected behavior

The slow tokenizer should behave same as fast tokenizer.

@ArthurZucker
Copy link
Collaborator

In this case, the fast is correct: when we ask to skip special tokens when decoding, we expect all the special tokens to be skipped.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 20, 2023

It will be addressed in the linked PR. This is mostly due to the fact that the slow tokenizer was not properly added to the list of additional_special_tokens when being added using add_tokens. The refactoring will prevent this from happening!

@huggingface huggingface deleted a comment from github-actions bot Aug 16, 2023
@ArthurZucker
Copy link
Collaborator

PR will be merged this week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
3 participants