2121"""
2222
2323
24+ import hashlib
2425import json
2526import os
27+ import re
2628from argparse import ArgumentParser
2729from collections import OrderedDict
2830from pathlib import Path
3133import torch .nn
3234from lightning .pytorch .core .saving import _load_state as ptl_load_state
3335from lightning .pytorch .trainer .trainer import Trainer
34- from omegaconf import OmegaConf
36+ from omegaconf import OmegaConf , open_dict
3537from transformers import AutoModelForCausalLM , AutoTokenizer
3638
3739from nemo .collections .nlp .models .language_modeling .megatron_gpt_model import MegatronGPTModel
@@ -58,6 +60,7 @@ def get_args():
5860 parser .add_argument ("--output_path" , type = str , default = None , required = True , help = "Path to output .nemo file." )
5961 parser .add_argument ("--precision" , type = str , default = "bf16" , help = "Model precision" )
6062 parser .add_argument ('--low-ram' , '--low-mem' , action = 'store_true' , dest = 'low_ram' )
63+ parser .add_argument ('--add-additional-tokens' , action = 'store_true' )
6164 parser .add_argument ('--tmp-dir' , default = '/tmp/mistral_ckpt_parts/' )
6265 args = parser .parse_args ()
6366 return args
@@ -430,6 +433,16 @@ def merge(a: dict, b: dict, path=[]):
430433 return a
431434
432435
436+ def md5_checksum (filepath ):
437+ if filepath is None :
438+ return None
439+ hash_md5 = hashlib .md5 ()
440+ with open (filepath , "rb" ) as f :
441+ for chunk in iter (lambda : f .read (4096 ), b"" ):
442+ hash_md5 .update (chunk )
443+ return hash_md5 .hexdigest ()
444+
445+
433446def save_to_nemo (args , checkpoint ):
434447 """saves checkpoint to nemo format"""
435448
@@ -464,15 +477,42 @@ def save_to_nemo(args, checkpoint):
464477 # disable cpu init
465478 model .cfg .use_cpu_initialization = False
466479 model .cfg .perform_initialization = True
480+ # If user has passed --add-additional-tokens or model is mistralai/Mistral-7B-Instruct-v0.3
481+ if (
482+ args .add_additional_tokens
483+ or md5_checksum (getattr (tokenizer , 'vocab_file' , None )) == '2bbc01eba250283314fdbd53d05de94b'
484+ ):
485+
486+ def make_token_name (token ):
487+ prefix = ''
488+ if len (token ) > 1 and token [1 ] == '/' :
489+ prefix = 'eos_'
490+ else :
491+ prefix = 'bos_'
492+ return prefix + re .sub (r'\W' , '_' , token )
493+
494+ if len (tokenizer .added_tokens_decoder ) > 0 :
495+ with open_dict (model .cfg .tokenizer ):
496+ model .cfg .tokenizer .sentencepiece_legacy = True
497+ model .cfg .tokenizer .special_tokens = {}
498+ model .cfg .tokenizer .special_tokens ['bos_token' ] = tokenizer .bos_token or "<s>"
499+ model .cfg .tokenizer .special_tokens ['eos_token' ] = tokenizer .eos_token or "</s>"
500+ model .cfg .tokenizer .special_tokens ['pad_token' ] = tokenizer .pad_token or "<pad>"
501+ skip_tokens = set (model .cfg .tokenizer .special_tokens .values ())
502+ skip_tokens .add ('<unk>' )
503+ for token_id , token in tokenizer .added_tokens_decoder .items ():
504+ token_name = make_token_name (token .content )
505+ if token .content in skip_tokens :
506+ continue
507+ assert not token_name in model .cfg .tokenizer .special_tokens
508+ model .cfg .tokenizer .special_tokens [token_name ] = token .content
509+
467510 if getattr (tokenizer , 'chat_template' , None ) is not None :
468- import hashlib
469511
470512 template_hash = hashlib .md5 (tokenizer .chat_template .encode ('utf-8' )).hexdigest ()
471513 if template_hash != "0b629f783db54e02509999196956ff40" :
472514 logging .warning ("Got unkown chat template" )
473515 else :
474- from omegaconf import OmegaConf , open_dict
475-
476516 with open_dict (model .cfg ):
477517 model .cfg .tokenizer .chat_template = OmegaConf .create (
478518 {
0 commit comments