|
57 | 57 | from executorch.examples.models.llama.source_transformation.quantize import (
|
58 | 58 | get_quant_embedding_transform,
|
59 | 59 | )
|
60 |
| -from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken |
61 | 60 | from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
|
62 | 61 | LlamaModel,
|
63 | 62 | ModelArgs,
|
|
75 | 74 | from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
|
76 | 75 | from executorch.extension.llm.custom_ops import model_sharding
|
77 | 76 | from executorch.extension.llm.export.builder import DType
|
78 |
| -from pytorch_tokenizers import get_tokenizer |
| 77 | +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer |
79 | 78 | from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
|
80 | 79 |
|
81 | 80 | from torch.ao.quantization.observer import MinMaxObserver
|
@@ -141,7 +140,7 @@ def _kv_calibrate(
|
141 | 140 | # Llama2 tokenizer has no special tokens
|
142 | 141 | if isinstance(tokenizer, SentencePieceTokenizer):
|
143 | 142 | token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
|
144 |
| - elif isinstance(tokenizer, Tiktoken): |
| 143 | + elif isinstance(tokenizer, TiktokenTokenizer): |
145 | 144 | token_list = tokenizer.encode(
|
146 | 145 | user_prompts, bos=True, eos=False, allowed_special="all"
|
147 | 146 | )
|
@@ -213,7 +212,7 @@ def _prefill_calibrate(
|
213 | 212 | # Llama2 tokenizer has no special tokens
|
214 | 213 | if isinstance(tokenizer, SentencePieceTokenizer):
|
215 | 214 | token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
|
216 |
| - elif isinstance(tokenizer, Tiktoken): |
| 215 | + elif isinstance(tokenizer, TiktokenTokenizer): |
217 | 216 | token_list = tokenizer.encode(
|
218 | 217 | user_prompts, bos=True, eos=False, allowed_special="all"
|
219 | 218 | )
|
@@ -1111,7 +1110,7 @@ def export_llama(args) -> None:
|
1111 | 1110 | runtime_tokenizer_path = args.tokenizer_bin
|
1112 | 1111 | elif args.llama_model == "llama3_2":
|
1113 | 1112 | assert isinstance(
|
1114 |
| - tokenizer, Tiktoken |
| 1113 | + tokenizer, TiktokenTokenizer |
1115 | 1114 | ), f"Wrong tokenizer provided for llama3_2."
|
1116 | 1115 | runtime_tokenizer_path = args.tokenizer_model
|
1117 | 1116 | else:
|
|
0 commit comments