Skip to content

Commit a924579

Browse files
lkk12014402rootpre-commit-ci[bot]
authored
remove finetuning models limitation. (#573)
* remove finetuning models limitation. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add ut. * update ut and add dashboard. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update ut port. * update finetuning params for customization. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change name. --------- Co-authored-by: root <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 445c9b1 commit a924579

File tree

10 files changed

+171
-150
lines changed

10 files changed

+171
-150
lines changed

comps/finetuning/docker/Dockerfile_cpu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ WORKDIR /home/user/comps/finetuning
3232
RUN echo PKGPATH=$(python3 -c "import pkg_resources; print(pkg_resources.get_distribution('oneccl-bind-pt').location)") >> run.sh && \
3333
echo 'export LD_LIBRARY_PATH=$PKGPATH/oneccl_bindings_for_pytorch/opt/mpi/lib/:$LD_LIBRARY_PATH' >> run.sh && \
3434
echo 'source $PKGPATH/oneccl_bindings_for_pytorch/env/setvars.sh' >> run.sh && \
35-
echo ray start --head >> run.sh && \
35+
echo ray start --head --dashboard-host=0.0.0.0 >> run.sh && \
36+
echo export RAY_ADDRESS=http://localhost:8265 >> run.sh && \
3637
echo python finetuning_service.py >> run.sh
3738

38-
CMD bash run.sh
39+
CMD bash run.sh

comps/finetuning/llm_on_ray/finetune/finetune_config.py renamed to comps/finetuning/finetune_config.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pydantic import BaseModel, validator
99

10+
from comps.cores.proto.api_protocol import FineTuningJobsRequest
11+
1012
PRECISION_BF16 = "bf16"
1113
PRECISION_FP16 = "fp16"
1214
PRECISION_NO = "no"
@@ -20,30 +22,31 @@
2022
ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED"
2123

2224

23-
class GeneralConfig(BaseModel):
24-
trust_remote_code: bool
25-
use_auth_token: Optional[str]
25+
class LoadConfig(BaseModel):
26+
trust_remote_code: bool = False
27+
# set Huggingface token to access dataset/model
28+
token: Optional[str] = None
2629

2730

2831
class LoraConfig(BaseModel):
29-
task_type: str
30-
r: int
31-
lora_alpha: int
32-
lora_dropout: float
32+
task_type: str = "CAUSAL_LM"
33+
r: int = 8
34+
lora_alpha: int = 32
35+
lora_dropout: float = 0.1
3336
target_modules: Optional[List[str]] = None
3437

3538

36-
class General(BaseModel):
37-
base_model: str
39+
class GeneralConfig(BaseModel):
40+
base_model: str = None
3841
tokenizer_name: Optional[str] = None
3942
gaudi_config_name: Optional[str] = None
40-
gpt_base_model: bool
41-
output_dir: str
43+
gpt_base_model: bool = False
44+
output_dir: str = "./tmp"
4245
report_to: str = "none"
4346
resume_from_checkpoint: Optional[str] = None
4447
save_strategy: str = "no"
45-
config: GeneralConfig
46-
lora_config: Optional[LoraConfig] = None
48+
config: LoadConfig = LoadConfig()
49+
lora_config: Optional[LoraConfig] = LoraConfig()
4750
enable_gradient_checkpointing: bool = False
4851

4952
@validator("report_to")
@@ -52,10 +55,10 @@ def check_report_to(cls, v: str):
5255
return v
5356

5457

55-
class Dataset(BaseModel):
56-
train_file: str
57-
validation_file: Optional[str]
58-
validation_split_percentage: int
58+
class DatasetConfig(BaseModel):
59+
train_file: str = None
60+
validation_file: Optional[str] = None
61+
validation_split_percentage: int = 5
5962
max_length: int = 512
6063
group: bool = True
6164
block_size: int = 512
@@ -74,23 +77,23 @@ class Dataset(BaseModel):
7477

7578

7679
class RayResourceConfig(BaseModel):
77-
CPU: int
80+
CPU: int = 32
7881
GPU: int = 0
7982
HPU: int = 0
8083

8184

82-
class Training(BaseModel):
83-
optimizer: str
84-
batch_size: int
85-
epochs: int
85+
class TrainingConfig(BaseModel):
86+
optimizer: str = "adamw_torch"
87+
batch_size: int = 2
88+
epochs: int = 1
8689
max_train_steps: Optional[int] = None
87-
learning_rate: float
88-
lr_scheduler: str
89-
weight_decay: float
90+
learning_rate: float = 5.0e-5
91+
lr_scheduler: str = "linear"
92+
weight_decay: float = 0.0
9093
device: str = DEVICE_CPU
9194
hpu_execution_mode: str = "lazy"
92-
num_training_workers: int
93-
resources_per_worker: RayResourceConfig
95+
num_training_workers: int = 1
96+
resources_per_worker: RayResourceConfig = RayResourceConfig()
9497
accelerate_mode: str = ACCELERATE_STRATEGY_DDP
9598
mixed_precision: str = PRECISION_NO
9699
gradient_accumulation_steps: int = 1
@@ -151,6 +154,13 @@ def check_logging_steps(cls, v: int):
151154

152155

153156
class FinetuneConfig(BaseModel):
154-
General: General
155-
Dataset: Dataset
156-
Training: Training
157+
General: GeneralConfig = GeneralConfig()
158+
Dataset: DatasetConfig = DatasetConfig()
159+
Training: TrainingConfig = TrainingConfig()
160+
161+
162+
class FineTuningParams(FineTuningJobsRequest):
163+
# priority use FineTuningJobsRequest params
164+
General: GeneralConfig = GeneralConfig()
165+
Dataset: DatasetConfig = DatasetConfig()
166+
Training: TrainingConfig = TrainingConfig()

comps/finetuning/finetune_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic_yaml import parse_yaml_raw_as
77
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
88

9-
from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig
9+
from comps.finetuning.finetune_config import FinetuneConfig
1010

1111

1212
class FineTuneCallback(TrainerCallback):

comps/finetuning/finetuning_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from fastapi import BackgroundTasks, File, UploadFile
99

1010
from comps import opea_microservices, register_microservice
11-
from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest
11+
from comps.cores.proto.api_protocol import FineTuningJobIDRequest
12+
from comps.finetuning.finetune_config import FineTuningParams
1213
from comps.finetuning.handlers import (
1314
DATASET_BASE_PATH,
1415
handle_cancel_finetuning_job,
@@ -21,7 +22,7 @@
2122

2223

2324
@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8015)
24-
def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
25+
def create_finetuning_jobs(request: FineTuningParams, background_tasks: BackgroundTasks):
2526
return handle_create_finetuning_jobs(request, background_tasks)
2627

2728

comps/finetuning/handlers.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,21 @@
1313
from ray.job_submission import JobSubmissionClient
1414

1515
from comps import CustomLogger
16-
from comps.cores.proto.api_protocol import (
17-
FineTuningJob,
18-
FineTuningJobIDRequest,
19-
FineTuningJobList,
20-
FineTuningJobsRequest,
21-
)
22-
from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig
16+
from comps.cores.proto.api_protocol import FineTuningJob, FineTuningJobIDRequest, FineTuningJobList
17+
from comps.finetuning.finetune_config import FinetuneConfig, FineTuningParams
2318

2419
logger = CustomLogger("finetuning_handlers")
2520

26-
MODEL_CONFIG_FILE_MAP = {
27-
"meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml",
28-
"mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml",
29-
}
30-
3121
DATASET_BASE_PATH = "datasets"
3222
JOBS_PATH = "jobs"
23+
OUTPUT_DIR = "output"
24+
3325
if not os.path.exists(DATASET_BASE_PATH):
3426
os.mkdir(DATASET_BASE_PATH)
35-
3627
if not os.path.exists(JOBS_PATH):
3728
os.mkdir(JOBS_PATH)
29+
if not os.path.exists(OUTPUT_DIR):
30+
os.mkdir(OUTPUT_DIR)
3831

3932
FineTuningJobID = str
4033
CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs
@@ -62,23 +55,17 @@ def update_job_status(job_id: FineTuningJobID):
6255
time.sleep(CHECK_JOB_STATUS_INTERVAL)
6356

6457

65-
def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
58+
def handle_create_finetuning_jobs(request: FineTuningParams, background_tasks: BackgroundTasks):
6659
base_model = request.model
6760
train_file = request.training_file
6861
train_file_path = os.path.join(DATASET_BASE_PATH, train_file)
6962

70-
model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model)
71-
if not model_config_file:
72-
raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!")
73-
7463
if not os.path.exists(train_file_path):
7564
raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!")
7665

77-
with open(model_config_file) as f:
78-
finetune_config = parse_yaml_raw_as(FinetuneConfig, f)
79-
66+
finetune_config = FinetuneConfig(General=request.General, Dataset=request.Dataset, Training=request.Training)
67+
finetune_config.General.base_model = base_model
8068
finetune_config.Dataset.train_file = train_file_path
81-
8269
if request.hyperparameters is not None:
8370
if request.hyperparameters.epochs != "auto":
8471
finetune_config.Training.epochs = request.hyperparameters.epochs
@@ -90,7 +77,7 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
9077
finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier
9178

9279
if os.getenv("HF_TOKEN", None):
93-
finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None)
80+
finetune_config.General.config.token = os.getenv("HF_TOKEN", None)
9481

9582
job = FineTuningJob(
9683
id=f"ft-job-{uuid.uuid4()}",
@@ -105,12 +92,16 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
10592
status="running",
10693
seed=random.randint(0, 1000) if request.seed is None else request.seed,
10794
)
108-
finetune_config.General.output_dir = os.path.join(JOBS_PATH, job.id)
95+
finetune_config.General.output_dir = os.path.join(OUTPUT_DIR, job.id)
10996
if os.getenv("DEVICE", ""):
11097

11198
logger.info(f"specific device: {os.getenv('DEVICE')}")
11299

113100
finetune_config.Training.device = os.getenv("DEVICE")
101+
if finetune_config.Training.device == "hpu":
102+
if finetune_config.Training.resources_per_worker.HPU == 0:
103+
# set 1
104+
finetune_config.Training.resources_per_worker.HPU = 1
114105

115106
finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml"
116107
to_yaml_file(finetune_config_file, finetune_config)
@@ -122,7 +113,7 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas
122113
# Entrypoint shell command to execute
123114
entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}",
124115
# Path to the local directory that contains the script.py file
125-
runtime_env={"working_dir": "./"},
116+
runtime_env={"working_dir": "./", "excludes": [f"{OUTPUT_DIR}"]},
126117
)
127118

128119
logger.info(f"Submitted Ray job: {ray_job_id} ...")

comps/finetuning/launch.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
if [[ -n "$RAY_PORT" ]];then
5-
ray start --head --port $RAY_PORT
5+
ray start --head --port $RAY_PORT --dashboard-host=0.0.0.0
66
else
7-
ray start --head
7+
ray start --head --dashboard-host=0.0.0.0
88
export RAY_PORT=8265
99
fi
1010

11-
export RAY_ADDRESS=http://127.0.0.1:$RAY_PORT
11+
export RAY_ADDRESS=http://localhost:$RAY_PORT
1212
python finetuning_service.py

comps/finetuning/llm_on_ray/finetune/finetune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from ray.train.torch import TorchTrainer
2525

2626
from comps import CustomLogger
27+
from comps.finetuning.finetune_config import FinetuneConfig
2728
from comps.finetuning.llm_on_ray import common
2829
from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor
29-
from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig
3030

3131
logger = CustomLogger("llm_on_ray/finetune")
3232

@@ -171,8 +171,8 @@ def local_load(name, **load_config):
171171
else:
172172
# try to download and load dataset from huggingface.co
173173
load_config = config["General"].get("config", {})
174-
use_auth_token = load_config.get("use_auth_token", None)
175-
raw_dataset = datasets.load_dataset(dataset_file, use_auth_token=use_auth_token)
174+
use_auth_token = load_config.get("token", None)
175+
raw_dataset = datasets.load_dataset(dataset_file, token=use_auth_token)
176176

177177
validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0)
178178
if "validation" not in raw_dataset.keys() and (

comps/finetuning/models/llama-2-7b-chat-hf.yaml

Lines changed: 0 additions & 39 deletions
This file was deleted.

comps/finetuning/models/mistral-7b-v0.1.yaml

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)