Skip to content

Commit 02b60b5

Browse files
Support bigcode eval for codegen v0.1 (#94)
* Support bigcode eval for codegen v0.1 Signed-off-by: Yao, Qing <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for bigcode eval UT Signed-off-by: Yao, Qing <[email protected]> --------- Signed-off-by: Yao, Qing <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4df6438 commit 02b60b5

File tree

4 files changed

+123
-1
lines changed

4 files changed

+123
-1
lines changed

evals/benchmark/stresscli/locust/codegenbench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
def getUrl():
23-
return "/v1/chatqna"
23+
return "/v1/codegen"
2424

2525

2626
def getReqData():

evals/evaluation/bigcode_evaluation_harness/accuracy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from bigcode_eval.tasks import ALL_TASKS
2323
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
2424

25+
from evals.evaluation.bigcode_evaluation_harness.api_evaluator import APIEvaluator
26+
2527

2628
def pattern_match(patterns, source_list):
2729
"""Returns a list containing all values of the source_list that
@@ -68,6 +70,13 @@ def evaluate(args):
6870
evaluator = Evaluator(accelerator, None, None, args)
6971
for task in task_names:
7072
results[task] = evaluator.evaluate(task)
73+
elif args.codegen_url:
74+
# here we generate code using an OPEA codegen API
75+
if accelerator.is_main_process:
76+
print("OPEA codegen API generation mode")
77+
evaluator = APIEvaluator(accelerator, args.model, None, args)
78+
for task in task_names:
79+
results[task] = evaluator.evaluate(task)
7180
else:
7281
# here we generate code and save it (evaluation is optional but True by default)
7382
dict_precisions = {
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import inspect
5+
import json
6+
import warnings
7+
8+
import aiohttp
9+
from bigcode_eval import tasks
10+
from bigcode_eval.evaluator import Evaluator
11+
12+
13+
class APIEvaluator(Evaluator):
14+
def generate_text(self, task_name, intermediate_generations=None):
15+
task = tasks.get_task(task_name, self.args)
16+
dataset = task.get_dataset()
17+
# if args.limit is None, use all samples
18+
# if args.limit is used, make sure args.limit_start + args.limit <= len(dataset)
19+
n_tasks = min(self.args.limit, len(dataset) - self.args.limit_start) if self.args.limit else len(dataset)
20+
print(n_tasks)
21+
# when args.limit is None
22+
# adjust n_tasks by args.limit_start to prevent out of bounds issues
23+
if not self.args.limit:
24+
n_tasks -= self.args.limit_start
25+
references = [
26+
task.get_reference(dataset[i]) for i in range(self.args.limit_start, self.args.limit_start + n_tasks)
27+
]
28+
29+
if self.args.check_references:
30+
if "get_solution" in inspect.signature(task.get_reference).parameters:
31+
solutions = [
32+
[task.get_reference(dataset[i], get_solution=True)]
33+
for i in range(self.args.limit_start, self.args.limit_start + n_tasks)
34+
]
35+
else:
36+
solutions = [[ref] for ref in references]
37+
return solutions, references
38+
39+
if intermediate_generations:
40+
curr_generations = [gen for gen in intermediate_generations if gen]
41+
n_tasks -= len(curr_generations)
42+
43+
generations = parallel_generations_by_api(
44+
task,
45+
dataset,
46+
self.accelerator,
47+
n_tasks=n_tasks,
48+
args=self.args,
49+
)
50+
51+
if len(generations[0]) > self.args.n_samples:
52+
generations = [l[: self.args.n_samples] for l in generations]
53+
warnings.warn(
54+
f"Number of tasks wasn't proportional to number of devices, we removed extra predictions to only keep nsamples={self.args.n_samples}"
55+
)
56+
return generations, references
57+
58+
59+
def parallel_generations_by_api(
60+
task,
61+
dataset,
62+
accelerator,
63+
n_tasks,
64+
args,
65+
):
66+
if args.load_generations_path:
67+
# load generated code
68+
with open(args.load_generations_path) as fp:
69+
generations = json.load(fp)
70+
if accelerator.is_main_process:
71+
print(
72+
f"generations loaded, {n_tasks} selected from {len(generations)} with {len(generations[0])} candidates"
73+
)
74+
return generations[:n_tasks]
75+
76+
if codegen_url := args.codegen_url:
77+
assert "/codegen" in codegen_url, "Only OPEA codegen compatible APIs are supported"
78+
import asyncio
79+
import os
80+
81+
import requests
82+
from tqdm.asyncio import tqdm
83+
84+
async def get_res(prompt):
85+
headers = {"Content-Type": "application/json"}
86+
data = {
87+
"messages": prompt,
88+
"max_tokens": 2048,
89+
"stream": False,
90+
"temperature": args.temperature,
91+
"top_p": args.top_p,
92+
"top_k": args.top_k,
93+
}
94+
async with aiohttp.ClientSession() as session:
95+
async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response:
96+
text = await response.text()
97+
return text
98+
99+
prompts = [task.get_prompt(doc) for doc in dataset]
100+
awaitables = [get_res(prompt=prompt) for prompt in prompts]
101+
responses = asyncio.run(tqdm.gather(*awaitables))
102+
generations = []
103+
for i, (prompt, response) in enumerate(zip(prompts, responses)):
104+
texts = [prompt + choice["message"]["content"] for choice in json.loads(response)["choices"]]
105+
generations.append([task.postprocess_generation(text, i) for text in texts])
106+
return generations

evals/evaluation/bigcode_evaluation_harness/arguments.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def setup_parser():
204204
action="store_true",
205205
help="Don't run generation but benchmark groundtruth (useful for debugging)",
206206
)
207+
parser.add_argument(
208+
"--codegen_url",
209+
default=None,
210+
help="Base URL to use OPEA Codegen API,",
211+
)
207212
return parser.parse_args()
208213

209214

@@ -253,6 +258,7 @@ def __init__(
253258
check_references=False,
254259
user_model=None, # used for pass model object
255260
tokenizer=None, # used for pass tokenizer object
261+
codegen_url=None,
256262
):
257263
self.prefix = prefix
258264
self.do_sample = do_sample
@@ -295,3 +301,4 @@ def __init__(
295301
self.check_references = check_references
296302
self.user_model = user_model
297303
self.tokenizer = tokenizer
304+
self.codegen_url = codegen_url

0 commit comments

Comments
 (0)