|
20 | 20 |
|
21 | 21 | from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast |
22 | 22 | from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES |
23 | | -from transformers.testing_utils import require_jinja, require_tokenizers |
| 23 | +from transformers.testing_utils import require_jinja, require_tiktoken, require_tokenizers |
24 | 24 |
|
25 | 25 | from ...test_tokenization_common import TokenizerTesterMixin |
26 | 26 |
|
@@ -299,6 +299,23 @@ def test_tokenization_for_chat(self): |
299 | 299 | for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): |
300 | 300 | self.assertListEqual(tokenized_chat, expected_tokens) |
301 | 301 |
|
| 302 | + @require_tiktoken |
| 303 | + def test_tokenization_tiktoken(self): |
| 304 | + from tiktoken import encoding_name_for_model |
| 305 | + |
| 306 | + from transformers.integrations.tiktoken import convert_tiktoken_to_fast |
| 307 | + |
| 308 | + encoding = encoding_name_for_model("gpt2") |
| 309 | + convert_tiktoken_to_fast(encoding, self.tmpdirname) |
| 310 | + |
| 311 | + tiktoken_fast_tokenizer = GPT2TokenizerFast.from_pretrained(self.tmpdirname) |
| 312 | + rust_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") |
| 313 | + sequence = "lower newer" |
| 314 | + self.assertEqual( |
| 315 | + rust_tokenizer.decode(rust_tokenizer.encode(sequence)), |
| 316 | + tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)), |
| 317 | + ) |
| 318 | + |
302 | 319 |
|
303 | 320 | @require_tokenizers |
304 | 321 | class OPTTokenizationTest(unittest.TestCase): |
|
0 commit comments