|
57 | 57 | from executorch.examples.models.llama.source_transformation.quantize import ( |
58 | 58 | get_quant_embedding_transform, |
59 | 59 | ) |
60 | | -from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken |
61 | 60 | from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( |
62 | 61 | LlamaModel, |
63 | 62 | ModelArgs, |
|
75 | 74 | from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass |
76 | 75 | from executorch.extension.llm.custom_ops import model_sharding |
77 | 76 | from executorch.extension.llm.export.builder import DType |
78 | | -from pytorch_tokenizers import get_tokenizer |
| 77 | +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer |
79 | 78 | from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer |
80 | 79 |
|
81 | 80 | from torch.ao.quantization.observer import MinMaxObserver |
@@ -141,7 +140,7 @@ def _kv_calibrate( |
141 | 140 | # Llama2 tokenizer has no special tokens |
142 | 141 | if isinstance(tokenizer, SentencePieceTokenizer): |
143 | 142 | token_list = tokenizer.encode(user_prompts, bos=True, eos=False) |
144 | | - elif isinstance(tokenizer, Tiktoken): |
| 143 | + elif isinstance(tokenizer, TiktokenTokenizer): |
145 | 144 | token_list = tokenizer.encode( |
146 | 145 | user_prompts, bos=True, eos=False, allowed_special="all" |
147 | 146 | ) |
@@ -213,7 +212,7 @@ def _prefill_calibrate( |
213 | 212 | # Llama2 tokenizer has no special tokens |
214 | 213 | if isinstance(tokenizer, SentencePieceTokenizer): |
215 | 214 | token_list = tokenizer.encode(user_prompts, bos=True, eos=False) |
216 | | - elif isinstance(tokenizer, Tiktoken): |
| 215 | + elif isinstance(tokenizer, TiktokenTokenizer): |
217 | 216 | token_list = tokenizer.encode( |
218 | 217 | user_prompts, bos=True, eos=False, allowed_special="all" |
219 | 218 | ) |
@@ -1111,7 +1110,7 @@ def export_llama(args) -> None: |
1111 | 1110 | runtime_tokenizer_path = args.tokenizer_bin |
1112 | 1111 | elif args.llama_model == "llama3_2": |
1113 | 1112 | assert isinstance( |
1114 | | - tokenizer, Tiktoken |
| 1113 | + tokenizer, TiktokenTokenizer |
1115 | 1114 | ), f"Wrong tokenizer provided for llama3_2." |
1116 | 1115 | runtime_tokenizer_path = args.tokenizer_model |
1117 | 1116 | else: |
|
0 commit comments