@@ -171,15 +171,17 @@ def chat_completion(self, data):
171171 if OmegaConf .select (self .model .cfg , "data.chat_prompt_tokens" ) is not None :
172172 special_tokens = self .model .cfg .data .chat_prompt_tokens
173173 else :
174- #raise RuntimeError(
174+ # raise RuntimeError(
175175 # "You don't have a model (model_config.yaml) which has chat_prompt_tokens, are you sure this is a Chat/Instruction model?"
176- #)
176+ # )
177177 # (@adithyare) hacking in the special tokens to test non-chat models for debugging
178- special_tokens = {"system_turn_start" : "<SPECIAL_10>" ,
179- "turn_start" : "<SPECIAL_11>" ,
180- "label_start" : "<SPECIAL_12>" ,
181- "end_of_name" : "\n " ,
182- "end_of_turn" : "\n " }
178+ special_tokens = {
179+ "system_turn_start" : "<SPECIAL_10>" ,
180+ "turn_start" : "<SPECIAL_11>" ,
181+ "label_start" : "<SPECIAL_12>" ,
182+ "end_of_name" : "\n " ,
183+ "end_of_turn" : "\n " ,
184+ }
183185 nemo_source = self .convert_messages (data ['messages' ])
184186 header , conversation , data_type , mask_role = _get_header_conversation_type_mask_role (
185187 nemo_source , special_tokens
@@ -429,7 +431,7 @@ def put(self):
429431 # (@adithyare) resolves a json byte conversion issue (taken from chat_completeion)
430432 for i in range (len (output ['tokens' ])):
431433 tokens = output ['tokens' ][i ]
432- output ['tokens' ][i ] = [t .decode ('utf-8' , errors = 'replace' ) if isinstance (t , bytes ) else t for t in tokens ]
434+ output ['tokens' ][i ] = [t .decode ('utf-8' , errors = 'replace' ) if isinstance (t , bytes ) else t for t in tokens ]
433435
434436 if not all_probs :
435437 del output ['full_logprob' ]
0 commit comments