Skip to content

Commit 413e736

Browse files
committed
Fixed tokenization of special characters.
1 parent 30cef20 commit 413e736

File tree

1 file changed

+28
-45
lines changed

1 file changed

+28
-45
lines changed

nemo/collections/common/tokenizers/tiktoken_tokenizer.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import base64
1616
import json
1717
import os
18-
import re
18+
import regex as re
1919
from pathlib import Path
2020
from typing import Dict, List, Optional
2121

@@ -68,7 +68,6 @@ def reload_mergeable_ranks(
6868
SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"]
6969
SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>"
7070

71-
7271
class TiktokenTokenizer(TokenizerSpec):
7372
"""
7473
TiktokenTokenizer https://github.com/openai/tiktoken.
@@ -127,38 +126,12 @@ def __init__(
127126
mergeable_ranks=self.token2id,
128127
special_tokens={}, # special tokens are handled manually
129128
)
129+
130+
# Compile the tokenizer pattern for later use
131+
self.pattern = re.compile(pattern)
130132

131133
def text_to_tokens(self, text: str) -> List[str]:
132-
"""
133-
Tokenizes input text into a list of token strings, handling special tokens and non-ASCII substrings.
134-
135-
Args:
136-
text (str): The input text to tokenize.
137-
138-
Returns:
139-
List[str]: A list of token strings.
140-
"""
141-
tokens = []
142-
special_token_pattern = SPECIAL_TOKEN_TEMPLATE.format(id=r"\d+")
143-
pattern = f"({special_token_pattern}|<unk>|<s>|</s>|[^\x00-\x7F]+)"
144-
145-
# Split the text using the defined pattern
146-
parts = re.split(pattern, text)
147-
148-
for part in filter(None, parts): # Skip empty strings
149-
if re.match(special_token_pattern, part) or part in self.special_tokens:
150-
tokens.append(part) # Special token
151-
elif re.match(r"[^\x00-\x7F]+", part):
152-
tokens.append(part) # Non-ASCII substring
153-
else:
154-
# Encode and decode ASCII parts
155-
for token_id in self.tokenizer.encode(part):
156-
token_str = self.id2token.get(token_id, "<unk>")
157-
if isinstance(token_str, bytes): # Handle bytes decoding
158-
token_str = token_str.decode('utf-8', errors='replace')
159-
tokens.append(token_str)
160-
161-
return tokens
134+
return self.ids_to_tokens(self.text_to_ids(text))
162135

163136
def tokens_to_text(self, tokens: List[str]) -> str:
164137
return ''.join(tokens)
@@ -185,23 +158,19 @@ def tokens_to_ids(self, tokens):
185158

186159
def ids_to_tokens(self, ids: List[int]) -> List[str]:
187160
tokens = []
188-
chunks = []
161+
current_ids = []
189162
for id_ in ids:
190163
if id_ < self.num_special_tokens:
191-
if chunks:
192-
# Decode the chunk and append resulting tokens
193-
decoded_chunk = self.tokenizer.decode([t - self.num_special_tokens for t in chunks])
194-
tokens.extend(decoded_chunk.split()) # Split into individual tokens
195-
chunks = []
196-
# Add the special token directly
164+
if current_ids:
165+
decoded_text = self.tokenizer.decode([i - self.num_special_tokens for i in current_ids])
166+
tokens.extend(self._tokenize_text_with_pattern(decoded_text))
167+
current_ids = []
197168
tokens.append(self.special_tokens[id_])
198169
else:
199-
# Add to current chunk
200-
chunks.append(id_)
201-
if chunks:
202-
# Decode any remaining chunk
203-
decoded_chunk = self.tokenizer.decode([t - self.num_special_tokens for t in chunks])
204-
tokens.extend(decoded_chunk.split())
170+
current_ids.append(id_)
171+
if current_ids:
172+
decoded_text = self.tokenizer.decode([i - self.num_special_tokens for i in current_ids])
173+
tokens.extend(self._tokenize_text_with_pattern(decoded_text))
205174
return tokens
206175

207176
def text_to_ids(self, text: str) -> List[int]:
@@ -232,6 +201,20 @@ def ids_to_text(self, ids: List[int], skip_special_tokens: bool = False) -> str:
232201
result.append(self.tokenizer.decode([t - self.num_special_tokens for t in chunks]))
233202
return ''.join(result)
234203

204+
def _tokenize_text_with_pattern(self, text: str) -> List[str]:
205+
tokens = []
206+
last_end = 0
207+
for match in self.pattern.finditer(text):
208+
start, end = match.span()
209+
if start > last_end:
210+
# Capture any text between matches (including leading whitespace)
211+
tokens.append(text[last_end:start])
212+
tokens.append(match.group(0))
213+
last_end = end
214+
if last_end < len(text):
215+
tokens.append(text[last_end:])
216+
return tokens
217+
235218
@property
236219
def bos_id(self):
237220
return self._bos_id

0 commit comments

Comments
 (0)