|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import inspect |
| 5 | +import json |
| 6 | +import warnings |
| 7 | + |
| 8 | +import aiohttp |
| 9 | +from bigcode_eval import tasks |
| 10 | +from bigcode_eval.evaluator import Evaluator |
| 11 | + |
| 12 | + |
| 13 | +class APIEvaluator(Evaluator): |
| 14 | + def generate_text(self, task_name, intermediate_generations=None): |
| 15 | + task = tasks.get_task(task_name, self.args) |
| 16 | + dataset = task.get_dataset() |
| 17 | + # if args.limit is None, use all samples |
| 18 | + # if args.limit is used, make sure args.limit_start + args.limit <= len(dataset) |
| 19 | + n_tasks = min(self.args.limit, len(dataset) - self.args.limit_start) if self.args.limit else len(dataset) |
| 20 | + print(n_tasks) |
| 21 | + # when args.limit is None |
| 22 | + # adjust n_tasks by args.limit_start to prevent out of bounds issues |
| 23 | + if not self.args.limit: |
| 24 | + n_tasks -= self.args.limit_start |
| 25 | + references = [ |
| 26 | + task.get_reference(dataset[i]) for i in range(self.args.limit_start, self.args.limit_start + n_tasks) |
| 27 | + ] |
| 28 | + |
| 29 | + if self.args.check_references: |
| 30 | + if "get_solution" in inspect.signature(task.get_reference).parameters: |
| 31 | + solutions = [ |
| 32 | + [task.get_reference(dataset[i], get_solution=True)] |
| 33 | + for i in range(self.args.limit_start, self.args.limit_start + n_tasks) |
| 34 | + ] |
| 35 | + else: |
| 36 | + solutions = [[ref] for ref in references] |
| 37 | + return solutions, references |
| 38 | + |
| 39 | + if intermediate_generations: |
| 40 | + curr_generations = [gen for gen in intermediate_generations if gen] |
| 41 | + n_tasks -= len(curr_generations) |
| 42 | + |
| 43 | + generations = parallel_generations_by_api( |
| 44 | + task, |
| 45 | + dataset, |
| 46 | + self.accelerator, |
| 47 | + n_tasks=n_tasks, |
| 48 | + args=self.args, |
| 49 | + ) |
| 50 | + |
| 51 | + if len(generations[0]) > self.args.n_samples: |
| 52 | + generations = [l[: self.args.n_samples] for l in generations] |
| 53 | + warnings.warn( |
| 54 | + f"Number of tasks wasn't proportional to number of devices, we removed extra predictions to only keep nsamples={self.args.n_samples}" |
| 55 | + ) |
| 56 | + return generations, references |
| 57 | + |
| 58 | + |
| 59 | +def parallel_generations_by_api( |
| 60 | + task, |
| 61 | + dataset, |
| 62 | + accelerator, |
| 63 | + n_tasks, |
| 64 | + args, |
| 65 | +): |
| 66 | + if args.load_generations_path: |
| 67 | + # load generated code |
| 68 | + with open(args.load_generations_path) as fp: |
| 69 | + generations = json.load(fp) |
| 70 | + if accelerator.is_main_process: |
| 71 | + print( |
| 72 | + f"generations loaded, {n_tasks} selected from {len(generations)} with {len(generations[0])} candidates" |
| 73 | + ) |
| 74 | + return generations[:n_tasks] |
| 75 | + |
| 76 | + if codegen_url := args.codegen_url: |
| 77 | + assert "/codegen" in codegen_url, "Only OPEA codegen compatible APIs are supported" |
| 78 | + import asyncio |
| 79 | + import os |
| 80 | + |
| 81 | + import requests |
| 82 | + from tqdm.asyncio import tqdm |
| 83 | + |
| 84 | + async def get_res(prompt): |
| 85 | + headers = {"Content-Type": "application/json"} |
| 86 | + data = { |
| 87 | + "messages": prompt, |
| 88 | + "max_tokens": 2048, |
| 89 | + "stream": False, |
| 90 | + "temperature": args.temperature, |
| 91 | + "top_p": args.top_p, |
| 92 | + "top_k": args.top_k, |
| 93 | + } |
| 94 | + async with aiohttp.ClientSession() as session: |
| 95 | + async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response: |
| 96 | + text = await response.text() |
| 97 | + return text |
| 98 | + |
| 99 | + prompts = [task.get_prompt(doc) for doc in dataset] |
| 100 | + awaitables = [get_res(prompt=prompt) for prompt in prompts] |
| 101 | + responses = asyncio.run(tqdm.gather(*awaitables)) |
| 102 | + generations = [] |
| 103 | + for i, (prompt, response) in enumerate(zip(prompts, responses)): |
| 104 | + texts = [prompt + choice["message"]["content"] for choice in json.loads(response)["choices"]] |
| 105 | + generations.append([task.postprocess_generation(text, i) for text in texts]) |
| 106 | + return generations |
0 commit comments