Skip to content

Commit 859f6e1

Browse files
authored
(fix) v1/fine_tuning/jobs with VertexAI (#7487)
* update convert_openai_request_to_vertex * test_create_vertex_fine_tune_jobs_mocked
1 parent b3d4ee9 commit 859f6e1

File tree

3 files changed

+134
-30
lines changed

3 files changed

+134
-30
lines changed

litellm/llms/vertex_ai/fine_tuning/handler.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import traceback
23
from datetime import datetime
34
from typing import Literal, Optional, Union
@@ -67,14 +68,14 @@ def convert_vertex_response_to_open_ai_response(
6768
training_uri = response["supervisedTuningSpec"]["trainingDatasetUri"] or ""
6869

6970
return FineTuningJob(
70-
id=response["name"] or "",
71+
id=response.get("name", "") or "",
7172
created_at=created_at,
72-
fine_tuned_model=response["tunedModelDisplayName"],
73+
fine_tuned_model=response.get("tunedModelDisplayName", ""),
7374
finished_at=None,
7475
hyperparameters=Hyperparameters(
7576
n_epochs=0,
7677
),
77-
model=response["baseModel"] or "",
78+
model=response.get("baseModel", "") or "",
7879
object="fine_tuning.job",
7980
organization_id="",
8081
result_files=[],
@@ -95,12 +96,20 @@ def convert_openai_request_to_vertex(
9596
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
9697
supervised_tuning_spec = FineTunesupervisedTuningSpec(
9798
"""
98-
hyperparameters = create_fine_tuning_job_data.hyperparameters
99+
99100
supervised_tuning_spec = FineTunesupervisedTuningSpec(
100101
training_dataset_uri=create_fine_tuning_job_data.training_file,
101-
validation_dataset=create_fine_tuning_job_data.validation_file,
102102
)
103103

104+
if create_fine_tuning_job_data.validation_file:
105+
supervised_tuning_spec["validation_dataset"] = (
106+
create_fine_tuning_job_data.validation_file
107+
)
108+
109+
if kwargs.get("adapter_size"):
110+
supervised_tuning_spec["adapter_size"] = kwargs.get("adapter_size")
111+
112+
hyperparameters = create_fine_tuning_job_data.hyperparameters
104113
if hyperparameters:
105114
if hyperparameters.n_epochs:
106115
supervised_tuning_spec["epoch_count"] = int(hyperparameters.n_epochs)
@@ -109,8 +118,6 @@ def convert_openai_request_to_vertex(
109118
hyperparameters.learning_rate_multiplier
110119
)
111120

112-
supervised_tuning_spec["adapter_size"] = kwargs.get("adapter_size")
113-
114121
fine_tune_job = FineTuneJobCreate(
115122
baseModel=create_fine_tuning_job_data.model,
116123
supervisedTuningSpec=supervised_tuning_spec,
@@ -130,7 +137,7 @@ async def acreate_fine_tuning_job(
130137
verbose_logger.debug(
131138
"about to create fine tuning job: %s, request_data: %s",
132139
fine_tuning_url,
133-
request_data,
140+
json.dumps(request_data, indent=4),
134141
)
135142
if self.async_handler is None:
136143
raise ValueError(

litellm/proxy/proxy_config.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,12 @@ model_list:
88
model: "openai/*"
99
api_key: os.environ/OPENAI_API_KEY
1010
litellm_settings:
11-
callbacks: ["datadog"]
11+
callbacks: ["datadog"]
12+
13+
14+
# For /fine_tuning/jobs endpoints
15+
finetune_settings:
16+
- custom_llm_provider: "vertex_ai"
17+
vertex_project: "adroit-crow-413218"
18+
vertex_location: "us-central1"
19+
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"

tests/batches_tests/test_fine_tuning_api.py

Lines changed: 110 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from litellm.integrations.custom_logger import CustomLogger
2828
from litellm.types.utils import StandardLoggingPayload
29+
from unittest.mock import patch, MagicMock, AsyncMock
2930

3031
vertex_finetune_api = VertexFineTuningAPI()
3132

@@ -237,29 +238,75 @@ async def test_azure_create_fine_tune_jobs_async():
237238

238239

239240
@pytest.mark.asyncio()
240-
@pytest.mark.skip(reason="skipping until we can cancel fine tuning jobs")
241-
async def test_create_vertex_fine_tune_jobs():
242-
try:
243-
verbose_logger.setLevel(logging.DEBUG)
244-
load_vertex_ai_credentials()
241+
async def test_create_vertex_fine_tune_jobs_mocked():
242+
load_vertex_ai_credentials()
243+
# Define reusable variables for the test
244+
project_id = "633608382793"
245+
location = "us-central1"
246+
job_id = "3978211980451250176"
247+
base_model = "gemini-1.0-pro-002"
248+
tuned_model_name = f"{base_model}-f9259f2c-3fdf-4dd3-9413-afef2bfd24f5"
249+
training_file = (
250+
"gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
251+
)
252+
create_time = "2024-12-31T22:40:20.211140Z"
253+
254+
mock_response = AsyncMock()
255+
mock_response.status_code = 200
256+
mock_response.json = MagicMock(
257+
return_value={
258+
"name": f"projects/{project_id}/locations/{location}/tuningJobs/{job_id}",
259+
"tunedModelDisplayName": tuned_model_name,
260+
"baseModel": base_model,
261+
"supervisedTuningSpec": {"trainingDatasetUri": training_file},
262+
"state": "JOB_STATE_PENDING",
263+
"createTime": create_time,
264+
"updateTime": create_time,
265+
}
266+
)
245267

246-
vertex_credentials = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
247-
print("creating fine tuning job")
268+
with patch(
269+
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
270+
return_value=mock_response,
271+
) as mock_post:
248272
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
249-
model="gemini-1.0-pro-002",
273+
model=base_model,
250274
custom_llm_provider="vertex_ai",
251-
training_file="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
252-
vertex_project="adroit-crow-413218",
253-
vertex_location="us-central1",
254-
vertex_credentials=vertex_credentials,
275+
training_file=training_file,
276+
vertex_project=project_id,
277+
vertex_location=location,
255278
)
256-
print("vertex ai create fine tuning response=", create_fine_tuning_response)
257279

258-
assert create_fine_tuning_response.id is not None
259-
assert create_fine_tuning_response.model == "gemini-1.0-pro-002"
260-
assert create_fine_tuning_response.object == "fine_tuning.job"
261-
except Exception:
262-
pass
280+
# Verify the request
281+
mock_post.assert_called_once()
282+
283+
# Validate the request
284+
assert mock_post.call_args.kwargs["json"] == {
285+
"baseModel": base_model,
286+
"supervisedTuningSpec": {"training_dataset_uri": training_file},
287+
"tunedModelDisplayName": None,
288+
}
289+
290+
# Verify the response
291+
response_json = json.loads(create_fine_tuning_response.model_dump_json())
292+
assert (
293+
response_json["id"]
294+
== f"projects/{project_id}/locations/{location}/tuningJobs/{job_id}"
295+
)
296+
assert response_json["model"] == base_model
297+
assert response_json["object"] == "fine_tuning.job"
298+
assert response_json["fine_tuned_model"] == tuned_model_name
299+
assert response_json["status"] == "queued"
300+
assert response_json["training_file"] == training_file
301+
assert (
302+
response_json["created_at"] == 1735684820
303+
) # Unix timestamp for create_time
304+
assert response_json["error"] is None
305+
assert response_json["finished_at"] is None
306+
assert response_json["validation_file"] is None
307+
assert response_json["trained_tokens"] is None
308+
assert response_json["estimated_finish"] is None
309+
assert response_json["integrations"] == []
263310

264311

265312
# Testing OpenAI -> Vertex AI param mapping
@@ -276,7 +323,7 @@ def test_convert_openai_request_to_vertex_basic():
276323

277324
result = vertex_finetune_api.convert_openai_request_to_vertex(openai_data)
278325

279-
print("converted vertex ai result=", result)
326+
print("converted vertex ai result=", json.dumps(result, indent=4))
280327

281328
assert result["baseModel"] == "text-davinci-002"
282329
assert result["tunedModelDisplayName"] == "my_fine_tuned_model"
@@ -303,15 +350,57 @@ def test_convert_openai_request_to_vertex_with_adapter_size():
303350
openai_data, adapter_size="SMALL"
304351
)
305352

306-
print("converted vertex ai result=", result)
353+
print("converted vertex ai result=", json.dumps(result, indent=4))
307354

308355
assert result["baseModel"] == "text-davinci-002"
309356
assert result["tunedModelDisplayName"] == "custom_model"
310357
assert (
311358
result["supervisedTuningSpec"]["training_dataset_uri"]
312359
== "gs://bucket/train.jsonl"
313360
)
314-
assert result["supervisedTuningSpec"]["validation_dataset"] is None
315361
assert result["supervisedTuningSpec"]["epoch_count"] == 5
316362
assert result["supervisedTuningSpec"]["learning_rate_multiplier"] == 0.2
317363
assert result["supervisedTuningSpec"]["adapter_size"] == "SMALL"
364+
365+
366+
def test_convert_basic_openai_request_to_vertex_request():
367+
openai_data = FineTuningJobCreate(
368+
training_file="gs://bucket/train.jsonl",
369+
model="gemini-1.0-pro-002",
370+
)
371+
372+
result = vertex_finetune_api.convert_openai_request_to_vertex(
373+
openai_data, adapter_size="SMALL"
374+
)
375+
376+
print("converted vertex ai result=", json.dumps(result, indent=4))
377+
378+
assert result["baseModel"] == "gemini-1.0-pro-002"
379+
assert result["tunedModelDisplayName"] == None
380+
assert (
381+
result["supervisedTuningSpec"]["training_dataset_uri"]
382+
== "gs://bucket/train.jsonl"
383+
)
384+
385+
386+
@pytest.mark.asyncio()
387+
@pytest.mark.skip(reason="skipping - we run mock tests for vertex ai")
388+
async def test_create_vertex_fine_tune_jobs():
389+
verbose_logger.setLevel(logging.DEBUG)
390+
# load_vertex_ai_credentials()
391+
392+
vertex_credentials = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
393+
print("creating fine tuning job")
394+
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
395+
model="gemini-1.0-pro-002",
396+
custom_llm_provider="vertex_ai",
397+
training_file="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
398+
vertex_project="adroit-crow-413218",
399+
vertex_location="us-central1",
400+
vertex_credentials=vertex_credentials,
401+
)
402+
print("vertex ai create fine tuning response=", create_fine_tuning_response)
403+
404+
assert create_fine_tuning_response.id is not None
405+
assert create_fine_tuning_response.model == "gemini-1.0-pro-002"
406+
assert create_fine_tuning_response.object == "fine_tuning.job"

0 commit comments

Comments
 (0)