1717import os
1818from pathlib import Path
1919from typing import Dict , List , Optional
20-
20+ import re
2121try :
2222 import tiktoken
2323except ImportError :
@@ -102,7 +102,7 @@ def __init__(
102102 self ._eos_id = special_tokens .index ("</s>" )
103103
104104 self ._vocab_size = vocab_size
105- print ( f' { self . _vocab_size = } ' )
105+
106106 self .num_special_tokens = num_special_tokens
107107 special_filler = [SPECIAL_TOKEN_TEMPLATE .format (id = i ) for i in range (len (special_tokens ), num_special_tokens )]
108108 if special_filler :
@@ -128,48 +128,81 @@ def __init__(
128128 )
129129
130130 def text_to_tokens (self , text : str ):
131- token_ids = self .tokenizer .encode (text )
132- return [self .tokenizer .decode_single_token_bytes (token ) for token in token_ids ]
131+ tokens = []
132+ special_token_pattern = SPECIAL_TOKEN_TEMPLATE .format (id = '\\ d+' )
133+ parts = re .split (f"({ special_token_pattern } )" , text )
134+ for part in parts :
135+ if re .match (special_token_pattern , part ):
136+ tokens .append (part .encode ('utf-8' ))
137+ else :
138+ token_ids = self .tokenizer .encode (part )
139+ tokens .extend ([self .tokenizer .decode_single_token_bytes (token ) for token in token_ids ])
140+ return tokens
133141
134142 def tokens_to_text (self , tokens : List [int ]):
135- token_ids = [self .tokenizer .encode_single_token (tokens ) for tokens in tokens ]
136- return self .tokenizer .decode (token_ids )
143+ result = []
144+ for token in tokens :
145+ if isinstance (token , bytes ):
146+ result .append (token .decode ('utf-8' ))
147+ else :
148+ result .append (self .tokenizer .decode ([token ]))
149+ return '' .join (result )
137150
138151 def token_to_id (self , token ):
139- return self .tokenizer .encode_single_token (token )
140-
152+ token_str = token .decode ('utf-8' , errors = 'replace' ) if isinstance (token , bytes ) else token
153+ if token_str in self .special_tokens :
154+ return self .special_tokens .index (token_str )
155+ else :
156+ token_ids = self .tokenizer .encode (token_str )
157+ if len (token_ids ) != 1 :
158+ raise ValueError (f"Token '{ token_str } ' should correspond to exactly one ID, but got { token_ids } " )
159+ return token_ids [0 ] + self .num_special_tokens
160+
141161 def tokens_to_ids (self , tokens ):
142- return [self .tokenizer .encode_single_token (token ) for token in tokens ]
162+ ids = []
163+ for token in tokens :
164+ token_str = token .decode ('utf-8' , errors = 'replace' ) if isinstance (token , bytes ) else token
165+ if token_str in self .special_tokens :
166+ ids .append (self .special_tokens .index (token_str ))
167+ else :
168+ ids .extend ([id + self .num_special_tokens for id in self .tokenizer .encode (token_str )])
169+ return ids
143170
144171 def ids_to_tokens (self , token_ids ):
145172 tokens = []
146173 for token_id in token_ids :
147174 if token_id < self .num_special_tokens :
148- tokens .append (self .special_tokens [token_id ])
175+ tokens .append (self .special_tokens [token_id ]. encode ( 'utf-8' ) )
149176 else :
150- token_id -= self .num_special_tokens
151- token_bytes = self .tokenizer .decode_single_token_bytes (token_id )
152- tokens .append (token_bytes . decode ( 'utf-8' , errors = 'replace' ) )
177+ adjusted_token = token_id - self .num_special_tokens
178+ token_bytes = self .tokenizer .decode_single_token_bytes (adjusted_token )
179+ tokens .append (token_bytes )
153180 return tokens
154181
182+
155183 def text_to_ids (self , text : str ):
156- tokens = self .tokenizer .encode (text )
157- tokens = [t + self .num_special_tokens for t in tokens ]
184+ tokens = []
185+ special_token_pattern = SPECIAL_TOKEN_TEMPLATE .format (id = '\\ d+' )
186+ parts = re .split (f"({ special_token_pattern } )" , text )
187+ for part in parts :
188+ if re .match (special_token_pattern , part ):
189+ token_id = int (re .findall (r"\d+" , part )[0 ])
190+ tokens .append (token_id )
191+ else :
192+ token_ids = self .tokenizer .encode (part )
193+ tokens .extend ([t + self .num_special_tokens for t in token_ids ])
158194 return tokens
159195
160- def ids_to_text (self , tokens : List [int ]):
161- # Filter out special tokens and adjust the remaining tokens
162- adjusted_tokens = [
163- t - self .num_special_tokens
164- for t in tokens
165- if t not in {self .bos , self .eos } and t >= self .num_special_tokens
166- ]
167-
168- # Decode only if there are tokens left after filtering
169- if adjusted_tokens :
170- return self .tokenizer .decode (adjusted_tokens )
171- else :
172- return "" # Return an empty string if all tokens were filtered out
196+ def ids_to_text (self , tokens : List [int ], skip_special_tokens : bool = True ):
197+ result = []
198+ for token in tokens :
199+ if token < self .num_special_tokens :
200+ if not skip_special_tokens :
201+ result .append (self .special_tokens [token ])
202+ else :
203+ adjusted_token = token - self .num_special_tokens
204+ result .append (self .tokenizer .decode ([adjusted_token ]))
205+ return '' .join (result )
173206
174207 @property
175208 def bos_id (self ):
0 commit comments