diff --git a/llama/tokenizer.py b/llama/tokenizer.py index 3eda89a06..4a25de6a8 100755 --- a/llama/tokenizer.py +++ b/llama/tokenizer.py @@ -3,7 +3,7 @@ import os from logging import getLogger -from typing import List +from typing import List, Optional from sentencepiece import SentencePieceProcessor @@ -13,29 +13,32 @@ class Tokenizer: """tokenizing and encoding/decoding text using SentencePiece.""" - def __init__(self, model_path: str): + def __init__(self, model_path: Optional[str] = None): """ Initializes the Tokenizer with a SentencePiece model. Args: model_path (str): The path to the SentencePiece model file. """ - # reload tokenizer - assert os.path.isfile(model_path), model_path - self.sp_model = SentencePieceProcessor(model_file=model_path) - logger.info(f"Reloaded SentencePiece model from {model_path}") + if model_path is not None: + # reload tokenizer if possible + if not os.path.isfile(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") - # BOS / EOS token IDs - self.n_words: int = self.sp_model.vocab_size() - self.bos_id: int = self.sp_model.bos_id() - self.eos_id: int = self.sp_model.eos_id() - self.pad_id: int = self.sp_model.pad_id() - logger.info( - f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" - ) - assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + # BOS / EOS / PAD / UNK token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.unk_id: int = self.sp_model.unk_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: """ Encodes a string into a list of token IDs. @@ -47,8 +50,15 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]: Returns: List[int]: A list of token IDs. """ - assert type(s) is str - t = self.sp_model.encode(s) + assert isinstance(s, str), "Input 's' must be a string" + try: + t = self.sp_model.encode(s) + except Exception as e: + raise ValueError(f"Error during tokenization: {e}") + + # Handle unknown tokens + t = [token_id if token_id in range(self.n_words) else self.unk_id for token_id in t] + if bos: t = [self.bos_id] + t if eos: