Skip to content

Commit 19e5049

Browse files
committed
a long overdue tiktoken special tokens fix -- Tkonuk
Signed-off-by: adithyare <adithyare@nvidia.com>
1 parent ada4b90 commit 19e5049

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

nemo/collections/common/tokenizers/tiktoken_tokenizer.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
from pathlib import Path
1919
from typing import Dict, List, Optional
20-
20+
import re
2121
try:
2222
import tiktoken
2323
except ImportError:
@@ -102,7 +102,7 @@ def __init__(
102102
self._eos_id = special_tokens.index("</s>")
103103

104104
self._vocab_size = vocab_size
105-
print(f'{self._vocab_size = }')
105+
106106
self.num_special_tokens = num_special_tokens
107107
special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)]
108108
if special_filler:
@@ -128,48 +128,81 @@ def __init__(
128128
)
129129

130130
def text_to_tokens(self, text: str):
131-
token_ids = self.tokenizer.encode(text)
132-
return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids]
131+
tokens = []
132+
special_token_pattern = SPECIAL_TOKEN_TEMPLATE.format(id='\\d+')
133+
parts = re.split(f"({special_token_pattern})", text)
134+
for part in parts:
135+
if re.match(special_token_pattern, part):
136+
tokens.append(part.encode('utf-8'))
137+
else:
138+
token_ids = self.tokenizer.encode(part)
139+
tokens.extend([self.tokenizer.decode_single_token_bytes(token) for token in token_ids])
140+
return tokens
133141

134142
def tokens_to_text(self, tokens: List[int]):
135-
token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens]
136-
return self.tokenizer.decode(token_ids)
143+
result = []
144+
for token in tokens:
145+
if isinstance(token, bytes):
146+
result.append(token.decode('utf-8'))
147+
else:
148+
result.append(self.tokenizer.decode([token]))
149+
return ''.join(result)
137150

138151
def token_to_id(self, token):
139-
return self.tokenizer.encode_single_token(token)
140-
152+
token_str = token.decode('utf-8', errors='replace') if isinstance(token, bytes) else token
153+
if token_str in self.special_tokens:
154+
return self.special_tokens.index(token_str)
155+
else:
156+
token_ids = self.tokenizer.encode(token_str)
157+
if len(token_ids) != 1:
158+
raise ValueError(f"Token '{token_str}' should correspond to exactly one ID, but got {token_ids}")
159+
return token_ids[0] + self.num_special_tokens
160+
141161
def tokens_to_ids(self, tokens):
142-
return [self.tokenizer.encode_single_token(token) for token in tokens]
162+
ids = []
163+
for token in tokens:
164+
token_str = token.decode('utf-8', errors='replace') if isinstance(token, bytes) else token
165+
if token_str in self.special_tokens:
166+
ids.append(self.special_tokens.index(token_str))
167+
else:
168+
ids.extend([id + self.num_special_tokens for id in self.tokenizer.encode(token_str)])
169+
return ids
143170

144171
def ids_to_tokens(self, token_ids):
145172
tokens = []
146173
for token_id in token_ids:
147174
if token_id < self.num_special_tokens:
148-
tokens.append(self.special_tokens[token_id])
175+
tokens.append(self.special_tokens[token_id].encode('utf-8'))
149176
else:
150-
token_id -= self.num_special_tokens
151-
token_bytes = self.tokenizer.decode_single_token_bytes(token_id)
152-
tokens.append(token_bytes.decode('utf-8', errors='replace'))
177+
adjusted_token = token_id - self.num_special_tokens
178+
token_bytes = self.tokenizer.decode_single_token_bytes(adjusted_token)
179+
tokens.append(token_bytes)
153180
return tokens
154181

182+
155183
def text_to_ids(self, text: str):
156-
tokens = self.tokenizer.encode(text)
157-
tokens = [t + self.num_special_tokens for t in tokens]
184+
tokens = []
185+
special_token_pattern = SPECIAL_TOKEN_TEMPLATE.format(id='\\d+')
186+
parts = re.split(f"({special_token_pattern})", text)
187+
for part in parts:
188+
if re.match(special_token_pattern, part):
189+
token_id = int(re.findall(r"\d+", part)[0])
190+
tokens.append(token_id)
191+
else:
192+
token_ids = self.tokenizer.encode(part)
193+
tokens.extend([t + self.num_special_tokens for t in token_ids])
158194
return tokens
159195

160-
def ids_to_text(self, tokens: List[int]):
161-
# Filter out special tokens and adjust the remaining tokens
162-
adjusted_tokens = [
163-
t - self.num_special_tokens
164-
for t in tokens
165-
if t not in {self.bos, self.eos} and t >= self.num_special_tokens
166-
]
167-
168-
# Decode only if there are tokens left after filtering
169-
if adjusted_tokens:
170-
return self.tokenizer.decode(adjusted_tokens)
171-
else:
172-
return "" # Return an empty string if all tokens were filtered out
196+
def ids_to_text(self, tokens: List[int], skip_special_tokens: bool = True):
197+
result = []
198+
for token in tokens:
199+
if token < self.num_special_tokens:
200+
if not skip_special_tokens:
201+
result.append(self.special_tokens[token])
202+
else:
203+
adjusted_token = token - self.num_special_tokens
204+
result.append(self.tokenizer.decode([adjusted_token]))
205+
return ''.join(result)
173206

174207
@property
175208
def bos_id(self):

0 commit comments

Comments
 (0)