Skip to content

Commit 3cf4319

Browse files
athittenyoungeunkwon0405
authored andcommitted
Add batching support for evaluation (NVIDIA-NeMo#11934)
* Add server ready check before evaluation Uses bool generation_logits_available as inputs dict does not contain it Signed-off-by: Abhishree <abhishreetm@gmail.com> * Apply isort and black reformatting Signed-off-by: athitten <athitten@users.noreply.github.com> * Add batching changes Signed-off-by: Abhishree <abhishreetm@gmail.com> * Discard 0 padding with batching and other minor edits Signed-off-by: Abhishree <abhishreetm@gmail.com> * Add func for padding and minor edits Signed-off-by: Abhishree <abhishreetm@gmail.com> * Remove commented code and Pylint fixes Signed-off-by: Abhishree <abhishreetm@gmail.com> * Apply isort and black reformatting Signed-off-by: athitten <athitten@users.noreply.github.com> --------- Signed-off-by: Abhishree <abhishreetm@gmail.com> Signed-off-by: athitten <athitten@users.noreply.github.com> Co-authored-by: athitten <athitten@users.noreply.github.com> Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
1 parent b54a689 commit 3cf4319

File tree

3 files changed

+127
-74
lines changed

3 files changed

+127
-74
lines changed

nemo/collections/llm/api.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,6 @@ def deploy(
387387

388388
unset_environment_variables()
389389

390-
if not isinstance(nemo_checkpoint, Path):
391-
nemo_checkpoint = Path(nemo_checkpoint)
392-
if not isinstance(triton_model_repository, Path):
393-
triton_model_repository = Path(triton_model_repository)
394-
395390
triton_deployable = get_trtllm_deployable(
396391
nemo_checkpoint,
397392
model_type,
@@ -446,6 +441,8 @@ def evaluate(
446441
limit: Optional[Union[int, float]] = None,
447442
bootstrap_iters: int = 100000,
448443
# inference params
444+
batch_size: Optional[int] = 1,
445+
max_tokens_to_generate: Optional[int] = 256,
449446
temperature: Optional[float] = 0.000000001,
450447
top_p: Optional[float] = 0.0,
451448
top_k: Optional[int] = 1,
@@ -495,15 +492,14 @@ def evaluate(
495492

496493
from nemo.collections.llm import evaluation
497494

498-
if not isinstance(nemo_checkpoint_path, Path):
499-
nemo_checkpoint_path = Path(nemo_checkpoint_path)
500-
501495
# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
502496
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
503497
# Wait for server to be ready before starting evaluation
504498
evaluation.wait_for_server_ready(url=url, triton_http_port=triton_http_port, model_name=model_name)
505499
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
506-
model = evaluation.NeMoFWLMEval(model_name, url, tokenizer, temperature, top_p, top_k, add_bos)
500+
model = evaluation.NeMoFWLMEval(
501+
model_name, url, tokenizer, batch_size, max_tokens_to_generate, temperature, top_p, top_k, add_bos
502+
)
507503
results = evaluator.simple_evaluate(
508504
model=model,
509505
tasks=eval_task,

nemo/collections/llm/evaluation/base.py

Lines changed: 93 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16+
import numpy as np
1617

1718
import torch
1819
import torch.nn.functional as F
@@ -33,19 +34,21 @@ class NeMoFWLMEval(LM):
3334
Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md
3435
"""
3536

36-
def __init__(self, model_name, api_url, tokenizer, temperature, top_p, top_k, add_bos):
37+
def __init__(
38+
self, model_name, api_url, tokenizer, batch_size, max_tokens_to_generate, temperature, top_p, top_k, add_bos
39+
):
3740
self.model_name = model_name
3841
self.api_url = api_url
3942
self.tokenizer = tokenizer
43+
self.batch_size = batch_size
44+
self.max_tokens_to_generate = max_tokens_to_generate
4045
self.temperature = temperature
4146
self.top_p = top_p
4247
self.top_k = top_k
4348
self.add_bos = add_bos
4449
super().__init__()
4550

46-
def _generate_tokens_logits(
47-
self, payload, single_prediction_token, return_text: bool = False, return_logits: bool = False
48-
):
51+
def _generate_tokens_logits(self, payload, single_prediction_token: bool = False, return_logits: bool = False):
4952
"""
5053
A private method that sends post request to the model on PyTriton server and returns either generated text or
5154
logits.
@@ -54,12 +57,13 @@ def _generate_tokens_logits(
5457

5558
output_context_logits = False
5659
output_generation_logits = False
57-
if single_prediction_token:
58-
# In case of single token prediction return the generation logits
59-
output_generation_logits = True
60-
else:
61-
# In case of multiple token prediction return the context logits
62-
output_context_logits = True
60+
if return_logits: # in case of loglikelihood type tasks
61+
if single_prediction_token:
62+
# In case of single token prediction like mmlu return only the generation logits
63+
output_generation_logits = True
64+
else:
65+
# In case of multiple token prediction return the full context logits
66+
output_context_logits = True
6367
response = nq.query_llm(
6468
prompts=payload['prompt'] if isinstance(payload['prompt'], list) else [payload['prompt']],
6569
max_output_len=payload['max_tokens'],
@@ -71,13 +75,13 @@ def _generate_tokens_logits(
7175
openai_format_response=True,
7276
)
7377

74-
if return_text:
75-
return response["choices"][0]["text"] # shape[batch_size, 1]
76-
elif return_logits:
78+
if return_logits: # loglikelihood type tasks, return just logits and not text
7779
if output_context_logits:
7880
return response["choices"][0]["context_logits"]
7981
else:
8082
return response["choices"][0]["generation_logits"]
83+
else: # generate_until type tasks, return just text and not logits
84+
return str(response["choices"][0]["text"])
8185

8286
def tokenizer_type(self, tokenizer):
8387
"""
@@ -110,59 +114,90 @@ def loglikelihood(self, requests: list[Instance]):
110114
# Assuming evaluating on only one benchmark/task at a time, hence all instances in requests are of the same
111115
# task.
112116
mmlu_regex_pattern = r"^mmlu_"
113-
lambada_regex_pattern = r"^lambada_"
114-
if re.match(mmlu_regex_pattern, requests[0].task_name) or re.match(
115-
lambada_regex_pattern, requests[0].task_name
116-
):
117+
if re.match(mmlu_regex_pattern, requests[0].task_name):
118+
# in case of mmlu the output token is one of 'a','b','c','d'
117119
single_prediction_token = True
118120

121+
# Hard code max_tokens_to_generate to 1 to always generate just 1 token in case of loglikelihood type tasks
122+
self.max_tokens_to_generate = 1
123+
119124
results = []
120-
for request in tqdm(requests):
121-
# get the input prompt from the request
122-
context = request.arguments[0]
123-
# get the output prompt from the request
124-
continuation = request.arguments[1]
125-
# get encoded tokens of continuation
126-
continuation_enc = self.tokenizer.tokenizer.encode(continuation, **special_tokens_kwargs)
127-
# for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space.
128-
if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer":
129-
continuation_enc = continuation_enc[1:]
130-
num_cont_tokens = len(continuation_enc)
131-
# Hard code max_tokens_to_generate to 1 to always generate just 1 token
132-
self.max_tokens_to_generate = 1
133-
# Delete the last token from continuation before passing it to the ip prompt by replacing with empty string
134-
prompt = context + continuation.replace(self.tokenizer.tokenizer.decode(continuation_enc[-1]), "")
135-
# Create payload to query the model deployed on PyTriton server
125+
for i in tqdm(range(0, len(requests), self.batch_size)):
126+
# Group requests into batches
127+
batch = requests[i : i + self.batch_size]
128+
prompts = []
129+
continuations = []
130+
continuation_encs = []
131+
num_ctx_tokens_list = []
132+
num_cont_tokens_list = []
133+
# Prepare inputs for the batch
134+
for request in batch:
135+
# get the input prompt from the request
136+
context = request.arguments[0]
137+
# get the output prompt from the request
138+
continuation = request.arguments[1]
139+
# get encoded tokens of context
140+
context_enc = self.tokenizer.tokenizer.encode(context, **special_tokens_kwargs)
141+
# get encoded tokens of continuation
142+
continuation_enc = self.tokenizer.tokenizer.encode(continuation, **special_tokens_kwargs)
143+
# for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space.
144+
if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer":
145+
context_enc = context_enc[1:]
146+
continuation_enc = continuation_enc[1:]
147+
num_ctx_tokens = len(context_enc)
148+
num_cont_tokens = len(continuation_enc)
149+
# Delete the last token from continuation before passing it to the ip prompt by replacing with empty
150+
# string
151+
prompt = context + continuation.replace(self.tokenizer.tokenizer.decode(continuation_enc[-1]), "")
152+
153+
prompts.append(prompt)
154+
continuations.append(continuation)
155+
continuation_encs.append(continuation_enc)
156+
num_ctx_tokens_list.append(num_ctx_tokens)
157+
num_cont_tokens_list.append(num_cont_tokens)
158+
159+
# Create a single payload for the entire batch
136160
payload = {
137161
"model": self.model_name,
138-
"prompt": prompt,
162+
"prompt": prompts,
139163
"max_tokens": self.max_tokens_to_generate,
140164
"temperature": self.temperature,
141165
"top_p": self.top_p,
142166
"top_k": self.top_k,
143167
}
144-
# Get the logits from the model
145-
logits = self._generate_tokens_logits(payload, single_prediction_token, return_logits=True)
146-
# In case of multiple token prediction where full context logits are returned, get only logits
147-
# corresponding to the continuation tokens from the context logits tensor.context_logits contains logits
148-
# for all tokens in the ip prompt along with the logit for the next token prediction after the final token
149-
# in the prompt. Shape of context_logits: [1, #tokens_in_prompt+1, vocab_size]
150-
if not single_prediction_token:
151-
logits = logits[:, -num_cont_tokens:, :]
152-
# Convert logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
153-
logProbs = F.log_softmax(torch.tensor(logits), dim=-1)
154-
# Convert encoded continuation tokens to torch tensor
155-
cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0)
156-
# Get the greedy token from the logits (i.e token with the highest prob)
157-
greedy_tokens = logProbs.argmax(dim=-1)
158-
# Check if all greedy_tokens match the the actual continuation tokens
159-
is_greedy = (greedy_tokens == cont_toks).all()
160-
# Get the logits corresponding to the actual continuation tokens
161-
logProbs_actual = torch.gather(logProbs, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
162-
# result is tuple of logProb of generating the continuation token and is_greedy
163-
result = (float(logProbs_actual.sum()), bool(is_greedy))
164-
165-
results.append(result)
168+
169+
# Query the model deployed on PyTriton server with the batched payload to get the logits
170+
logits_batch = self._generate_tokens_logits(payload, single_prediction_token, return_logits=True)
171+
172+
# Process each result in the batch
173+
for j, logits in enumerate(logits_batch):
174+
continuation_enc = continuation_encs[j]
175+
num_ctx_tokens = num_ctx_tokens_list[j]
176+
num_cont_tokens = num_cont_tokens_list[j]
177+
178+
# In case of multiple token prediction where full context logits are returned (tasks other than mmlu),
179+
# get only logits corresponding to the continuation tokens from context logits tensor. context_logits
180+
# contains logits for all tokens in the ip prompt along with the logit for the next token prediction
181+
# after the final token in the prompt. Shape of context_logits: [1, #tokens_in_prompt+1, vocab_size].
182+
if not single_prediction_token:
183+
# Discard zero padding if any
184+
logits = logits[:, np.any(logits != 0, axis=(0, 2)), :]
185+
# Get only logits corresponding to cont tokens
186+
logits = logits[:, -num_cont_tokens:, :]
187+
# Convert logits to torch tensor to easily get logprobs wo manual implementation of log_softmax
188+
logProbs = F.log_softmax(torch.tensor(logits), dim=-1)
189+
# Convert encoded continuation tokens to torch tensor
190+
cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0)
191+
# Get the greedy token from the logits (i.e token with the highest prob)
192+
greedy_tokens = logProbs.argmax(dim=-1)
193+
# Check if all greedy_tokens match the the actual continuation tokens
194+
is_greedy = (greedy_tokens == cont_toks).all()
195+
# Get the logits corresponding to the actual continuation tokens
196+
logProbs_actual = torch.gather(logProbs, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
197+
# result is tuple of logProb of generating the continuation token and is_greedy
198+
result = (float(logProbs_actual.sum()), bool(is_greedy))
199+
# Append the result of this input in the batch to results list
200+
results.append(result)
166201

167202
return results
168203

@@ -179,7 +214,7 @@ def generate_until(self, inputs: list[Instance]):
179214
type(here loglikelihood) and other relevant args like few shot samples.
180215
"""
181216
results = []
182-
for instance in inputs:
217+
for instance in tqdm(inputs):
183218
# Access the 'arguments' attribute of the Instance which contains the input prompt string
184219
prompt = instance.arguments[0]
185220
# Create payload to query the model deployed on PyTriton server
@@ -192,7 +227,7 @@ def generate_until(self, inputs: list[Instance]):
192227
"top_k": self.top_k,
193228
}
194229
# Get the text generated by the model
195-
generated_text = self._generate_tokens_logits(payload, return_text=True)
230+
generated_text = self._generate_tokens_logits(payload)
196231

197232
results.append(generated_text)
198233

nemo/export/tensorrt_llm.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import safetensors
2828
import tensorrt_llm
2929
import torch
30+
import torch.nn.functional as F
3031
import wrapt
3132
from tensorrt_llm._utils import numpy_to_torch
3233

@@ -1117,6 +1118,19 @@ def remove_prompt_table(self, task_name: str):
11171118
return
11181119
self._prep_ptuning_table()
11191120

1121+
def _pad_logits(self, logits_tensor):
1122+
"""
1123+
Pads the logits tensor with 0's on the right
1124+
"""
1125+
padding_len = max([logit_tensor.shape[0] for logit_tensor in logits_tensor])
1126+
for i, tensor in enumerate(logits_tensor):
1127+
tensor_len = tensor.shape[0]
1128+
if tensor_len < padding_len:
1129+
padding_diff = padding_len - tensor_len
1130+
# padding_diff num of rows of zeros are added at the bottom
1131+
logits_tensor[i] = F.pad(tensor, (0, 0, 0, padding_diff), mode='constant', value=0)
1132+
return logits_tensor
1133+
11201134
@property
11211135
def get_supported_models_list(self):
11221136
"""Supported model list"""
@@ -1200,16 +1214,24 @@ def triton_infer_fn(self, **inputs: np.ndarray):
12001214
infer_input["output_context_logits"] = inputs.pop("output_context_logits")[0][0]
12011215

12021216
if generation_logits_available:
1217+
# generation_logits is a 4d torch tensor of dim [BS,1,#generated_tokens,vocab_size]
12031218
output_texts, generation_logits = self.forward(**infer_input)
1204-
# generation_logits is a 4d tensor of dim [1,1,#generated_tokens, vocab_size], return just the 3d tensor
1205-
# in output dict.
1206-
output_dict["generation_logits"] = np.array(generation_logits[0].cpu().numpy())
1219+
# convert generation_logits to numpy array. Note: from my understanding since generation_logits is
1220+
# returned as a torch tensor it won't have varying number of tokens across multiple sequences,
1221+
# likely due to TRTLLM taking care of padding hence no addtnl padding is needed.
1222+
output_dict["generation_logits"] = np.array(
1223+
[generation_logit.cpu().numpy() for generation_logit in generation_logits]
1224+
)
1225+
12071226
elif context_logits_available:
12081227
output_texts, context_logits = self.forward(**infer_input)
1209-
# convert context logits to 3d tensor from list since its avaiable as a list of tensor shaped
1210-
# [#tokens, vocab_size]
1211-
context_logits = context_logits[0].unsqueeze(0)
1212-
output_dict["context_logits"] = np.array(context_logits.cpu().numpy())
1228+
# context_logits is a list of tensors shaped [#tokens, vocab_size] and the len of the list is BS
1229+
# In case of batched inputs (i.e multiple prompts sent as a list) context_logits returned can have
1230+
# different seq_len. Following code pads them as it can otherwise error while converting to numpy array
1231+
context_logits = self._pad_logits(context_logits)
1232+
# Convert context_Logits to numpy array of shape [bS, 1, padding_len, vocab_size],.
1233+
context_logits = np.array([logit_tensor.unsqueeze(0).cpu().numpy() for logit_tensor in context_logits])
1234+
output_dict["context_logits"] = context_logits
12131235
else:
12141236
output_texts = self.forward(**infer_input)
12151237
output_dict["outputs"] = cast_output(output_texts, np.bytes_)

0 commit comments

Comments
 (0)