Skip to content

Commit 021193f

Browse files
Support Longbench (#179)
* add longbench Signed-off-by: Xinyao Wang <[email protected]> * refine readme Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xinyao Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b466abd commit 021193f

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed

evals/evaluation/longbench/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
[LongBench](https://github.com/THUDM/LongBench) is the benchmark for bilingual, multitask, and comprehensive assessment of long context understanding capabilities of large language models. LongBench includes different languages (Chinese and English) to provide a more comprehensive evaluation of the large models' multilingual capabilities on long contexts. In addition, LongBench is composed of six major categories and twenty one different tasks, covering key long-text application scenarios such as single-document QA, multi-document QA, summarization, few-shot learning, synthetic tasks and code completion.
2+
3+
In this guideline, we evaluate LongBench dataset with OPEA services on Intel hardwares.
4+
5+
# 🚀 QuickStart
6+
7+
## Installation
8+
9+
```
10+
pip install ../../../requirements.txt
11+
```
12+
13+
## Launch a LLM Service
14+
15+
To setup a LLM model, we can use [tgi-gaudi](https://github.com/huggingface/tgi-gaudi) or [OPEA microservices](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/text-generation) to launch a service.
16+
17+
### Example 1: TGI
18+
For example, the follow command is to setup the [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) model on Gaudi:
19+
20+
```
21+
model=meta-llama/Llama-2-7b-hf
22+
hf_token=YOUR_ACCESS_TOKEN
23+
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
24+
25+
docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all \
26+
-e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HF_TOKEN=$hf_token \
27+
-e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true -e USE_FLASH_ATTENTION=true \
28+
-e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice --ipc=host \
29+
ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 \
30+
--max-total-tokens 2048
31+
```
32+
33+
### Example 2: OPEA LLM
34+
You can also set up a service with OPEA microservices.
35+
36+
For example, you can refer to [native LLM](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/text-generation/native/langchain) for deployment on native Gaudi without any serving framework.
37+
38+
## Predict
39+
Please set up the environment variables first.
40+
```
41+
export ENDPOINT="http://{host_ip}:8080/generate" # your LLM serving endpoint
42+
export LLM_MODEL="meta-llama/Llama-2-7b-hf"
43+
export BACKEND="tgi" # "tgi" or "llm"
44+
export DATASET="narrativeqa" # can refer to https://github.com/THUDM/LongBench/blob/main/task.md for full list
45+
export MAX_INPUT_LENGTH=2048 # specify the max input length according to llm services
46+
```
47+
Then get the prediction on the dataset.
48+
```
49+
python pred.py \
50+
--endpoint ${ENDPOINT} \
51+
--model_name ${LLM_MODEL} \
52+
--backend ${BACKEND} \
53+
--dataset ${DATASET} \
54+
--max_input_length ${MAX_INPUT_LENGTH}
55+
```
56+
The prediction will be saved to "pred/{LLM_MODEL}/{DATASET.jsonl}".
57+
58+
## Evaluate
59+
Evaluate the prediction with LongBench metrics.
60+
```
61+
git clone https://github.com/THUDM/LongBench
62+
cd LongBench
63+
pip install -r requirements.txt
64+
python eval.py --model ${LLM_MODEL}
65+
```
66+
Then evaluated result will be saved to "pred/{LLM_MODEL}/{result.jsonl}".

evals/evaluation/longbench/pred.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
import json
6+
import os
7+
import random
8+
import time
9+
10+
import numpy as np
11+
import requests
12+
from datasets import load_dataset
13+
from requests.exceptions import RequestException
14+
from tqdm import tqdm
15+
from transformers import AutoTokenizer
16+
17+
18+
def parse_args(args=None):
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--endpoint", type=str, required=True)
21+
parser.add_argument("--model_name", type=str, required=True)
22+
parser.add_argument("--backend", type=str, default="tgi", choices=["tgi", "llm"])
23+
parser.add_argument(
24+
"--dataset", type=str, help="give dataset name, if not given, will evaluate on all datasets", default=None
25+
)
26+
parser.add_argument("--e", action="store_true", help="Evaluate on LongBench-E")
27+
parser.add_argument("--max_input_length", type=int, default=2048, help="max input length")
28+
return parser.parse_args(args)
29+
30+
31+
def get_query(backend, prompt, max_new_length):
32+
header = {"Content-Type": "application/json"}
33+
query = {
34+
"tgi": {"inputs": prompt, "parameters": {"max_new_tokens": max_new_length, "do_sample": False}},
35+
"llm": {"query": prompt, "max_tokens": max_new_length},
36+
}
37+
return header, query[backend]
38+
39+
40+
def get_pred(
41+
data, dataset_name, backend, endpoint, model_name, max_input_length, max_new_length, prompt_format, out_path
42+
):
43+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
44+
for json_obj in tqdm(data):
45+
prompt = prompt_format.format(**json_obj)
46+
47+
# truncate to fit max_input_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
48+
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
49+
if len(tokenized_prompt) > max_input_length:
50+
half = int(max_input_length / 2)
51+
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(
52+
tokenized_prompt[-half:], skip_special_tokens=True
53+
)
54+
55+
header, query = get_query(backend, prompt, max_new_length)
56+
print("query: ", query)
57+
try:
58+
start_time = time.perf_counter()
59+
res = requests.post(endpoint, headers=header, json=query)
60+
res.raise_for_status()
61+
res = res.json()
62+
cost = time.perf_counter() - start_time
63+
except RequestException as e:
64+
raise Exception(f"An unexpected error occurred: {str(e)}")
65+
66+
if backend == "tgi":
67+
result = res["generated_text"]
68+
else:
69+
result = res["text"]
70+
print("result: ", result)
71+
with open(out_path, "a", encoding="utf-8") as f:
72+
json.dump(
73+
{
74+
"pred": result,
75+
"answers": json_obj["answers"],
76+
"all_classes": json_obj["all_classes"],
77+
"length": json_obj["length"],
78+
},
79+
f,
80+
ensure_ascii=False,
81+
)
82+
f.write("\n")
83+
84+
85+
if __name__ == "__main__":
86+
args = parse_args()
87+
endpoint = args.endpoint
88+
model_name = args.model_name
89+
backend = args.backend
90+
dataset = args.dataset
91+
max_input_length = args.max_input_length
92+
93+
dataset_list = [
94+
"narrativeqa",
95+
"qasper",
96+
"multifieldqa_en",
97+
"multifieldqa_zh",
98+
"hotpotqa",
99+
"2wikimqa",
100+
"musique",
101+
"dureader",
102+
"gov_report",
103+
"qmsum",
104+
"multi_news",
105+
"vcsum",
106+
"trec",
107+
"triviaqa",
108+
"samsum",
109+
"lsht",
110+
"passage_count",
111+
"passage_retrieval_en",
112+
"passage_retrieval_zh",
113+
"lcc",
114+
"repobench-p",
115+
]
116+
datasets_e_list = [
117+
"qasper",
118+
"multifieldqa_en",
119+
"hotpotqa",
120+
"2wikimqa",
121+
"gov_report",
122+
"multi_news",
123+
"trec",
124+
"triviaqa",
125+
"samsum",
126+
"passage_count",
127+
"passage_retrieval_en",
128+
"lcc",
129+
"repobench-p",
130+
]
131+
if args.e:
132+
if dataset is not None:
133+
if dataset in datasets_e_list:
134+
datasets = [dataset]
135+
else:
136+
raise NotImplementedError(f"{dataset} are not supported in LongBench-e dataset list: {datasets_e_list}")
137+
else:
138+
datasets = datasets_e_list
139+
if not os.path.exists(f"pred_e/{model_name}"):
140+
os.makedirs(f"pred_e/{model_name}")
141+
else:
142+
datasets = [dataset] if dataset is not None else dataset_list
143+
if not os.path.exists(f"pred/{model_name}"):
144+
os.makedirs(f"pred/{model_name}")
145+
146+
for dataset in datasets:
147+
if args.e:
148+
out_path = f"pred_e/{model_name}/{dataset}.jsonl"
149+
data = load_dataset("THUDM/LongBench", f"{dataset}_e", split="test")
150+
else:
151+
out_path = f"pred/{model_name}/{dataset}.jsonl"
152+
data = load_dataset("THUDM/LongBench", dataset, split="test")
153+
154+
# we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
155+
dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
156+
dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
157+
prompt_format = dataset2prompt[dataset]
158+
max_new_length = dataset2maxlen[dataset]
159+
160+
data_all = [data_sample for data_sample in data]
161+
get_pred(
162+
data_all, dataset, backend, endpoint, model_name, max_input_length, max_new_length, prompt_format, out_path
163+
)

0 commit comments

Comments
 (0)