Skip to content

Commit 7e1a2e5

Browse files
authored
enable embedding finetuning (#639)
* add embedding finetuning. * enhance training on hpu. * add embedding training ut.
1 parent 99be1bd commit 7e1a2e5

File tree

6 files changed

+470
-22
lines changed

6 files changed

+470
-22
lines changed

comps/finetuning/README.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ For reranking and embedding models finetuning, the training file [toy_finetune_d
9999

100100
## 3.2 Create fine-tuning job
101101

102+
### 3.2.1 Instruction Tuning
103+
102104
After a training file like `alpaca_data.json` is uploaded, use the following command to launch a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model:
103105

104106
```bash
@@ -112,6 +114,8 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
112114
}'
113115
```
114116

117+
### 3.2.2 Reranking Model Training
118+
115119
Use the following command to launch a finetuning job for reranking model finetuning, such as `BAAI/bge-reranker-large`:
116120

117121
```bash
@@ -129,6 +133,46 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
129133
}'
130134
```
131135

136+
### 3.2.3 Embedding Model Training
137+
138+
Use the following command to launch a finetuning job for embedding model finetuning, such as `BAAI/bge-base-en-v1.5`:
139+
140+
```bash
141+
# create a finetuning job
142+
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
143+
-X POST \
144+
-H "Content-Type: application/json" \
145+
-d '{
146+
"training_file": "toy_finetune_data.jsonl",
147+
"model": "BAAI/bge-base-en-v1.5",
148+
"General":{
149+
"task":"embedding",
150+
"lora_config":null
151+
}
152+
}'
153+
154+
155+
# If training on Gaudi2, we need to set --padding "max_length" and the value of --query_max_len is same with --passage_max_len for static shape during training. For example:
156+
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
157+
-X POST \
158+
-H "Content-Type: application/json" \
159+
-d '{
160+
"training_file": "toy_finetune_data.jsonl",
161+
"model": "BAAI/bge-base-en-v1.5",
162+
"General":{
163+
"task":"embedding",
164+
"lora_config":null
165+
},
166+
"Dataset":{
167+
"query_max_len":128,
168+
"passage_max_len":128,
169+
"padding":"max_length"
170+
}
171+
}'
172+
173+
174+
```
175+
132176
## 3.3 Manage fine-tuning job
133177

134178
Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job.
@@ -149,4 +193,4 @@ curl http://${your_ip}:8015/v1/finetune/list_checkpoints -X POST -H "Content-Typ
149193

150194
## 🚀4. Descriptions for Finetuning parameters
151195

152-
We utilize [OpenAI finetuning parameters](https://platform.openai.com/docs/api-reference/fine-tuning) and extend it with more customizable parameters.
196+
We utilize [OpenAI finetuning parameters](https://platform.openai.com/docs/api-reference/fine-tuning) and extend it with more customizable parameters, see the definitions at [finetune_config](https://github.com/opea-project/GenAIComps/blob/main/comps/finetuning/finetune_config.py).

comps/finetuning/finetune_config.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import List, Optional, Union
77

8-
from pydantic import BaseModel, validator
8+
from pydantic import BaseModel, Field, validator
99

1010
from comps.cores.proto.api_protocol import FineTuningJobsRequest
1111

@@ -74,13 +74,29 @@ class DatasetConfig(BaseModel):
7474
truncation_side: str = "right"
7575
max_seq_length: int = 512
7676
truncation: bool = True
77-
padding: bool = True
77+
padding: Union[bool, str] = True
7878
mask_input: bool = True
7979
mask_response: bool = True
8080
data_preprocess_type: str = "neural_chat"
8181
max_train_samples: int = 0
8282
max_eval_samples: int = 0
8383
train_group_size: int = 8
84+
query_max_len: int = Field(
85+
default=128,
86+
description=(
87+
"The maximum total input sequence length after tokenization for passage. Sequences longer "
88+
"than this will be truncated, sequences shorter will be padded."
89+
),
90+
)
91+
passage_max_len: int = Field(
92+
default=128,
93+
description=(
94+
"The maximum total input sequence length after tokenization for passage. Sequences longer "
95+
"than this will be truncated, sequences shorter will be padded."
96+
),
97+
)
98+
query_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for query")
99+
passage_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for passage")
84100

85101

86102
class RayResourceConfig(BaseModel):
@@ -89,6 +105,14 @@ class RayResourceConfig(BaseModel):
89105
HPU: int = 0
90106

91107

108+
class EmbeddingTrainingConfig(BaseModel):
109+
negatives_cross_device: bool = Field(default=False, description="share negatives across devices")
110+
temperature: Optional[float] = Field(default=0.02)
111+
sentence_pooling_method: str = Field(default="cls", description="the pooling method, should be cls or mean")
112+
normalized: bool = Field(default=True)
113+
use_inbatch_neg: bool = Field(default=True, description="use passages in the same batch as negatives")
114+
115+
92116
class TrainingConfig(BaseModel):
93117
optimizer: str = "adamw_torch"
94118
batch_size: int = 2
@@ -106,6 +130,7 @@ class TrainingConfig(BaseModel):
106130
gradient_accumulation_steps: int = 1
107131
logging_steps: int = 10
108132
deepspeed_config_file: str = ""
133+
embedding_training_config: Optional[EmbeddingTrainingConfig] = EmbeddingTrainingConfig()
109134

110135
@validator("device")
111136
def check_device(cls, v: str):

comps/finetuning/llm_on_ray/finetune/data_process.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,74 @@ def __call__(self, features) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.T
246246
if isinstance(features[0], list):
247247
features = sum(features, [])
248248
return super().__call__(features)
249+
250+
251+
class TrainDatasetForEmbedding(Dataset):
252+
def __init__(self, dataset, args, tokenizer):
253+
self.dataset = dataset
254+
self.tokenizer = tokenizer
255+
self.args = args
256+
self.total_len = len(self.dataset)
257+
258+
def __len__(self):
259+
return self.total_len
260+
261+
def __getitem__(self, item) -> Tuple[str, List[str]]:
262+
query = self.dataset[item]["query"]
263+
if self.args["query_instruction_for_retrieval"] is not None:
264+
query = self.args["query_instruction_for_retrieval"] + query
265+
266+
passages = []
267+
268+
assert isinstance(self.dataset[item]["pos"], list)
269+
pos = random.choice(self.dataset[item]["pos"])
270+
passages.append(pos)
271+
272+
train_group_size = self.args.get("train_group_size", 8)
273+
if len(self.dataset[item]["neg"]) < train_group_size - 1:
274+
num = math.ceil((train_group_size - 1) / len(self.dataset[item]["neg"]))
275+
negs = random.sample(self.dataset[item]["neg"] * num, train_group_size - 1)
276+
else:
277+
negs = random.sample(self.dataset[item]["neg"], train_group_size - 1)
278+
passages.extend(negs)
279+
280+
if self.args["passage_instruction_for_retrieval"] is not None:
281+
passages = [self.args["passage_instruction_for_retrieval"] + p for p in passages]
282+
return query, passages
283+
284+
285+
@dataclass
286+
class EmbedCollator(DataCollatorWithPadding):
287+
"""Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
288+
and pass batch separately to the actual collator.
289+
290+
Abstract out data detail for the model.
291+
"""
292+
293+
query_max_len: int = 32
294+
passage_max_len: int = 128
295+
296+
def __call__(self, features):
297+
query = [f[0] for f in features]
298+
passage = [f[1] for f in features]
299+
300+
if isinstance(query[0], list):
301+
query = sum(query, [])
302+
if isinstance(passage[0], list):
303+
passage = sum(passage, [])
304+
305+
q_collated = self.tokenizer(
306+
query,
307+
padding=self.padding,
308+
truncation=True,
309+
max_length=self.query_max_len,
310+
return_tensors="pt",
311+
)
312+
d_collated = self.tokenizer(
313+
passage,
314+
padding=self.padding,
315+
truncation=True,
316+
max_length=self.passage_max_len,
317+
return_tensors="pt",
318+
)
319+
return {"query": q_collated, "passage": d_collated}

comps/finetuning/llm_on_ray/finetune/finetune.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@
2727
from comps import CustomLogger
2828
from comps.finetuning.finetune_config import FinetuneConfig
2929
from comps.finetuning.llm_on_ray import common
30-
from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor, GroupCollator, TrainDatasetForCE
31-
from comps.finetuning.llm_on_ray.finetune.modeling import CrossEncoder
30+
from comps.finetuning.llm_on_ray.finetune.data_process import (
31+
DataProcessor,
32+
EmbedCollator,
33+
GroupCollator,
34+
TrainDatasetForCE,
35+
TrainDatasetForEmbedding,
36+
)
37+
from comps.finetuning.llm_on_ray.finetune.modeling import BiEncoderModel, CrossEncoder
3238

3339
logger = CustomLogger("llm_on_ray/finetune")
3440

@@ -244,7 +250,8 @@ def group_texts(examples):
244250
dataset["train"] = TrainDatasetForCE(dataset["train"], config["Dataset"], tokenizer)
245251
return dataset
246252
elif task == "embedding":
247-
pass
253+
dataset["train"] = TrainDatasetForEmbedding(dataset["train"], config["Dataset"], tokenizer)
254+
return dataset
248255
else:
249256
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
250257

@@ -258,7 +265,12 @@ def prepare_data_collator(config: Dict, tokenizer):
258265
elif task == "rerank":
259266
return GroupCollator(tokenizer)
260267
elif task == "embedding":
261-
pass
268+
return EmbedCollator(
269+
tokenizer=tokenizer,
270+
padding=config["Dataset"]["padding"],
271+
query_max_len=config["Dataset"]["query_max_len"],
272+
passage_max_len=config["Dataset"]["passage_max_len"],
273+
)
262274
else:
263275
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
264276

@@ -268,24 +280,36 @@ def load_model(config: Dict):
268280
model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no"))
269281
model_config = config["General"].get("config", {})
270282
task = config["General"].get("task", "instruction_tuning")
271-
training_args = convert_to_training_args(TrainingArguments, config)
272283
if task == "instruction_tuning":
273284
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config)
274-
275285
lora_config = config["General"].get("lora_config", None)
276286
if lora_config:
277287
peft_config = LoraConfig(**lora_config)
278288
model = get_peft_model(model, peft_config)
279289
elif task == "rerank":
280290
model = CrossEncoder.from_pretrained(
281-
config["Dataset"],
282-
training_args,
291+
config["Dataset"].get("train_group_size", 8),
292+
config["Training"]["batch_size"],
283293
model_name,
284294
from_tf=bool(".ckpt" in model_name),
285295
config=model_config,
286296
)
287297
elif task == "embedding":
288-
pass
298+
should_concat = False
299+
if (
300+
config["Dataset"]["query_max_len"] == config["Dataset"]["passage_max_len"]
301+
and config["Dataset"]["padding"] == "max_length"
302+
):
303+
should_concat = True
304+
if config["Training"]["device"] == "hpu" and not should_concat:
305+
raise ValueError("please set query_max_len==passage_max_len and padding='max_length' for hpu.")
306+
307+
if config["Training"].get("embedding_training_config", None) is not None:
308+
model = BiEncoderModel(
309+
model_name=model_name, should_concat=should_concat, **config["Training"]["embedding_training_config"]
310+
)
311+
else:
312+
model = BiEncoderModel(model_name=model_name, should_concat=should_concat)
289313
else:
290314
raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.")
291315

0 commit comments

Comments
 (0)