Skip to content
Merged
38 changes: 36 additions & 2 deletions nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
legacy: bool = False,
ignore_extra_whitespaces: bool = True,
chat_template: Optional[Dict] = None,
trim_spm_separator_after_special_token=True,
spm_separator='▁',
):
self.chat_template = chat_template
if not model_path or not os.path.exists(model_path):
Expand All @@ -66,6 +68,10 @@ def __init__(
self.extra_space_token = '☯'
self.special_token_to_id = {}
self.id_to_special_token = {}
self.trim_spm_separator_after_special_token = trim_spm_separator_after_special_token
self.spm_separator_id = self.tokenizer.piece_to_id(spm_separator)
self.spm_separator = spm_separator

if special_tokens:
if not self.legacy:
raise ValueError(
Expand Down Expand Up @@ -99,7 +105,19 @@ def text_to_tokens(self, text):
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]

tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx]))
tok = self.tokenizer.encode_as_pieces(text[idx:next_idx])
# Chat-templates insert a space between a special token and first word (e.g.
# "[INST] who") which is tokenized as <inst-id> <space-id> <who-id> instead of
# <inst-id> <who-id>.
if (
self.trim_spm_separator_after_special_token
and len(tokens) > 0
and tokens[-1] in self.special_token_to_id
and len(tok) > 0
and tok[0] == self.spm_separator
):
tok.pop(0)
tokens.extend(tok)
tokens.append(next_token)
idx = next_idx + len(next_token)

Expand Down Expand Up @@ -143,7 +161,19 @@ def _text_to_ids(self, text, sample_alpha=None):
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]

ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
text_tokens = self.tokenizer.encode(text[idx:next_idx])
# Chat-templates insert a space between a special token and first word (e.g.
# "[INST] who") which is tokenized as <inst-id> <space-id> <who-id> instead of
# <inst-id> <who-id>.
if (
self.trim_spm_separator_after_special_token
and len(ids) > 0
and ids[-1] in self.id_to_special_token
and len(text_tokens) > 0
and text_tokens[0] == self.spm_separator_id
):
text_tokens.pop(0)
ids.extend(text_tokens)
ids.append(self.special_token_to_id[next_token])
idx = next_idx + len(next_token)

Expand Down Expand Up @@ -239,6 +269,7 @@ def add_special_tokens(self, special_tokens):
self.vocab_size += 1
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)
self.id_to_special_token[self.special_token_to_id[token]] = token

elif isinstance(special_tokens, dict):
for token_name, token in special_tokens.items():
Expand All @@ -250,6 +281,9 @@ def add_special_tokens(self, special_tokens):
self.special_token_to_id[token] = self.vocab_size
self.id_to_special_token[self.vocab_size] = token
self.vocab_size += 1
elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id():
self.special_token_to_id[token] = self.tokenizer.piece_to_id(token)
self.id_to_special_token[self.special_token_to_id[token]] = token
else:
raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens)))

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_nmt_tokenizer(
import omegaconf
from omegaconf import OmegaConf

if isinstance(special_tokens, omegaconf.listconfig.ListConfig):
if isinstance(special_tokens, (omegaconf.listconfig.ListConfig, omegaconf.dictconfig.DictConfig)):
special_tokens = OmegaConf.to_container(special_tokens)
if special_tokens is None:
special_tokens_dict = {}
Expand Down
78 changes: 63 additions & 15 deletions tests/collections/nlp/test_tokenizer_with_special_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model'
SPECIAL_TOKENS = [
'<s>',
'</s>',
'[INST]',
'[/INST]',
'[TOOL_CALLS]',
'[AVAILABLE_TOOLS]',
'[/AVAILABLE_TOOLS]',
'[TOOL_RESULTS]',
'[/TOOL_RESULTS]',
]


def test_spm_with_special_tokens() -> None:
special_tokens = [
'<s>',
'</s>',
'[INST]',
'[/INST]',
'[TOOL_CALLS]',
'[AVAILABLE_TOOLS]',
'[/AVAILABLE_TOOLS]',
'[TOOL_RESULTS]',
'[/TOOL_RESULTS]',
]
def _build_tokenizer(spm_file, special_tokens):
tokenizer_cfg = {
"library": "sentencepiece",
"type": None,
Expand All @@ -39,18 +39,66 @@ def test_spm_with_special_tokens() -> None:
"sentencepiece_legacy": True,
"special_tokens": special_tokens,
}
tokenizer = get_nmt_tokenizer(
return get_nmt_tokenizer(
library=tokenizer_cfg['library'],
model_name=tokenizer_cfg.get("type", None),
use_fast=tokenizer_cfg.get("use_fast", False),
delimiter=tokenizer_cfg.get("delimiter", None),
special_tokens=tokenizer_cfg.get("special_tokens", None),
trust_remote_code=tokenizer_cfg.get("trust_remote_code", False),
tokenizer_model=TOKENIZER_SPM_FILE,
tokenizer_model=spm_file,
legacy=True,
)


def test_spm_with_special_tokens() -> None:
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
assert tokenizer.text_to_ids('[INST]') == [3]
for i, special_token in enumerate(special_tokens):
for i, special_token in enumerate(SPECIAL_TOKENS):
assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token'
assert tokenizer.special_token_to_id[special_token] == i + 1


def test_trim_spm_separator_after_special_token():
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 7294]
tokenizer.trim_spm_separator_after_special_token = False
tokenizer.text_to_ids('<s>[INST] Who') == [1, 3, 29473, 7294]


def test_text_to_tokens_with_trim_spm_separator_after_special_token():
tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS)
text = "<s>[INST] Who are you?[/INST] This is a response</s>[INST] I'll ask again who are you?[/INST] I'm not a who</s>"
tokenized = tokenizer.text_to_tokens(text)
assert tokenized == [
'<s>',
'[INST]',
'▁Who',
'▁are',
'▁you',
'?',
'[/INST]',
'▁This',
'▁is',
'▁a',
'▁response',
'</s>',
'[INST]',
'▁I',
"'",
'll',
'▁ask',
'▁again',
'▁who',
'▁are',
'▁you',
'?',
'[/INST]',
'▁I',
"'",
'm',
'▁not',
'▁a',
'▁who',
'</s>',
]