Skip to content

Commit a787cbd

Browse files
authored
fix qnn export
Differential Revision: D72263224 Pull Request resolved: #9808
1 parent 150cbe1 commit a787cbd

File tree

1 file changed

+4
-5
lines changed
  • examples/qualcomm/oss_scripts/llama

1 file changed

+4
-5
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from executorch.examples.models.llama.source_transformation.quantize import (
5858
get_quant_embedding_transform,
5959
)
60-
from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
6160
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
6261
LlamaModel,
6362
ModelArgs,
@@ -75,7 +74,7 @@
7574
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
7675
from executorch.extension.llm.custom_ops import model_sharding
7776
from executorch.extension.llm.export.builder import DType
78-
from pytorch_tokenizers import get_tokenizer
77+
from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer
7978
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
8079

8180
from torch.ao.quantization.observer import MinMaxObserver
@@ -141,7 +140,7 @@ def _kv_calibrate(
141140
# Llama2 tokenizer has no special tokens
142141
if isinstance(tokenizer, SentencePieceTokenizer):
143142
token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
144-
elif isinstance(tokenizer, Tiktoken):
143+
elif isinstance(tokenizer, TiktokenTokenizer):
145144
token_list = tokenizer.encode(
146145
user_prompts, bos=True, eos=False, allowed_special="all"
147146
)
@@ -213,7 +212,7 @@ def _prefill_calibrate(
213212
# Llama2 tokenizer has no special tokens
214213
if isinstance(tokenizer, SentencePieceTokenizer):
215214
token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
216-
elif isinstance(tokenizer, Tiktoken):
215+
elif isinstance(tokenizer, TiktokenTokenizer):
217216
token_list = tokenizer.encode(
218217
user_prompts, bos=True, eos=False, allowed_special="all"
219218
)
@@ -1111,7 +1110,7 @@ def export_llama(args) -> None:
11111110
runtime_tokenizer_path = args.tokenizer_bin
11121111
elif args.llama_model == "llama3_2":
11131112
assert isinstance(
1114-
tokenizer, Tiktoken
1113+
tokenizer, TiktokenTokenizer
11151114
), f"Wrong tokenizer provided for llama3_2."
11161115
runtime_tokenizer_path = args.tokenizer_model
11171116
else:

0 commit comments

Comments
 (0)