|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import os |
| 5 | +import random |
| 6 | +import re |
| 7 | +import time |
| 8 | +import urllib.parse |
| 9 | +import uuid |
| 10 | +from pathlib import Path |
| 11 | +from typing import Dict |
| 12 | + |
| 13 | +from fastapi import BackgroundTasks, File, Form, HTTPException, UploadFile |
| 14 | +from pydantic_yaml import to_yaml_file |
| 15 | +from ray.job_submission import JobSubmissionClient |
| 16 | + |
| 17 | +from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry |
| 18 | +from comps.cores.proto.api_protocol import ( |
| 19 | + FileObject, |
| 20 | + FineTuningJob, |
| 21 | + FineTuningJobCheckpoint, |
| 22 | + FineTuningJobIDRequest, |
| 23 | + FineTuningJobList, |
| 24 | + UploadFileRequest, |
| 25 | +) |
| 26 | +from comps.finetuning.src.integrations.finetune_config import FinetuneConfig, FineTuningParams |
| 27 | + |
| 28 | +logger = CustomLogger("opea") |
| 29 | + |
| 30 | +DATASET_BASE_PATH = "datasets" |
| 31 | +JOBS_PATH = "jobs" |
| 32 | +OUTPUT_DIR = "output" |
| 33 | + |
| 34 | +if not os.path.exists(DATASET_BASE_PATH): |
| 35 | + os.mkdir(DATASET_BASE_PATH) |
| 36 | +if not os.path.exists(JOBS_PATH): |
| 37 | + os.mkdir(JOBS_PATH) |
| 38 | +if not os.path.exists(OUTPUT_DIR): |
| 39 | + os.mkdir(OUTPUT_DIR) |
| 40 | + |
| 41 | +FineTuningJobID = str |
| 42 | +CheckpointID = str |
| 43 | +CheckpointPath = str |
| 44 | + |
| 45 | +CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs |
| 46 | + |
| 47 | +global ray_client |
| 48 | +ray_client: JobSubmissionClient = None |
| 49 | + |
| 50 | +running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {} |
| 51 | +finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {} |
| 52 | +checkpoint_id_to_checkpoint_path: Dict[CheckpointID, CheckpointPath] = {} |
| 53 | + |
| 54 | + |
| 55 | +# Add a background task to periodicly update job status |
| 56 | +def update_job_status(job_id: FineTuningJobID): |
| 57 | + while True: |
| 58 | + job_status = ray_client.get_job_status(finetuning_job_to_ray_job[job_id]) |
| 59 | + status = str(job_status).lower() |
| 60 | + # Ray status "stopped" is OpenAI status "cancelled" |
| 61 | + status = "cancelled" if status == "stopped" else status |
| 62 | + logger.info(f"Status of job {job_id} is '{status}'") |
| 63 | + running_finetuning_jobs[job_id].status = status |
| 64 | + if status == "succeeded" or status == "cancelled" or status == "failed": |
| 65 | + break |
| 66 | + time.sleep(CHECK_JOB_STATUS_INTERVAL) |
| 67 | + |
| 68 | + |
| 69 | +async def save_content_to_local_disk(save_path: str, content): |
| 70 | + save_path = Path(save_path) |
| 71 | + try: |
| 72 | + if isinstance(content, str): |
| 73 | + with open(save_path, "w", encoding="utf-8") as file: |
| 74 | + file.write(content) |
| 75 | + else: |
| 76 | + with save_path.open("wb") as fout: |
| 77 | + content = await content.read() |
| 78 | + fout.write(content) |
| 79 | + except Exception as e: |
| 80 | + logger.info(f"Write file failed. Exception: {e}") |
| 81 | + raise Exception(status_code=500, detail=f"Write file {save_path} failed. Exception: {e}") |
| 82 | + |
| 83 | + |
| 84 | +async def upload_file(purpose: str = Form(...), file: UploadFile = File(...)): |
| 85 | + return UploadFileRequest(purpose=purpose, file=file) |
| 86 | + |
| 87 | + |
| 88 | +@OpeaComponentRegistry.register("XTUNE_FINETUNING") |
| 89 | +class XtuneFinetuning(OpeaComponent): |
| 90 | + """A specialized finetuning component derived from OpeaComponent for finetuning services.""" |
| 91 | + |
| 92 | + def __init__(self, name: str, description: str, config: dict = None): |
| 93 | + super().__init__(name, "finetuning", description, config) |
| 94 | + |
| 95 | + def create_finetuning_jobs(self, request: FineTuningParams, background_tasks: BackgroundTasks): |
| 96 | + model = request.model |
| 97 | + train_file = request.training_file |
| 98 | + finetune_config = FinetuneConfig(General=request.General) |
| 99 | + if finetune_config.General.xtune_config.device == "XPU": |
| 100 | + flag = 1 |
| 101 | + else: |
| 102 | + flag = 0 |
| 103 | + if os.getenv("HF_TOKEN", None): |
| 104 | + finetune_config.General.config.token = os.getenv("HF_TOKEN", None) |
| 105 | + |
| 106 | + job = FineTuningJob( |
| 107 | + id=f"ft-job-{uuid.uuid4()}", |
| 108 | + model=model, |
| 109 | + created_at=int(time.time()), |
| 110 | + training_file=train_file, |
| 111 | + hyperparameters={}, |
| 112 | + status="running", |
| 113 | + seed=random.randint(0, 1000) if request.seed is None else request.seed, |
| 114 | + ) |
| 115 | + |
| 116 | + finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" |
| 117 | + to_yaml_file(finetune_config_file, finetune_config) |
| 118 | + |
| 119 | + global ray_client |
| 120 | + ray_client = JobSubmissionClient() if ray_client is None else ray_client |
| 121 | + if finetune_config.General.xtune_config.tool == "clip": |
| 122 | + ray_job_id = ray_client.submit_job( |
| 123 | + # Entrypoint shell command to execute |
| 124 | + entrypoint=f"cd integrations/xtune/src/llamafactory/clip_finetune && export DATA={finetune_config.General.xtune_config.dataset_root} && bash scripts/clip_finetune/{finetune_config.General.xtune_config.trainer}.sh {finetune_config.General.xtune_config.dataset} {finetune_config.General.xtune_config.model} 0 {finetune_config.General.xtune_config.device} > /tmp/test.log 2>&1 || true", |
| 125 | + ) |
| 126 | + |
| 127 | + else: |
| 128 | + if flag == 1: |
| 129 | + ray_job_id = ray_client.submit_job( |
| 130 | + # Entrypoint shell command to execute |
| 131 | + entrypoint=f"cd integrations/xtune/src/llamafactory/adaclip_finetune && python train.py --config {finetune_config.General.xtune_config.config_file} --frames_dir {finetune_config.General.xtune_config.dataset_root}{finetune_config.General.xtune_config.dataset}/frames --top_k 16 --freeze_cnn --frame_agg mlp --resume {finetune_config.General.xtune_config.model} --xpu --batch_size 8 > /tmp/test.log 2>&1 || true", |
| 132 | + ) |
| 133 | + else: |
| 134 | + ray_job_id = ray_client.submit_job( |
| 135 | + # Entrypoint shell command to execute |
| 136 | + entrypoint=f"cd integrations/xtune/src/llamafactory/adaclip_finetune && python train.py --config {finetune_config.General.config_file} --frames_dir {finetune_config.General.dataset_root}{finetune_config.General.dataset}/frames --top_k 16 --freeze_cnn --frame_agg mlp --resume {finetune_config.General.model}--batch_size 8 > /tmp/test.log 2>&1 || true", |
| 137 | + ) |
| 138 | + |
| 139 | + logger.info(f"Submitted Ray job: {ray_job_id} ...") |
| 140 | + |
| 141 | + running_finetuning_jobs[job.id] = job |
| 142 | + finetuning_job_to_ray_job[job.id] = ray_job_id |
| 143 | + |
| 144 | + background_tasks.add_task(update_job_status, job.id) |
| 145 | + |
| 146 | + return job |
| 147 | + |
| 148 | + def list_finetuning_jobs(self): |
| 149 | + finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False) |
| 150 | + |
| 151 | + return finetuning_jobs_list |
| 152 | + |
| 153 | + def retrieve_finetuning_job(self, request: FineTuningJobIDRequest): |
| 154 | + fine_tuning_job_id = request.fine_tuning_job_id |
| 155 | + |
| 156 | + job = running_finetuning_jobs.get(fine_tuning_job_id) |
| 157 | + if job is None: |
| 158 | + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") |
| 159 | + return job |
| 160 | + |
| 161 | + def cancel_finetuning_job(self, request: FineTuningJobIDRequest): |
| 162 | + fine_tuning_job_id = request.fine_tuning_job_id |
| 163 | + |
| 164 | + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) |
| 165 | + if ray_job_id is None: |
| 166 | + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") |
| 167 | + |
| 168 | + global ray_client |
| 169 | + ray_client = JobSubmissionClient() if ray_client is None else ray_client |
| 170 | + ray_client.stop_job(ray_job_id) |
| 171 | + |
| 172 | + job = running_finetuning_jobs.get(fine_tuning_job_id) |
| 173 | + job.status = "cancelled" |
| 174 | + return job |
| 175 | + |
| 176 | + def list_finetuning_checkpoints(self, request: FineTuningJobIDRequest): |
| 177 | + fine_tuning_job_id = request.fine_tuning_job_id |
| 178 | + |
| 179 | + job = running_finetuning_jobs.get(fine_tuning_job_id) |
| 180 | + if job is None: |
| 181 | + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") |
| 182 | + output_dir = os.path.join(OUTPUT_DIR, job.id) |
| 183 | + checkpoints = [] |
| 184 | + if os.path.exists(output_dir): |
| 185 | + # Iterate over the contents of the directory and add an entry for each |
| 186 | + files = os.listdir(output_dir) |
| 187 | + for file in files: # Loop over directory contents |
| 188 | + file_path = os.path.join(output_dir, file) |
| 189 | + if os.path.isdir(file_path) and file.startswith("checkpoint"): |
| 190 | + steps = re.findall("\d+", file)[0] |
| 191 | + checkpointsResponse = FineTuningJobCheckpoint( |
| 192 | + id=f"ftckpt-{uuid.uuid4()}", # Generate a unique ID |
| 193 | + created_at=int(time.time()), # Use the current timestamp |
| 194 | + fine_tuned_model_checkpoint=file_path, # Directory path itself |
| 195 | + fine_tuning_job_id=fine_tuning_job_id, |
| 196 | + object="fine_tuning.job.checkpoint", |
| 197 | + step_number=steps, |
| 198 | + ) |
| 199 | + checkpoints.append(checkpointsResponse) |
| 200 | + if job.status == "succeeded": |
| 201 | + checkpointsResponse = FineTuningJobCheckpoint( |
| 202 | + id=f"ftckpt-{uuid.uuid4()}", # Generate a unique ID |
| 203 | + created_at=int(time.time()), # Use the current timestamp |
| 204 | + fine_tuned_model_checkpoint=output_dir, # Directory path itself |
| 205 | + fine_tuning_job_id=fine_tuning_job_id, |
| 206 | + object="fine_tuning.job.checkpoint", |
| 207 | + ) |
| 208 | + checkpoints.append(checkpointsResponse) |
| 209 | + |
| 210 | + return checkpoints |
| 211 | + |
| 212 | + async def upload_training_files(self, request: UploadFileRequest): |
| 213 | + file = request.file |
| 214 | + if file is None: |
| 215 | + raise HTTPException(status_code=404, detail="upload file failed!") |
| 216 | + filename = urllib.parse.quote(file.filename, safe="") |
| 217 | + save_path = os.path.join(DATASET_BASE_PATH, filename) |
| 218 | + await save_content_to_local_disk(save_path, file) |
| 219 | + |
| 220 | + fileBytes = os.path.getsize(save_path) |
| 221 | + fileInfo = FileObject( |
| 222 | + id=f"file-{uuid.uuid4()}", |
| 223 | + object="file", |
| 224 | + bytes=fileBytes, |
| 225 | + created_at=int(time.time()), |
| 226 | + filename=filename, |
| 227 | + purpose="fine-tune", |
| 228 | + ) |
| 229 | + |
| 230 | + return fileInfo |
| 231 | + |
| 232 | + def invoke(self, *args, **kwargs): |
| 233 | + pass |
| 234 | + |
| 235 | + def check_health(self) -> bool: |
| 236 | + """Checks the health of the component. |
| 237 | +
|
| 238 | + Returns: |
| 239 | + bool: True if the component is healthy, False otherwise. |
| 240 | + """ |
| 241 | + return True |
0 commit comments