Skip to content

Commit 4df6438

Browse files
Add FaqGen Accuracy scripts & Refine Ragas (#91)
* fix ragas to align latest code Signed-off-by: Xinyao Wang <[email protected]> * add FaqGen Accuracy scripts Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug Signed-off-by: Xinyao Wang <[email protected]> --------- Signed-off-by: Xinyao Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 514a6d6 commit 4df6438

File tree

8 files changed

+227
-14
lines changed

8 files changed

+227
-14
lines changed

evals/metrics/ragas/ragas.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ def __init__(
3232
self.embeddings = embeddings
3333
self.metrics = metrics
3434
self.validated_list = [
35-
"answer_relevancy",
36-
"faithfulness",
3735
"answer_correctness",
36+
"answer_relevancy",
3837
"answer_similarity",
3938
"context_precision",
40-
"context_relevancy",
4139
"context_recall",
40+
"faithfulness",
41+
"context_utilization",
42+
"reference_free_rubrics_score",
4243
]
4344

4445
async def a_measure(self, test_case: Dict):
@@ -55,8 +56,9 @@ def measure(self, test_case: Dict):
5556
answer_similarity,
5657
context_precision,
5758
context_recall,
58-
context_relevancy,
59+
context_utilization,
5960
faithfulness,
61+
reference_free_rubrics_score,
6062
)
6163

6264
except ModuleNotFoundError:
@@ -67,8 +69,14 @@ def measure(self, test_case: Dict):
6769
except ModuleNotFoundError:
6870
raise ModuleNotFoundError("Please install dataset")
6971
self.metrics_instance = {
72+
"answer_correctness": answer_correctness,
7073
"answer_relevancy": answer_relevancy,
74+
"answer_similarity": answer_similarity,
75+
"context_precision": context_precision,
76+
"context_recall": context_recall,
7177
"faithfulness": faithfulness,
78+
"context_utilization": context_utilization,
79+
"reference_free_rubrics_score": reference_free_rubrics_score,
7280
}
7381

7482
# Set LLM model
@@ -101,7 +109,7 @@ def measure(self, test_case: Dict):
101109
else:
102110
if metric == "answer_relevancy" and self.embeddings is None:
103111
raise ValueError("answer_relevancy metric need provide embeddings model.")
104-
tmp_metrics.append(metric)
112+
tmp_metrics.append(self.metrics_instance[metric])
105113
self.metrics = tmp_metrics
106114
else:
107115
self.metrics = [
@@ -110,15 +118,14 @@ def measure(self, test_case: Dict):
110118
answer_correctness,
111119
answer_similarity,
112120
context_precision,
113-
context_relevancy,
114121
context_recall,
115122
]
116123

117124
data = {
118-
"question": test_case["input"],
119-
"contexts": test_case["retrieval_context"],
120-
"answer": test_case["actual_output"],
121-
"ground_truth": test_case["expected_output"],
125+
"question": test_case["question"],
126+
"contexts": test_case["contexts"],
127+
"answer": test_case["answer"],
128+
"ground_truth": test_case["ground_truth"],
122129
}
123130
dataset = Dataset.from_dict(data)
124131

examples/FaqGen/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
## Dataset
2+
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records.
3+
4+
First download dataset and put at "./data".
5+
6+
Extract unique "context" columns, which will be save to 'data/sqv2_context.json':
7+
```
8+
python get_context.py
9+
```
10+
11+
## Generate FAQs
12+
13+
### Launch FaQGen microservice
14+
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint.
15+
```
16+
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen"
17+
```
18+
19+
### Generate FAQs with microservice
20+
Use the microservice endpoint to generate FAQs for dataset.
21+
```
22+
python generate_FAQ.py
23+
```
24+
25+
Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'.
26+
```
27+
python post_process_FAQ.py
28+
```
29+
30+
## Evaluate with Ragas
31+
32+
### Launch TGI service
33+
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi.
34+
```
35+
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token"
36+
bash launch_tgi.sh
37+
```
38+
Get the endpoint:
39+
```
40+
export LLM_ENDPOINT = "http://${ip_address}:8082"
41+
```
42+
43+
Verify the service:
44+
```bash
45+
curl http://${ip_address}:8082/generate \
46+
-X POST \
47+
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \
48+
-H 'Content-Type: application/json'
49+
```
50+
51+
### Evaluate
52+
evaluate the performance with the LLM:
53+
```
54+
python evaluate.py
55+
```
56+
57+
### Performance Result
58+
Here is the tested result for your reference
59+
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score |
60+
| ---- | ---- |---- |---- |
61+
| 0.7191 | 0.9681 | 0.8964 | 4.4125|

examples/FaqGen/evaluate.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import os
6+
7+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8+
9+
from evals.metrics.ragas import RagasMetric
10+
11+
llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082")
12+
13+
f = open("data/sqv2_context.json", "r")
14+
sqv2_context = json.load(f)
15+
16+
f = open("data/sqv2_faq.json", "r")
17+
sqv2_faq = json.load(f)
18+
19+
templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
20+
TEXT: {text}
21+
Do not use any prefix or suffix to the FAQ.
22+
"""
23+
24+
number = 1204
25+
question = []
26+
answer = []
27+
ground_truth = ["None"] * number
28+
contexts = []
29+
for i in range(number):
30+
inputs = sqv2_context[str(i)]
31+
inputs_faq = templ.format_map({"text": inputs})
32+
actual_output = sqv2_faq[str(i)]
33+
34+
question.append(inputs_faq)
35+
answer.append(actual_output)
36+
contexts.append([inputs_faq])
37+
38+
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
39+
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"]
40+
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq)
41+
42+
test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts}
43+
44+
metric.measure(test_case)
45+
print(metric.score)

examples/FaqGen/generate_FAQ.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import os
6+
import time
7+
8+
import requests
9+
10+
llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen")
11+
12+
f = open("data/sqv2_context.json", "r")
13+
sqv2_context = json.load(f)
14+
15+
start_time = time.time()
16+
headers = {"Content-Type": "application/json"}
17+
for i in range(1204):
18+
start_time_tmp = time.time()
19+
print(i)
20+
inputs = sqv2_context[str(i)]
21+
data = {"query": inputs, "max_new_tokens": 128}
22+
response = requests.post(llm_endpoint, json=data, headers=headers)
23+
f = open(f"data/result/sqv2_faq_{i}", "w")
24+
f.write(inputs)
25+
f.write(str(response.content, encoding="utf-8"))
26+
f.close()
27+
print(f"Cost {time.time()-start_time_tmp} seconds")
28+
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n")

examples/FaqGen/get_context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import os
6+
7+
import pandas as pd
8+
9+
data_path = "./data"
10+
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet"))
11+
sq_context = list(data["context"].unique())
12+
sq_context_d = dict()
13+
for i in range(len(sq_context)):
14+
sq_context_d[i] = sq_context[i]
15+
16+
with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile:
17+
json.dump(sq_context_d, outfile)

examples/FaqGen/launch_tgi.sh

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
max_input_tokens=3072
5+
max_total_tokens=4096
6+
port_number=8082
7+
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
8+
volume="./data"
9+
docker run -it --rm \
10+
--name="tgi_Mixtral" \
11+
-p $port_number:80 \
12+
-v $volume:/data \
13+
--runtime=habana \
14+
--restart always \
15+
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
16+
-e HABANA_VISIBLE_DEVICES=all \
17+
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
18+
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
19+
--cap-add=sys_nice \
20+
--ipc=host \
21+
-e HTTPS_PROXY=$https_proxy \
22+
-e HTTP_PROXY=$https_proxy \
23+
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
24+
--model-id $model_name \
25+
--max-input-tokens $max_input_tokens \
26+
--max-total-tokens $max_total_tokens \
27+
--sharded true \
28+
--num-shard 2

examples/FaqGen/post_process_FAQ.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
6+
faq_dict = {}
7+
fails = []
8+
for i in range(1204):
9+
data = open(f"data/result/sqv2_faq_{i}", "r").readlines()
10+
result = data[-6][6:]
11+
# print(result)
12+
if "LLMChain/final_output" not in result:
13+
print(f"error1: fail for {i}")
14+
fails.append(i)
15+
continue
16+
try:
17+
result2 = json.loads(result)
18+
result3 = result2["ops"][0]["value"]["text"]
19+
faq_dict[str(i)] = result3
20+
except:
21+
print(f"error2: fail for {i}")
22+
fails.append(i)
23+
continue
24+
with open("data/sqv2_faq.json", "w") as outfile:
25+
json.dump(faq_dict, outfile)
26+
print("Failure index:")
27+
print(fails)

tests/test_ragas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def test_ragas(self):
2626
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
2727
metric = RagasMetric(threshold=0.5, model="http://localhost:8008", embeddings=embeddings)
2828
test_case = {
29-
"input": ["What if these shoes don't fit?"],
30-
"actual_output": [actual_output],
31-
"expected_output": [expected_output],
32-
"retrieval_context": [retrieval_context],
29+
"question": ["What if these shoes don't fit?"],
30+
"answer": [actual_output],
31+
"ground_truth": [expected_output],
32+
"contexts": [retrieval_context],
3333
}
3434

3535
metric.measure(test_case)

0 commit comments

Comments
 (0)