diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 45fbd4c8b328..ac22dbbbd7bc 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -51,6 +51,8 @@ def __init__( legacy: bool = False, ignore_extra_whitespaces: bool = True, chat_template: Optional[Dict] = None, + trim_spm_separator_after_special_token=True, + spm_separator='▁', ): self.chat_template = chat_template if not model_path or not os.path.exists(model_path): @@ -66,6 +68,10 @@ def __init__( self.extra_space_token = '☯' self.special_token_to_id = {} self.id_to_special_token = {} + self.trim_spm_separator_after_special_token = trim_spm_separator_after_special_token + self.spm_separator_id = self.tokenizer.piece_to_id(spm_separator) + self.spm_separator = spm_separator + if special_tokens: if not self.legacy: raise ValueError( @@ -99,7 +105,19 @@ def text_to_tokens(self, text): next_token = min(indices, key=indices.get) next_idx = idx + indices[next_token] - tokens.extend(self.tokenizer.encode_as_pieces(text[idx:next_idx])) + tok = self.tokenizer.encode_as_pieces(text[idx:next_idx]) + # Chat-templates insert a space between a special token and first word (e.g. + # "[INST] who") which is tokenized as instead of + # . + if ( + self.trim_spm_separator_after_special_token + and len(tokens) > 0 + and tokens[-1] in self.special_token_to_id + and len(tok) > 0 + and tok[0] == self.spm_separator + ): + tok.pop(0) + tokens.extend(tok) tokens.append(next_token) idx = next_idx + len(next_token) @@ -143,7 +161,19 @@ def _text_to_ids(self, text, sample_alpha=None): next_token = min(indices, key=indices.get) next_idx = idx + indices[next_token] - ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx])) + text_tokens = self.tokenizer.encode(text[idx:next_idx]) + # Chat-templates insert a space between a special token and first word (e.g. + # "[INST] who") which is tokenized as instead of + # . + if ( + self.trim_spm_separator_after_special_token + and len(ids) > 0 + and ids[-1] in self.id_to_special_token + and len(text_tokens) > 0 + and text_tokens[0] == self.spm_separator_id + ): + text_tokens.pop(0) + ids.extend(text_tokens) ids.append(self.special_token_to_id[next_token]) idx = next_idx + len(next_token) @@ -239,6 +269,7 @@ def add_special_tokens(self, special_tokens): self.vocab_size += 1 elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id(): self.special_token_to_id[token] = self.tokenizer.piece_to_id(token) + self.id_to_special_token[self.special_token_to_id[token]] = token elif isinstance(special_tokens, dict): for token_name, token in special_tokens.items(): @@ -250,6 +281,9 @@ def add_special_tokens(self, special_tokens): self.special_token_to_id[token] = self.vocab_size self.id_to_special_token[self.vocab_size] = token self.vocab_size += 1 + elif self.tokenizer.piece_to_id(token) != self.tokenizer.unk_id(): + self.special_token_to_id[token] = self.tokenizer.piece_to_id(token) + self.id_to_special_token[self.special_token_to_id[token]] = token else: raise ValueError("Expected special_tokens to be a list or a dict " + str(type(special_tokens))) diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 08c39b5a67cf..21d04887ea6b 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -173,7 +173,7 @@ def get_nmt_tokenizer( import omegaconf from omegaconf import OmegaConf - if isinstance(special_tokens, omegaconf.listconfig.ListConfig): + if isinstance(special_tokens, (omegaconf.listconfig.ListConfig, omegaconf.dictconfig.DictConfig)): special_tokens = OmegaConf.to_container(special_tokens) if special_tokens is None: special_tokens_dict = {} diff --git a/tests/collections/nlp/test_tokenizer_with_special_tokens.py b/tests/collections/nlp/test_tokenizer_with_special_tokens.py index d042231f6670..87460b1af316 100644 --- a/tests/collections/nlp/test_tokenizer_with_special_tokens.py +++ b/tests/collections/nlp/test_tokenizer_with_special_tokens.py @@ -16,20 +16,20 @@ from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer TOKENIZER_SPM_FILE = '/home/TestData/nlp/tokenizer_with_special_tokens/tokenizer.model' +SPECIAL_TOKENS = [ + '', + '', + '[INST]', + '[/INST]', + '[TOOL_CALLS]', + '[AVAILABLE_TOOLS]', + '[/AVAILABLE_TOOLS]', + '[TOOL_RESULTS]', + '[/TOOL_RESULTS]', +] -def test_spm_with_special_tokens() -> None: - special_tokens = [ - '', - '', - '[INST]', - '[/INST]', - '[TOOL_CALLS]', - '[AVAILABLE_TOOLS]', - '[/AVAILABLE_TOOLS]', - '[TOOL_RESULTS]', - '[/TOOL_RESULTS]', - ] +def _build_tokenizer(spm_file, special_tokens): tokenizer_cfg = { "library": "sentencepiece", "type": None, @@ -39,18 +39,66 @@ def test_spm_with_special_tokens() -> None: "sentencepiece_legacy": True, "special_tokens": special_tokens, } - tokenizer = get_nmt_tokenizer( + return get_nmt_tokenizer( library=tokenizer_cfg['library'], model_name=tokenizer_cfg.get("type", None), use_fast=tokenizer_cfg.get("use_fast", False), delimiter=tokenizer_cfg.get("delimiter", None), special_tokens=tokenizer_cfg.get("special_tokens", None), trust_remote_code=tokenizer_cfg.get("trust_remote_code", False), - tokenizer_model=TOKENIZER_SPM_FILE, + tokenizer_model=spm_file, legacy=True, ) + +def test_spm_with_special_tokens() -> None: + tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS) assert tokenizer.text_to_ids('[INST]') == [3] - for i, special_token in enumerate(special_tokens): + for i, special_token in enumerate(SPECIAL_TOKENS): assert special_token in tokenizer.special_token_to_id, f'Expected {special_token} to be a special token' assert tokenizer.special_token_to_id[special_token] == i + 1 + + +def test_trim_spm_separator_after_special_token(): + tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS) + tokenizer.text_to_ids('[INST] Who') == [1, 3, 7294] + tokenizer.trim_spm_separator_after_special_token = False + tokenizer.text_to_ids('[INST] Who') == [1, 3, 29473, 7294] + + +def test_text_to_tokens_with_trim_spm_separator_after_special_token(): + tokenizer = _build_tokenizer(TOKENIZER_SPM_FILE, SPECIAL_TOKENS) + text = "[INST] Who are you?[/INST] This is a response[INST] I'll ask again who are you?[/INST] I'm not a who" + tokenized = tokenizer.text_to_tokens(text) + assert tokenized == [ + '', + '[INST]', + '▁Who', + '▁are', + '▁you', + '?', + '[/INST]', + '▁This', + '▁is', + '▁a', + '▁response', + '', + '[INST]', + '▁I', + "'", + 'll', + '▁ask', + '▁again', + '▁who', + '▁are', + '▁you', + '?', + '[/INST]', + '▁I', + "'", + 'm', + '▁not', + '▁a', + '▁who', + '', + ]