From 47227d680663551af553eb9de1e77a8df0ef2127 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 1 Apr 2025 14:35:04 -0700 Subject: [PATCH] fix qnn export (#9808) Summary: X-link: https://github.com/pytorch-labs/tokenizers/pull/41 One missing item for the new tokenizer lib Reviewed By: kirklandsign Differential Revision: D72263224 --- examples/qualcomm/oss_scripts/llama/llama.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e00df50e755..458807b8a64 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -57,7 +57,6 @@ from executorch.examples.models.llama.source_transformation.quantize import ( get_quant_embedding_transform, ) -from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( LlamaModel, ModelArgs, @@ -75,7 +74,7 @@ from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.extension.llm.custom_ops import model_sharding from executorch.extension.llm.export.builder import DType -from pytorch_tokenizers import get_tokenizer +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from torch.ao.quantization.observer import MinMaxObserver @@ -141,7 +140,7 @@ def _kv_calibrate( # Llama2 tokenizer has no special tokens if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) - elif isinstance(tokenizer, Tiktoken): + elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( user_prompts, bos=True, eos=False, allowed_special="all" ) @@ -213,7 +212,7 @@ def _prefill_calibrate( # Llama2 tokenizer has no special tokens if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) - elif isinstance(tokenizer, Tiktoken): + elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( user_prompts, bos=True, eos=False, allowed_special="all" ) @@ -1111,7 +1110,7 @@ def export_llama(args) -> None: runtime_tokenizer_path = args.tokenizer_bin elif args.llama_model == "llama3_2": assert isinstance( - tokenizer, Tiktoken + tokenizer, TiktokenTokenizer ), f"Wrong tokenizer provided for llama3_2." runtime_tokenizer_path = args.tokenizer_model else: