Skip to content

Commit 0e48dbc

Browse files
akoumpaabhinavg4
authored andcommitted
fix sentencepiece tokenizer special tokens (#11811)
* Allow special tokens in vocab Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * pass special tokens to SentencePieceTokenizer Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * cleanup Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * cleanup Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add trim_spm_separator_after_special_token option Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * remove unused code Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * also handle dict in omegaconf to_container Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * also handle special tokens that are already in vocab Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
1 parent 6335c82 commit 0e48dbc

File tree

3 files changed

+100
-18
lines changed

3 files changed

+100
-18
lines changed

nemo/collections/common/tokenizers/sentencepiece_tokenizer.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
legacy: bool = False,
5252
ignore_extra_whitespaces: bool = True,
5353
chat_template: Optional[Dict] = None,
54+
trim_spm_separator_after_special_token=True,
55+
spm_separator='▁',
5456
):
5557
self.chat_template = chat_template
5658
if not model_path or not os.path.exists(model_path):
@@ -66,6 +68,10 @@ def __init__(
6668
self.extra_space_token = '☯'
6769
self.special_token_to_id = {}
6870
self.id_to_special_token = {}
71+
self.trim_spm_separator_after_special_token = trim_spm_separator_after_special_token
72+
self.spm_separator_id = self.tokenizer.piece_to_id(spm_separator)
73+
self.spm_separator = spm_separator
74+
6975
if special_tokens:
7076
if not self.legacy:
7177
raise ValueError(
@@ -99,7 +105,19 @@ def text_to_tokens(self, text):
99105
next_token = min(indices, key=indices.get)
100106
next_idx = idx + indices[next_token]
101107

102-
tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx]))
108+
tok = self.tokenizer.encode_as_pieces(text[idx:next_idx])
109+
# Chat-templates insert a space between a special token and first word (e.g.
110+
# "[INST] who") which is tokenized as <inst-id> <space-id> <who-id> instead of
111+
# <inst-id> <who-id>.
112+
if (
113+
self.trim_spm_separator_after_special_token
114+
and len(tokens) > 0
115+
and tokens[-1] in self.special_token_to_id
116+
and len(tok) > 0
117+
and tok[0] == self.spm_separator
118+
):
119+
tok.pop(0)
120+
tokens.extend(tok)
103121
tokens.append(next_token)
104122
idx = next_idx + len(next_token)
105123

@@ -143,7 +161,19 @@ def _text_to_ids(self, text, sample_alpha=None):
143161
next_token = min(indices, key=indices.get)
144162
next_idx = idx + indices[next_token]
145163

146-
ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
164+
text_tokens = self.tokenizer.encode(text[idx:next_idx])
165+
# Chat-templates insert a space between a special token and first word (e.g.
166+
# "[INST] who") which is tokenized as <inst-id> <space-id> <who-id> instead of
167+
# <inst-id> <who-id>.
168+
if (
169+
self.trim_spm_separator_after_special_token
170+
and len(ids) > 0
171+
and ids[-1] in self.id_to_special_token
172+
and len(text_tokens) > 0
173+
and text_tokens[0] == self.spm_separator_id
174+
):
175+
text_tokens.pop(0)
176+
ids.extend(text_tokens)
147177
ids.append(self.special_token_to_id[next_token])
148178
idx = next_idx + len(next_token)
149179

@@ -239,6 +269,7 @@ def add_special_tokens(self, special_tokens):
239269
self.vocab_size += 1
240270
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
241271
self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)
272+
self.id_to_special_token[self.special_token_to_id[token]] = token
242273

243274
elif isinstance(special_tokens, dict):
244275
for token_name, token in special_tokens.items():
@@ -250,6 +281,9 @@ def add_special_tokens(self, special_tokens):
250281
self.special_token_to_id[token] = self.vocab_size
251282
self.id_to_special_token[self.vocab_size] = token
252283
self.vocab_size += 1
284+
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
285+
self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)
286+
self.id_to_special_token[self.special_token_to_id[token]] = token
253287
else:
254288
raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens)))
255289

nemo/collections/nlp/modules/common/tokenizer_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def get_nmt_tokenizer(
173173
import omegaconf
174174
from omegaconf import OmegaConf
175175

176-
if isinstance(special_tokens, omegaconf.listconfig.ListConfig):
176+
if isinstance(special_tokens, (omegaconf.listconfig.ListConfig, omegaconf.dictconfig.DictConfig)):
177177
special_tokens = OmegaConf.to_container(special_tokens)
178178
if special_tokens is None:
179179
special_tokens_dict = {}

tests/collections/nlp/test_tokenizer_with_special_tokens.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@
1616
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
1717

1818
TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'
19+
SPECIAL_TOKENS = [
20+
'<s>',
21+
'</s>',
22+
'[INST]',
23+
'[/INST]',
24+
'[TOOL_CALLS]',
25+
'[AVAILABLE_TOOLS]',
26+
'[/AVAILABLE_TOOLS]',
27+
'[TOOL_RESULTS]',
28+
'[/TOOL_RESULTS]',
29+
]
1930

2031

21-
def test_spm_with_special_tokens() -> None:
22-
special_tokens = [
23-
'<s>',
24-
'</s>',
25-
'[INST]',
26-
'[/INST]',
27-
'[TOOL_CALLS]',
28-
'[AVAILABLE_TOOLS]',
29-
'[/AVAILABLE_TOOLS]',
30-
'[TOOL_RESULTS]',
31-
'[/TOOL_RESULTS]',
32-
]
32+
def _build_tokenizer(spm_file, special_tokens):
3333
tokenizer_cfg = {
3434
"library": "sentencepiece",
3535
"type": None,
@@ -39,18 +39,66 @@ def test_spm_with_special_tokens() -> None:
3939
"sentencepiece_legacy": True,
4040
"special_tokens": special_tokens,
4141
}
42-
tokenizer = get_nmt_tokenizer(
42+
return get_nmt_tokenizer(
4343
library=tokenizer_cfg['library'],
4444
model_name=tokenizer_cfg.get("type", None),
4545
use_fast=tokenizer_cfg.get("use_fast", False),
4646
delimiter=tokenizer_cfg.get("delimiter", None),
4747
special_tokens=tokenizer_cfg.get("special_tokens", None),
4848
trust_remote_code=tokenizer_cfg.get("trust_remote_code", False),
49-
tokenizer_model=TOKENIZER_SPM_FILE,
49+
tokenizer_model=spm_file,
5050
legacy=True,
5151
)
5252

53+
54+
def test_spm_with_special_tokens() -> None:
55+
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
5356
assert tokenizer.text_to_ids('[INST]') == [3]
54-
for i, special_token in enumerate(special_tokens):
57+
for i, special_token in enumerate(SPECIAL_TOKENS):
5558
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
5659
assert tokenizer.special_token_to_id[special_token] == i + 1
60+
61+
62+
def test_trim_spm_separator_after_special_token():
63+
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
64+
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 7294]
65+
tokenizer.trim_spm_separator_after_special_token = False
66+
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 29473, 7294]
67+
68+
69+
def test_text_to_tokens_with_trim_spm_separator_after_special_token():
70+
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
71+
text = "<s>[INST] Who are you?[/INST] This is a response</s>[INST] I'll ask again who are you?[/INST] I'm not a who</s>"
72+
tokenized = tokenizer.text_to_tokens(text)
73+
assert tokenized == [
74+
'<s>',
75+
'[INST]',
76+
'▁Who',
77+
'▁are',
78+
'▁you',
79+
'?',
80+
'[/INST]',
81+
'▁This',
82+
'▁is',
83+
'▁a',
84+
'▁response',
85+
'</s>',
86+
'[INST]',
87+
'▁I',
88+
"'",
89+
'll',
90+
'▁ask',
91+
'▁again',
92+
'▁who',
93+
'▁are',
94+
'▁you',
95+
'?',
96+
'[/INST]',
97+
'▁I',
98+
"'",
99+
'm',
100+
'▁not',
101+
'▁a',
102+
'▁who',
103+
'</s>',
104+
]

0 commit comments

Comments
 (0)