1313# limitations under the License.
1414
1515import re
16+ import numpy as np
1617
1718import torch
1819import torch .nn .functional as F
@@ -33,19 +34,21 @@ class NeMoFWLMEval(LM):
3334 Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md
3435 """
3536
36- def __init__ (self , model_name , api_url , tokenizer , temperature , top_p , top_k , add_bos ):
37+ def __init__ (
38+ self , model_name , api_url , tokenizer , batch_size , max_tokens_to_generate , temperature , top_p , top_k , add_bos
39+ ):
3740 self .model_name = model_name
3841 self .api_url = api_url
3942 self .tokenizer = tokenizer
43+ self .batch_size = batch_size
44+ self .max_tokens_to_generate = max_tokens_to_generate
4045 self .temperature = temperature
4146 self .top_p = top_p
4247 self .top_k = top_k
4348 self .add_bos = add_bos
4449 super ().__init__ ()
4550
46- def _generate_tokens_logits (
47- self , payload , single_prediction_token , return_text : bool = False , return_logits : bool = False
48- ):
51+ def _generate_tokens_logits (self , payload , single_prediction_token : bool = False , return_logits : bool = False ):
4952 """
5053 A private method that sends post request to the model on PyTriton server and returns either generated text or
5154 logits.
@@ -54,12 +57,13 @@ def _generate_tokens_logits(
5457
5558 output_context_logits = False
5659 output_generation_logits = False
57- if single_prediction_token :
58- # In case of single token prediction return the generation logits
59- output_generation_logits = True
60- else :
61- # In case of multiple token prediction return the context logits
62- output_context_logits = True
60+ if return_logits : # in case of loglikelihood type tasks
61+ if single_prediction_token :
62+ # In case of single token prediction like mmlu return only the generation logits
63+ output_generation_logits = True
64+ else :
65+ # In case of multiple token prediction return the full context logits
66+ output_context_logits = True
6367 response = nq .query_llm (
6468 prompts = payload ['prompt' ] if isinstance (payload ['prompt' ], list ) else [payload ['prompt' ]],
6569 max_output_len = payload ['max_tokens' ],
@@ -71,13 +75,13 @@ def _generate_tokens_logits(
7175 openai_format_response = True ,
7276 )
7377
74- if return_text :
75- return response ["choices" ][0 ]["text" ] # shape[batch_size, 1]
76- elif return_logits :
78+ if return_logits : # loglikelihood type tasks, return just logits and not text
7779 if output_context_logits :
7880 return response ["choices" ][0 ]["context_logits" ]
7981 else :
8082 return response ["choices" ][0 ]["generation_logits" ]
83+ else : # generate_until type tasks, return just text and not logits
84+ return str (response ["choices" ][0 ]["text" ])
8185
8286 def tokenizer_type (self , tokenizer ):
8387 """
@@ -110,59 +114,90 @@ def loglikelihood(self, requests: list[Instance]):
110114 # Assuming evaluating on only one benchmark/task at a time, hence all instances in requests are of the same
111115 # task.
112116 mmlu_regex_pattern = r"^mmlu_"
113- lambada_regex_pattern = r"^lambada_"
114- if re .match (mmlu_regex_pattern , requests [0 ].task_name ) or re .match (
115- lambada_regex_pattern , requests [0 ].task_name
116- ):
117+ if re .match (mmlu_regex_pattern , requests [0 ].task_name ):
118+ # in case of mmlu the output token is one of 'a','b','c','d'
117119 single_prediction_token = True
118120
121+ # Hard code max_tokens_to_generate to 1 to always generate just 1 token in case of loglikelihood type tasks
122+ self .max_tokens_to_generate = 1
123+
119124 results = []
120- for request in tqdm (requests ):
121- # get the input prompt from the request
122- context = request .arguments [0 ]
123- # get the output prompt from the request
124- continuation = request .arguments [1 ]
125- # get encoded tokens of continuation
126- continuation_enc = self .tokenizer .tokenizer .encode (continuation , ** special_tokens_kwargs )
127- # for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space.
128- if self .tokenizer_type (self .tokenizer ) == "SentencePieceTokenizer" :
129- continuation_enc = continuation_enc [1 :]
130- num_cont_tokens = len (continuation_enc )
131- # Hard code max_tokens_to_generate to 1 to always generate just 1 token
132- self .max_tokens_to_generate = 1
133- # Delete the last token from continuation before passing it to the ip prompt by replacing with empty string
134- prompt = context + continuation .replace (self .tokenizer .tokenizer .decode (continuation_enc [- 1 ]), "" )
135- # Create payload to query the model deployed on PyTriton server
125+ for i in tqdm (range (0 , len (requests ), self .batch_size )):
126+ # Group requests into batches
127+ batch = requests [i : i + self .batch_size ]
128+ prompts = []
129+ continuations = []
130+ continuation_encs = []
131+ num_ctx_tokens_list = []
132+ num_cont_tokens_list = []
133+ # Prepare inputs for the batch
134+ for request in batch :
135+ # get the input prompt from the request
136+ context = request .arguments [0 ]
137+ # get the output prompt from the request
138+ continuation = request .arguments [1 ]
139+ # get encoded tokens of context
140+ context_enc = self .tokenizer .tokenizer .encode (context , ** special_tokens_kwargs )
141+ # get encoded tokens of continuation
142+ continuation_enc = self .tokenizer .tokenizer .encode (continuation , ** special_tokens_kwargs )
143+ # for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space.
144+ if self .tokenizer_type (self .tokenizer ) == "SentencePieceTokenizer" :
145+ context_enc = context_enc [1 :]
146+ continuation_enc = continuation_enc [1 :]
147+ num_ctx_tokens = len (context_enc )
148+ num_cont_tokens = len (continuation_enc )
149+ # Delete the last token from continuation before passing it to the ip prompt by replacing with empty
150+ # string
151+ prompt = context + continuation .replace (self .tokenizer .tokenizer .decode (continuation_enc [- 1 ]), "" )
152+
153+ prompts .append (prompt )
154+ continuations .append (continuation )
155+ continuation_encs .append (continuation_enc )
156+ num_ctx_tokens_list .append (num_ctx_tokens )
157+ num_cont_tokens_list .append (num_cont_tokens )
158+
159+ # Create a single payload for the entire batch
136160 payload = {
137161 "model" : self .model_name ,
138- "prompt" : prompt ,
162+ "prompt" : prompts ,
139163 "max_tokens" : self .max_tokens_to_generate ,
140164 "temperature" : self .temperature ,
141165 "top_p" : self .top_p ,
142166 "top_k" : self .top_k ,
143167 }
144- # Get the logits from the model
145- logits = self ._generate_tokens_logits (payload , single_prediction_token , return_logits = True )
146- # In case of multiple token prediction where full context logits are returned, get only logits
147- # corresponding to the continuation tokens from the context logits tensor.context_logits contains logits
148- # for all tokens in the ip prompt along with the logit for the next token prediction after the final token
149- # in the prompt. Shape of context_logits: [1, #tokens_in_prompt+1, vocab_size]
150- if not single_prediction_token :
151- logits = logits [:, - num_cont_tokens :, :]
152- # Convert logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
153- logProbs = F .log_softmax (torch .tensor (logits ), dim = - 1 )
154- # Convert encoded continuation tokens to torch tensor
155- cont_toks = torch .tensor (continuation_enc , dtype = torch .long ).unsqueeze (0 )
156- # Get the greedy token from the logits (i.e token with the highest prob)
157- greedy_tokens = logProbs .argmax (dim = - 1 )
158- # Check if all greedy_tokens match the the actual continuation tokens
159- is_greedy = (greedy_tokens == cont_toks ).all ()
160- # Get the logits corresponding to the actual continuation tokens
161- logProbs_actual = torch .gather (logProbs , 2 , cont_toks .unsqueeze (- 1 )).squeeze (- 1 )
162- # result is tuple of logProb of generating the continuation token and is_greedy
163- result = (float (logProbs_actual .sum ()), bool (is_greedy ))
164-
165- results .append (result )
168+
169+ # Query the model deployed on PyTriton server with the batched payload to get the logits
170+ logits_batch = self ._generate_tokens_logits (payload , single_prediction_token , return_logits = True )
171+
172+ # Process each result in the batch
173+ for j , logits in enumerate (logits_batch ):
174+ continuation_enc = continuation_encs [j ]
175+ num_ctx_tokens = num_ctx_tokens_list [j ]
176+ num_cont_tokens = num_cont_tokens_list [j ]
177+
178+ # In case of multiple token prediction where full context logits are returned (tasks other than mmlu),
179+ # get only logits corresponding to the continuation tokens from context logits tensor. context_logits
180+ # contains logits for all tokens in the ip prompt along with the logit for the next token prediction
181+ # after the final token in the prompt. Shape of context_logits: [1, #tokens_in_prompt+1, vocab_size].
182+ if not single_prediction_token :
183+ # Discard zero padding if any
184+ logits = logits [:, np .any (logits != 0 , axis = (0 , 2 )), :]
185+ # Get only logits corresponding to cont tokens
186+ logits = logits [:, - num_cont_tokens :, :]
187+ # Convert logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
188+ logProbs = F .log_softmax (torch .tensor (logits ), dim = - 1 )
189+ # Convert encoded continuation tokens to torch tensor
190+ cont_toks = torch .tensor (continuation_enc , dtype = torch .long ).unsqueeze (0 )
191+ # Get the greedy token from the logits (i.e token with the highest prob)
192+ greedy_tokens = logProbs .argmax (dim = - 1 )
193+ # Check if all greedy_tokens match the the actual continuation tokens
194+ is_greedy = (greedy_tokens == cont_toks ).all ()
195+ # Get the logits corresponding to the actual continuation tokens
196+ logProbs_actual = torch .gather (logProbs , 2 , cont_toks .unsqueeze (- 1 )).squeeze (- 1 )
197+ # result is tuple of logProb of generating the continuation token and is_greedy
198+ result = (float (logProbs_actual .sum ()), bool (is_greedy ))
199+ # Append the result of this input in the batch to results list
200+ results .append (result )
166201
167202 return results
168203
@@ -179,7 +214,7 @@ def generate_until(self, inputs: list[Instance]):
179214 type(here loglikelihood) and other relevant args like few shot samples.
180215 """
181216 results = []
182- for instance in inputs :
217+ for instance in tqdm ( inputs ) :
183218 # Access the 'arguments' attribute of the Instance which contains the input prompt string
184219 prompt = instance .arguments [0 ]
185220 # Create payload to query the model deployed on PyTriton server
@@ -192,7 +227,7 @@ def generate_until(self, inputs: list[Instance]):
192227 "top_k" : self .top_k ,
193228 }
194229 # Get the text generated by the model
195- generated_text = self ._generate_tokens_logits (payload , return_text = True )
230+ generated_text = self ._generate_tokens_logits (payload )
196231
197232 results .append (generated_text )
198233
0 commit comments