Skip to content

Commit 4526dfb

Browse files
authored
Merge pull request #20 from XinyaoWa/vllm_gaudi
Support model evaluation on Intel Gaudi
2 parents 96c692d + 403386c commit 4526dfb

File tree

11 files changed

+194
-6
lines changed

11 files changed

+194
-6
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ source env/bin/activate
5454
pip install -r requirements.txt
5555
```
5656

57+
For evaluating on NVIDIA GPUs, please install `flash-attn` by referring to the [flash attention repo](https://github.com/Dao-AILab/flash-attention).
58+
5759
Additionally, if you wish to use the API models, you will need to install the package corresponding to the API you wish to use
5860
```bash
5961
pip install openai # OpenAI API (GPT)
@@ -105,6 +107,21 @@ sbatch scripts/run_short_slurm.sh # 8k-64k
105107
# for the API models, note that API results may vary due to the randomness in the API calls
106108
bash scripts/run_api.sh
107109
```
110+
### Run on Intel Gaudi
111+
If you want to enable the evaluation on vLLM with Intel Gaudi, you can use the following commands:
112+
```bash
113+
## Build vllm docker image
114+
cd scripts/vllm-gaudi
115+
bash build_image.sh
116+
117+
## launch vllm container, change `LLM_MODEL_ID` and `NUM_CARDS` as your need
118+
bash launch_container.sh
119+
120+
## evalute
121+
cd ../../
122+
bash scripts/run_eval_vllm_gaudi.sh
123+
```
124+
108125
Check out the script file for more details!
109126
See [Others](#others) for the slurm scripts, easily collecting all the results, and using VLLM.
110127

arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def parse_arguments():
1919
parser.add_argument("--model_name_or_path", type=str, default=None)
2020
parser.add_argument("--use_vllm", action="store_true", help="whether to use vllm engine")
2121
parser.add_argument("--use_sglang", action="store_true", help="whether to use sglang engine")
22+
parser.add_argument("--use_tgi_or_vllm_serving", action="store_true", help="whether to use tgi or vllm serving engine")
23+
parser.add_argument("--endpoint_url", type=str,default="http://localhost:8080/v1/", help="endpoint url for tgi or vllm serving engine")
2224

2325
# data settings
2426
parser.add_argument("--datasets", type=str, default=None, help="comma separated list of dataset names")

configs/rag_vllm.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
input_max_length: 131072,131072,131072,131072
2+
datasets: kilt_nq,kilt_triviaqa,kilt_hotpotqa,kilt_popqa_3
3+
generation_max_length: 20,20,20,20
4+
test_files: data/kilt/nq-dev-multikilt_1000_k1000_dep6.jsonl,data/kilt/triviaqa-dev-multikilt_1000_k1000_dep6.jsonl,data/kilt/hotpotqa-dev-multikilt_1000_k1000_dep3.jsonl,data/kilt/popqa_test_1000_k1000_dep6.jsonl
5+
demo_files: data/kilt/nq-train-multikilt_1000_k3_dep6.jsonl,data/kilt/triviaqa-train-multikilt_1000_k3_dep6.jsonl,data/kilt/hotpotqa-train-multikilt_1000_k3_dep3.jsonl,data/kilt/popqa_test_1000_k3_dep6.jsonl
6+
use_chat_template: false
7+
max_test_samples: 100
8+
shots: 2
9+
stop_new_line: true
10+
model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
11+
output_dir: output/vllm-gaudi/Llama-3.3-70B-Instruct
12+
use_tgi_or_vllm_serving: true

configs/recall_vllm.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
input_max_length: 131072,131072,131072,131072
2+
datasets: ruler_niah_mk_2,ruler_niah_mk_3,ruler_niah_mv,json_kv
3+
generation_max_length: 50,100,50,100
4+
test_files: data/ruler/niah_multikey_2/validation_131072.jsonl,data/ruler/niah_multikey_3/validation_131072.jsonl,data/ruler/niah_multivalue/validation_131072.jsonl,data/json_kv/test_k1800_dep6.jsonl
5+
demo_files: ',,,'
6+
use_chat_template: false
7+
max_test_samples: 100
8+
shots: 2
9+
stop_new_line: false
10+
model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
11+
output_dir: output/vllm-gaudi/Llama-3.3-70B-Instruct
12+
use_tgi_or_vllm_serving: true

eval.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.utils.data import DataLoader
1313

1414
from arguments import parse_arguments
15-
from model_utils import load_LLM, OpenAIModel, AnthropicModel
15+
from model_utils import load_LLM, OpenAIModel, AnthropicModel, TgiVllmModel
1616

1717
from data import (
1818
load_data,
@@ -77,7 +77,7 @@ def run_test(args, model, dataset, test_file, demo_file):
7777
logger.info("Running generation...")
7878
start_time = time.time()
7979
# generate all outputs
80-
if isinstance(model, OpenAIModel) or isinstance(model, AnthropicModel):
80+
if (isinstance(model, OpenAIModel) or isinstance(model, AnthropicModel)) and (not isinstance(model, TgiVllmModel)):
8181
# using the batch API makes it cheaper and faster
8282
logger.info(f"Using the OpenAI/Anthropic batch API by default, if you want to use the iterative API, please change the code")
8383
all_outputs = model.generate_batch(all_inputs, batch_file=output_path+".batch")
@@ -138,8 +138,9 @@ def run_test(args, model, dataset, test_file, demo_file):
138138
if args.debug:
139139
import pdb; pdb.set_trace()
140140

141-
mem_usage = sum([torch.cuda.max_memory_allocated(i) for i in range(torch.cuda.device_count())])
142-
logger.info(f"Memory usage: {mem_usage/1000**3:.02f} GB")
141+
if not args.no_cuda:
142+
mem_usage = sum([torch.cuda.max_memory_allocated(i) for i in range(torch.cuda.device_count())])
143+
logger.info(f"Memory usage: {mem_usage/1000**3:.02f} GB")
143144
logger.info(f"Total time: {end_time - start_time:.02f} s")
144145
logger.info(f"Throughput: {len(results) / (end_time - start_time):.02f} samples/s")
145146

@@ -162,9 +163,10 @@ def run_test(args, model, dataset, test_file, demo_file):
162163
"data": results,
163164
"metrics": metrics,
164165
"averaged_metrics": averaged_metrics,
165-
"memory_usage": mem_usage,
166166
"throughput": len(results) / (end_time - start_time),
167167
}
168+
if not args.no_cuda:
169+
output["memory_usage"] = mem_usage
168170

169171
if args.output_dir is not None:
170172
with open(output_path, "w") as f:

model_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,64 @@ def generate_batch(self, inputs=None, prompt=None, **kwargs):
326326

327327
return outputs
328328

329+
class TgiVllmModel(OpenAIModel):
330+
def __init__(
331+
self,
332+
model_name,
333+
temperature=0.9,
334+
top_p=0.9,
335+
max_length=32768,
336+
generation_max_length=2048,
337+
generation_min_length=0,
338+
do_sample=True,
339+
stop_newline=False,
340+
use_chat_template=True,
341+
system_message=None,
342+
seed=42,
343+
**kwargs
344+
):
345+
self.model_name = model_name
346+
self.temperature = temperature
347+
self.top_p = top_p
348+
self.max_length = max_length
349+
self.generation_max_length = generation_max_length
350+
self.generation_min_length = generation_min_length
351+
self.do_sample = do_sample
352+
self.use_chat_template = use_chat_template
353+
self.system_message = system_message
354+
self.stops = None
355+
if stop_newline:
356+
self.stops = ["\n", "\n\n"]
357+
358+
from openai import OpenAI
359+
from transformers import AutoTokenizer
360+
361+
endpoint_url = kwargs["endpoint_url"]
362+
print(f"** Endpoint URL: {endpoint_url}")
363+
364+
self.model = OpenAI(
365+
base_url=endpoint_url,
366+
api_key="EMPTY_KEY"
367+
)
368+
self.model_name = model_name
369+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
370+
self.seed = seed
371+
self.API_MAX_LENGTH = float('inf')
372+
373+
def generate_batch(self, inputs=None, prompt=None, **kwargs):
374+
if inputs is None:
375+
inputs = [None for _ in prompt]
376+
else:
377+
prompt = [None for _ in inputs]
378+
379+
# we don't support kwargs here for now
380+
if len(kwargs) > 0:
381+
logger.warning("kwargs are not supported for batch generation")
382+
# use thread_map instead of process_map since the bottleneck is the api call
383+
outputs = thread_map(self.generate, inputs, prompt, max_workers=32)
384+
385+
return outputs
386+
329387

330388
class AnthropicModel(LLM):
331389
def __init__(
@@ -1203,6 +1261,10 @@ def load_LLM(args):
12031261
elif args.use_vllm:
12041262
model_cls = VLLMModel
12051263
kwargs['seed'] = args.seed
1264+
elif args.use_tgi_or_vllm_serving:
1265+
model_cls = TgiVllmModel
1266+
kwargs['seed'] = args.seed
1267+
kwargs["endpoint_url"] = args.endpoint_url
12061268
elif args.use_sglang:
12071269
model_cls = SGLangModel
12081270
kwargs['seed'] = args.seed

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,5 @@ datasets
66
transformers
77
accelerate
88
sentencepiece
9-
flash-attn
109
pytrec_eval
1110
rouge_score

scripts/run_eval_vllm_gaudi.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export host_ip=$(hostname -I | awk '{print $1}')
2+
export LLM_ENDPOINT_PORT=8010
3+
export DATA_PATH="~/.cache/huggingface"
4+
export LLM_ENDPOINT="http://${host_ip}:${LLM_ENDPOINT_PORT}/v1"
5+
export HF_HOME=$DATA_PATH
6+
7+
for task in "recall" "rag"; do
8+
python eval.py --config configs/${task}_vllm.yaml --endpoint_url $LLM_ENDPOINT --overwrite --no_cuda
9+
done

scripts/vllm-gaudi/build_image.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
export TAG="helmet"
5+
echo "Building the vllm-gaudi docker images"
6+
git clone https://github.com/HabanaAI/vllm-fork.git
7+
cd ./vllm-fork
8+
git checkout v0.6.6.post1+Gaudi-1.20.0 #habana_main
9+
10+
docker build --no-cache -f Dockerfile.hpu -t ${REGISTRY:-opea}/vllm-gaudi:${TAG:-latest} --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy
11+
if [ $? -ne 0 ]; then
12+
echo "vllm-gaudi failed"
13+
exit 1
14+
else
15+
echo "vllm-gaudi successful"
16+
fi
17+
18+

scripts/vllm-gaudi/compose.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
services:
5+
vllm-gaudi-server:
6+
image: ${REGISTRY:-opea}/vllm-gaudi:${TAG:-latest}
7+
container_name: vllm-gaudi-server
8+
ports:
9+
- ${LLM_ENDPOINT_PORT:-8008}:80
10+
volumes:
11+
- "${DATA_PATH:-./data}:/data"
12+
environment:
13+
no_proxy: ${no_proxy}
14+
http_proxy: ${http_proxy}
15+
https_proxy: ${https_proxy}
16+
HF_TOKEN: ${HF_TOKEN}
17+
HF_HOME: "/data"
18+
HABANA_VISIBLE_DEVICES: all
19+
OMPI_MCA_btl_vader_single_copy_mechanism: none
20+
PT_HPU_ENABLE_LAZY_COLLECTIVES: true
21+
LLM_MODEL_ID: ${LLM_MODEL_ID}
22+
VLLM_TORCH_PROFILER_DIR: "/mnt"
23+
host_ip: ${host_ip}
24+
LLM_ENDPOINT_PORT: ${LLM_ENDPOINT_PORT}
25+
VLLM_SKIP_WARMUP: ${VLLM_SKIP_WARMUP:-true}
26+
VLLM_ALLOW_LONG_MAX_MODEL_LEN: 1
27+
MAX_MODEL_LEN: ${MAX_MODEL_LEN:-131072}
28+
MAX_SEQ_LEN_TO_CAPTURE: ${MAX_MODEL_LEN:-131072}
29+
NUM_CARDS: ${NUM_CARDS:-1}
30+
runtime: habana
31+
cap_add:
32+
- SYS_NICE
33+
ipc: host
34+
healthcheck:
35+
test: ["CMD-SHELL", "curl -f http://${host_ip}:${LLM_ENDPOINT_PORT}/health || exit 1"]
36+
interval: 10s
37+
timeout: 10s
38+
retries: 150
39+
command: --model $LLM_MODEL_ID --tensor-parallel-size ${NUM_CARDS} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 256 --max-seq-len-to-capture ${MAX_MODEL_LEN} --max-model-len ${MAX_MODEL_LEN}

0 commit comments

Comments
 (0)