1515import base64
1616import json
1717import os
18- import re
18+ import regex as re
1919from pathlib import Path
2020from typing import Dict , List , Optional
2121
@@ -68,7 +68,6 @@ def reload_mergeable_ranks(
6868SPECIAL_TOKENS = ["<unk>" , "<s>" , "</s>" ]
6969SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>"
7070
71-
7271class TiktokenTokenizer (TokenizerSpec ):
7372 """
7473 TiktokenTokenizer https://github.com/openai/tiktoken.
@@ -127,38 +126,12 @@ def __init__(
127126 mergeable_ranks = self .token2id ,
128127 special_tokens = {}, # special tokens are handled manually
129128 )
129+
130+ # Compile the tokenizer pattern for later use
131+ self .pattern = re .compile (pattern )
130132
131133 def text_to_tokens (self , text : str ) -> List [str ]:
132- """
133- Tokenizes input text into a list of token strings, handling special tokens and non-ASCII substrings.
134-
135- Args:
136- text (str): The input text to tokenize.
137-
138- Returns:
139- List[str]: A list of token strings.
140- """
141- tokens = []
142- special_token_pattern = SPECIAL_TOKEN_TEMPLATE .format (id = r"\d+" )
143- pattern = f"({ special_token_pattern } |<unk>|<s>|</s>|[^\x00 -\x7F ]+)"
144-
145- # Split the text using the defined pattern
146- parts = re .split (pattern , text )
147-
148- for part in filter (None , parts ): # Skip empty strings
149- if re .match (special_token_pattern , part ) or part in self .special_tokens :
150- tokens .append (part ) # Special token
151- elif re .match (r"[^\x00-\x7F]+" , part ):
152- tokens .append (part ) # Non-ASCII substring
153- else :
154- # Encode and decode ASCII parts
155- for token_id in self .tokenizer .encode (part ):
156- token_str = self .id2token .get (token_id , "<unk>" )
157- if isinstance (token_str , bytes ): # Handle bytes decoding
158- token_str = token_str .decode ('utf-8' , errors = 'replace' )
159- tokens .append (token_str )
160-
161- return tokens
134+ return self .ids_to_tokens (self .text_to_ids (text ))
162135
163136 def tokens_to_text (self , tokens : List [str ]) -> str :
164137 return '' .join (tokens )
@@ -185,23 +158,19 @@ def tokens_to_ids(self, tokens):
185158
186159 def ids_to_tokens (self , ids : List [int ]) -> List [str ]:
187160 tokens = []
188- chunks = []
161+ current_ids = []
189162 for id_ in ids :
190163 if id_ < self .num_special_tokens :
191- if chunks :
192- # Decode the chunk and append resulting tokens
193- decoded_chunk = self .tokenizer .decode ([t - self .num_special_tokens for t in chunks ])
194- tokens .extend (decoded_chunk .split ()) # Split into individual tokens
195- chunks = []
196- # Add the special token directly
164+ if current_ids :
165+ decoded_text = self .tokenizer .decode ([i - self .num_special_tokens for i in current_ids ])
166+ tokens .extend (self ._tokenize_text_with_pattern (decoded_text ))
167+ current_ids = []
197168 tokens .append (self .special_tokens [id_ ])
198169 else :
199- # Add to current chunk
200- chunks .append (id_ )
201- if chunks :
202- # Decode any remaining chunk
203- decoded_chunk = self .tokenizer .decode ([t - self .num_special_tokens for t in chunks ])
204- tokens .extend (decoded_chunk .split ())
170+ current_ids .append (id_ )
171+ if current_ids :
172+ decoded_text = self .tokenizer .decode ([i - self .num_special_tokens for i in current_ids ])
173+ tokens .extend (self ._tokenize_text_with_pattern (decoded_text ))
205174 return tokens
206175
207176 def text_to_ids (self , text : str ) -> List [int ]:
@@ -232,6 +201,20 @@ def ids_to_text(self, ids: List[int], skip_special_tokens: bool = False) -> str:
232201 result .append (self .tokenizer .decode ([t - self .num_special_tokens for t in chunks ]))
233202 return '' .join (result )
234203
204+ def _tokenize_text_with_pattern (self , text : str ) -> List [str ]:
205+ tokens = []
206+ last_end = 0
207+ for match in self .pattern .finditer (text ):
208+ start , end = match .span ()
209+ if start > last_end :
210+ # Capture any text between matches (including leading whitespace)
211+ tokens .append (text [last_end :start ])
212+ tokens .append (match .group (0 ))
213+ last_end = end
214+ if last_end < len (text ):
215+ tokens .append (text [last_end :])
216+ return tokens
217+
235218 @property
236219 def bos_id (self ):
237220 return self ._bos_id
0 commit comments