Skip to content

Commit 7d9265f

Browse files
XinyuYe-Intelpre-commit-ci[bot]lkk12014402
authored
Support rerank model finetuning (#578)
* support rerank model finetuning. Signed-off-by: Ye, Xinyu <[email protected]> * adapt rerank model to transformers' scheme. Signed-off-by: Ye, Xinyu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo. Signed-off-by: Ye, Xinyu <[email protected]> * refined readme. Signed-off-by: Ye, Xinyu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify command due to api change. Signed-off-by: Ye, Xinyu <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ye, Xinyu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lkk <[email protected]>
1 parent 89197e5 commit 7d9265f

File tree

5 files changed

+237
-68
lines changed

5 files changed

+237
-68
lines changed

comps/finetuning/README.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# LLM Fine-tuning Microservice
1+
# Fine-tuning Microservice
22

3-
LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task.
3+
Fine-tuning microservice involves adapting a model to a specific task or dataset to improve its performance on that task, we currently supported instruction tuning for LLMs, finetuning for reranking and embedding models.
44

55
## 🚀1. Start Microservice with Python (Optional 1)
66

@@ -86,14 +86,22 @@ docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p 8015:8015 -e OMPI_M
8686

8787
## 🚀3. Consume Finetuning Service
8888

89-
### 3.1 Create fine-tuning job
89+
## 3.1 Upload a training file
9090

91-
Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model:
91+
Download a training file, such as `alpaca_data.json` for instruction tuning and upload it to the server with below command, this file can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json):
9292

9393
```bash
9494
# upload a training file
9595
curl http://${your_ip}:8015/v1/files -X POST -H "Content-Type: multipart/form-data" -F "file=@./alpaca_data.json" -F purpose="fine-tune"
96+
```
97+
98+
For reranking and embedding models finetuning, the training file [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/toy_finetune_data.jsonl) is an toy example.
99+
100+
## 3.2 Create fine-tuning job
96101

102+
After a training file like `alpaca_data.json` is uploaded, use the following command to launch a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model:
103+
104+
```bash
97105
# create a finetuning job
98106
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
99107
-X POST \
@@ -102,22 +110,41 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
102110
"training_file": "alpaca_data.json",
103111
"model": "meta-llama/Llama-2-7b-chat-hf"
104112
}'
113+
```
105114

115+
Use the following command to launch a finetuning job for reranking model finetuning, such as `BAAI/bge-reranker-large`:
116+
117+
```bash
118+
# create a finetuning job
119+
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
120+
-X POST \
121+
-H "Content-Type: application/json" \
122+
-d '{
123+
"training_file": "toy_finetune_data.jsonl",
124+
"model": "BAAI/bge-reranker-large",
125+
"General":{
126+
"task":"rerank",
127+
"lora_config":null
128+
}
129+
}'
130+
```
131+
132+
## 3.3 Manage fine-tuning job
133+
134+
Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job.
135+
136+
```bash
106137
# list finetuning jobs
107-
curl http://${your_ip}:8015/v1/fine_tuning/jobs -X GET
138+
curl http://${your_ip}:8015/v1/fine_tuning/jobs -X GET
108139

109140
# retrieve one finetuning job
110-
curl http://localhost:8015/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{
111-
"fine_tuning_job_id": ${fine_tuning_job_id}}'
141+
curl http://localhost:8015/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}'
112142

113143
# cancel one finetuning job
114-
curl http://localhost:8015/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{
115-
"fine_tuning_job_id": ${fine_tuning_job_id}}'
144+
curl http://localhost:8015/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}'
116145

117146
# list checkpoints of a finetuning job
118147
curl http://${your_ip}:8015/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}'
119-
120-
121148
```
122149

123150
## 🚀4. Descriptions for Finetuning parameters

comps/finetuning/finetune_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,18 @@ class GeneralConfig(BaseModel):
4848
config: LoadConfig = LoadConfig()
4949
lora_config: Optional[LoraConfig] = LoraConfig()
5050
enable_gradient_checkpointing: bool = False
51+
task: str = "instruction_tuning"
5152

5253
@validator("report_to")
5354
def check_report_to(cls, v: str):
5455
assert v in ["none", "tensorboard"]
5556
return v
5657

58+
@validator("task")
59+
def check_task(cls, v: str):
60+
assert v in ["instruction_tuning", "rerank", "embedding"]
61+
return v
62+
5763

5864
class DatasetConfig(BaseModel):
5965
train_file: str = None
@@ -74,6 +80,7 @@ class DatasetConfig(BaseModel):
7480
data_preprocess_type: str = "neural_chat"
7581
max_train_samples: int = 0
7682
max_eval_samples: int = 0
83+
train_group_size: int = 8
7784

7885

7986
class RayResourceConfig(BaseModel):

comps/finetuning/llm_on_ray/finetune/data_process.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44
# Copyright 2023 The LLM-on-Ray Authors.
55

66
import copy
7+
import math
8+
import random
79
import re
10+
from dataclasses import dataclass
811
from itertools import chain
12+
from typing import Dict, List, Tuple
913

1014
import torch
15+
from torch.utils.data import Dataset
16+
from transformers import BatchEncoding, DataCollatorWithPadding
1117

1218
IGNORE_INDEX = -100
1319

@@ -194,3 +200,49 @@ def tokenize(self, examples):
194200
examples["labels"].append(labels)
195201
examples["attention_mask"].append(results["attention_mask"])
196202
return examples
203+
204+
205+
class TrainDatasetForCE(Dataset):
206+
def __init__(self, dataset, args, tokenizer):
207+
self.dataset = dataset
208+
self.tokenizer = tokenizer
209+
self.args = args
210+
self.total_len = len(self.dataset)
211+
212+
def create_one_example(self, qry_encoding: str, doc_encoding: str):
213+
item = self.tokenizer.encode_plus(
214+
qry_encoding,
215+
doc_encoding,
216+
truncation=True,
217+
max_length=self.args.get("max_length", 512),
218+
padding=False,
219+
)
220+
return item
221+
222+
def __len__(self):
223+
return self.total_len
224+
225+
def __getitem__(self, item) -> List[BatchEncoding]:
226+
query = self.dataset[item]["query"]
227+
pos = random.choice(self.dataset[item]["pos"])
228+
train_group_size = self.args.get("train_group_size", 8)
229+
if len(self.dataset[item]["neg"]) < train_group_size - 1:
230+
num = math.ceil((train_group_size - 1) / len(self.dataset[item]["neg"]))
231+
negs = random.sample(self.dataset[item]["neg"] * num, train_group_size - 1)
232+
else:
233+
negs = random.sample(self.dataset[item]["neg"], train_group_size - 1)
234+
235+
batch_data = []
236+
batch_data.append(self.create_one_example(query, pos))
237+
for neg in negs:
238+
batch_data.append(self.create_one_example(query, neg))
239+
240+
return batch_data
241+
242+
243+
@dataclass
244+
class GroupCollator(DataCollatorWithPadding):
245+
def __call__(self, features) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
246+
if isinstance(features[0], list):
247+
features = sum(features, [])
248+
return super().__call__(features)

comps/finetuning/llm_on_ray/finetune/finetune.py

Lines changed: 89 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
from ray.air import FailureConfig, RunConfig
2323
from ray.air.config import ScalingConfig
2424
from ray.train.torch import TorchTrainer
25+
from transformers import Trainer, TrainingArguments
2526

2627
from comps import CustomLogger
2728
from comps.finetuning.finetune_config import FinetuneConfig
2829
from comps.finetuning.llm_on_ray import common
29-
from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor
30+
from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor, GroupCollator, TrainDatasetForCE
31+
from comps.finetuning.llm_on_ray.finetune.modeling import CrossEncoder
3032

3133
logger = CustomLogger("llm_on_ray/finetune")
3234

@@ -186,74 +188,106 @@ def local_load(name, **load_config):
186188

187189

188190
def tokenize_dataset(config: Dict, tokenizer, dataset):
189-
group = config["Dataset"].get("group", True)
190-
block_size = config["Dataset"].get("block_size", 512)
191-
tokenizer.pad_token = tokenizer.eos_token
192-
193-
processor = DataProcessor(config, tokenizer)
194-
195-
for key in dataset:
196-
prompts = processor.make_prompt(dataset[key])
197-
dataset[key] = datasets.Dataset.from_dict(prompts)
198-
199-
column_names = list(dataset["train"].features)
200-
tokenize_fn = (
201-
processor.tokenize_by_neural_chat
202-
if config["Dataset"].get("data_preprocess_type", "") == "neural_chat"
203-
else processor.tokenize
204-
)
205-
206-
tokenized_dataset = dataset.map(
207-
tokenize_fn,
208-
remove_columns=column_names,
209-
batched=True,
210-
load_from_cache_file=False,
211-
desc="Tokenize dataset",
212-
)
213-
214-
if group:
215-
216-
def group_texts(examples):
217-
# Concatenate all texts.
218-
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
219-
total_length = len(concatenated_examples[list(examples.keys())[0]])
220-
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
221-
# customize this part to your needs.
222-
if total_length >= block_size:
223-
total_length = (total_length // block_size) * block_size
224-
# Split by chunks of max_len.
225-
result = {
226-
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
227-
for k, t in concatenated_examples.items()
228-
}
229-
return result
191+
task = config["General"].get("task", "instruction_tuning")
192+
if task == "instruction_tuning":
193+
group = config["Dataset"].get("group", True)
194+
block_size = config["Dataset"].get("block_size", 512)
195+
tokenizer.pad_token = tokenizer.eos_token
196+
197+
processor = DataProcessor(config, tokenizer)
198+
199+
for key in dataset:
200+
prompts = processor.make_prompt(dataset[key])
201+
dataset[key] = datasets.Dataset.from_dict(prompts)
202+
203+
column_names = list(dataset["train"].features)
204+
tokenize_fn = (
205+
processor.tokenize_by_neural_chat
206+
if config["Dataset"].get("data_preprocess_type", "") == "neural_chat"
207+
else processor.tokenize
208+
)
230209

231-
tokenized_dataset = tokenized_dataset.map(
232-
group_texts,
210+
tokenized_dataset = dataset.map(
211+
tokenize_fn,
212+
remove_columns=column_names,
233213
batched=True,
234214
load_from_cache_file=False,
235-
desc=f"Grouping texts in chunks of {block_size}",
215+
desc="Tokenize dataset",
236216
)
237217

238-
return tokenized_dataset
218+
if group:
219+
220+
def group_texts(examples):
221+
# Concatenate all texts.
222+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
223+
total_length = len(concatenated_examples[list(examples.keys())[0]])
224+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
225+
# customize this part to your needs.
226+
if total_length >= block_size:
227+
total_length = (total_length // block_size) * block_size
228+
# Split by chunks of max_len.
229+
result = {
230+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
231+
for k, t in concatenated_examples.items()
232+
}
233+
return result
234+
235+
tokenized_dataset = tokenized_dataset.map(
236+
group_texts,
237+
batched=True,
238+
load_from_cache_file=False,
239+
desc=f"Grouping texts in chunks of {block_size}",
240+
)
241+
242+
return tokenized_dataset
243+
elif task == "rerank":
244+
dataset["train"] = TrainDatasetForCE(dataset["train"], config["Dataset"], tokenizer)
245+
return dataset
246+
elif task == "embedding":
247+
pass
248+
else:
249+
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
239250

240251

241252
def prepare_data_collator(config: Dict, tokenizer):
242-
return transformers.DataCollatorForLanguageModeling(
243-
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
244-
)
253+
task = config["General"].get("task", "instruction_tuning")
254+
if task == "instruction_tuning":
255+
return transformers.DataCollatorForLanguageModeling(
256+
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
257+
)
258+
elif task == "rerank":
259+
return GroupCollator(tokenizer)
260+
elif task == "embedding":
261+
pass
262+
else:
263+
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
245264

246265

247266
def load_model(config: Dict):
248267
model_name = config["General"]["base_model"]
249268
model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no"))
250269
model_config = config["General"].get("config", {})
251-
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config)
252-
253-
lora_config = config["General"].get("lora_config", None)
254-
if lora_config:
255-
peft_config = LoraConfig(**lora_config)
256-
model = get_peft_model(model, peft_config)
270+
task = config["General"].get("task", "instruction_tuning")
271+
training_args = convert_to_training_args(TrainingArguments, config)
272+
if task == "instruction_tuning":
273+
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config)
274+
275+
lora_config = config["General"].get("lora_config", None)
276+
if lora_config:
277+
peft_config = LoraConfig(**lora_config)
278+
model = get_peft_model(model, peft_config)
279+
elif task == "rerank":
280+
model = CrossEncoder.from_pretrained(
281+
config["Dataset"],
282+
training_args,
283+
model_name,
284+
from_tf=bool(".ckpt" in model_name),
285+
config=model_config,
286+
)
287+
elif task == "embedding":
288+
pass
289+
else:
290+
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
257291

258292
egc = config["General"].get("enable_gradient_checkpointing", False)
259293
if egc:
@@ -269,8 +303,6 @@ def load_model(config: Dict):
269303
def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator):
270304
device = config["Training"]["device"]
271305
if device in ["cpu", "gpu"]:
272-
from transformers import Trainer, TrainingArguments
273-
274306
training_args = convert_to_training_args(TrainingArguments, config)
275307
trainer = Trainer(
276308
model=model,

0 commit comments

Comments
 (0)