Skip to content

Commit d7f52c5

Browse files
committed
option: test tokenizer from tiktoken same as fast
1 parent 445cd2d commit d7f52c5

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tests/models/gpt2/test_tokenization_gpt2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
2222
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
23-
from transformers.testing_utils import require_jinja, require_tokenizers
23+
from transformers.testing_utils import require_jinja, require_tokenizers, require_tiktoken
2424

2525
from ...test_tokenization_common import TokenizerTesterMixin
2626

@@ -299,6 +299,19 @@ def test_tokenization_for_chat(self):
299299
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
300300
self.assertListEqual(tokenized_chat, expected_tokens)
301301

302+
@require_tiktoken
303+
def test_tokenization_tiktoken(self):
304+
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
305+
from tiktoken import encoding_name_for_model
306+
307+
encoding = encoding_name_for_model("gpt2")
308+
convert_tiktoken_to_fast(encoding, self.tmpdirname)
309+
310+
tiktoken_fast_tokenizer = GPT2TokenizerFast.from_pretrained(self.tmpdirname)
311+
rust_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
312+
sequence = "lower newer"
313+
self.assertEqual(rust_tokenizer.decode(rust_tokenizer.encode(sequence)), tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)))
314+
302315

303316
@require_tokenizers
304317
class OPTTokenizationTest(unittest.TestCase):

0 commit comments

Comments
 (0)