Skip to content

Commit 8b4fb4f

Browse files
akoumpaabhinavg4
authored andcommitted
handle mistralai/Mistral-7B-Instruct-v0.3 tokenizer correctly (#11839)
* handle mistralai/Mistral-7B-Instruct-v0.3 tokenizer correctly Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * remove manual token addition Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
1 parent 0e48dbc commit 8b4fb4f

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222

2323

24+
import hashlib
2425
import json
2526
import os
27+
import re
2628
from argparse import ArgumentParser
2729
from collections import OrderedDict
2830
from pathlib import Path
@@ -31,7 +33,7 @@
3133
import torch.nn
3234
from lightning.pytorch.core.saving import _load_state as ptl_load_state
3335
from lightning.pytorch.trainer.trainer import Trainer
34-
from omegaconf import OmegaConf
36+
from omegaconf import OmegaConf, open_dict
3537
from transformers import AutoModelForCausalLM, AutoTokenizer
3638

3739
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
@@ -58,6 +60,7 @@ def get_args():
5860
parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
5961
parser.add_argument("--precision", type=str, default="bf16", help="Model precision")
6062
parser.add_argument('--low-ram', '--low-mem', action='store_true', dest='low_ram')
63+
parser.add_argument('--add-additional-tokens', action='store_true')
6164
parser.add_argument('--tmp-dir', default='/tmp/mistral_ckpt_parts/')
6265
args = parser.parse_args()
6366
return args
@@ -430,6 +433,16 @@ def merge(a: dict, b: dict, path=[]):
430433
return a
431434

432435

436+
def md5_checksum(filepath):
437+
if filepath is None:
438+
return None
439+
hash_md5 = hashlib.md5()
440+
with open(filepath, "rb") as f:
441+
for chunk in iter(lambda: f.read(4096), b""):
442+
hash_md5.update(chunk)
443+
return hash_md5.hexdigest()
444+
445+
433446
def save_to_nemo(args, checkpoint):
434447
"""saves checkpoint to nemo format"""
435448

@@ -464,15 +477,42 @@ def save_to_nemo(args, checkpoint):
464477
# disable cpu init
465478
model.cfg.use_cpu_initialization = False
466479
model.cfg.perform_initialization = True
480+
# If user has passed --add-additional-tokens or model is mistralai/Mistral-7B-Instruct-v0.3
481+
if (
482+
args.add_additional_tokens
483+
or md5_checksum(getattr(tokenizer, 'vocab_file', None)) == '2bbc01eba250283314fdbd53d05de94b'
484+
):
485+
486+
def make_token_name(token):
487+
prefix = ''
488+
if len(token) > 1 and token[1] == '/':
489+
prefix = 'eos_'
490+
else:
491+
prefix = 'bos_'
492+
return prefix + re.sub(r'\W', '_', token)
493+
494+
if len(tokenizer.added_tokens_decoder) > 0:
495+
with open_dict(model.cfg.tokenizer):
496+
model.cfg.tokenizer.sentencepiece_legacy = True
497+
model.cfg.tokenizer.special_tokens = {}
498+
model.cfg.tokenizer.special_tokens['bos_token'] = tokenizer.bos_token or "<s>"
499+
model.cfg.tokenizer.special_tokens['eos_token'] = tokenizer.eos_token or "</s>"
500+
model.cfg.tokenizer.special_tokens['pad_token'] = tokenizer.pad_token or "<pad>"
501+
skip_tokens = set(model.cfg.tokenizer.special_tokens.values())
502+
skip_tokens.add('<unk>')
503+
for token_id, token in tokenizer.added_tokens_decoder.items():
504+
token_name = make_token_name(token.content)
505+
if token.content in skip_tokens:
506+
continue
507+
assert not token_name in model.cfg.tokenizer.special_tokens
508+
model.cfg.tokenizer.special_tokens[token_name] = token.content
509+
467510
if getattr(tokenizer, 'chat_template', None) is not None:
468-
import hashlib
469511

470512
template_hash = hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest()
471513
if template_hash != "0b629f783db54e02509999196956ff40":
472514
logging.warning("Got unkown chat template")
473515
else:
474-
from omegaconf import OmegaConf, open_dict
475-
476516
with open_dict(model.cfg):
477517
model.cfg.tokenizer.chat_template = OmegaConf.create(
478518
{

0 commit comments

Comments
 (0)