Skip to content

Commit 3f03c37

Browse files
authored
fix tiktoken convert to pass AddedToken to Tokenizer (#36566)
* pass AddedToken to Tokenizer * ruff * handle dict for special tokens * option: test tokenizer from tiktoken same as fast * ruff * ruff
1 parent 8f64b17 commit 3f03c37

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/transformers/convert_slow_tokenizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,9 @@ def __init__(
15801580
self.vocab_file = vocab_file
15811581
self.pattern = pattern
15821582
self.add_prefix_space = add_prefix_space
1583-
self.additional_special_tokens = additional_special_tokens
1583+
self.additional_special_tokens = (
1584+
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
1585+
)
15841586

15851587
def extract_vocab_merges_from_model(self, tiktoken_url: str):
15861588
try:
@@ -1629,7 +1631,10 @@ def converted(self) -> Tokenizer:
16291631
]
16301632
)
16311633
tokenizer.decoder = decoders.ByteLevel()
1632-
tokenizer.add_special_tokens(self.additional_special_tokens)
1634+
1635+
tokenizer.add_special_tokens(
1636+
[AddedToken(token, normalized=False, special=True) for token in self.additional_special_tokens]
1637+
)
16331638

16341639
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
16351640

tests/models/gpt2/test_tokenization_gpt2.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
2222
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
23-
from transformers.testing_utils import require_jinja, require_tokenizers
23+
from transformers.testing_utils import require_jinja, require_tiktoken, require_tokenizers
2424

2525
from ...test_tokenization_common import TokenizerTesterMixin
2626

@@ -299,6 +299,23 @@ def test_tokenization_for_chat(self):
299299
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
300300
self.assertListEqual(tokenized_chat, expected_tokens)
301301

302+
@require_tiktoken
303+
def test_tokenization_tiktoken(self):
304+
from tiktoken import encoding_name_for_model
305+
306+
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
307+
308+
encoding = encoding_name_for_model("gpt2")
309+
convert_tiktoken_to_fast(encoding, self.tmpdirname)
310+
311+
tiktoken_fast_tokenizer = GPT2TokenizerFast.from_pretrained(self.tmpdirname)
312+
rust_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
313+
sequence = "lower newer"
314+
self.assertEqual(
315+
rust_tokenizer.decode(rust_tokenizer.encode(sequence)),
316+
tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)),
317+
)
318+
302319

303320
@require_tokenizers
304321
class OPTTokenizationTest(unittest.TestCase):

0 commit comments

Comments
 (0)