26
26
)
27
27
from litellm .integrations .custom_logger import CustomLogger
28
28
from litellm .types .utils import StandardLoggingPayload
29
+ from unittest .mock import patch , MagicMock , AsyncMock
29
30
30
31
vertex_finetune_api = VertexFineTuningAPI ()
31
32
@@ -237,29 +238,75 @@ async def test_azure_create_fine_tune_jobs_async():
237
238
238
239
239
240
@pytest .mark .asyncio ()
240
- @pytest .mark .skip (reason = "skipping until we can cancel fine tuning jobs" )
241
- async def test_create_vertex_fine_tune_jobs ():
242
- try :
243
- verbose_logger .setLevel (logging .DEBUG )
244
- load_vertex_ai_credentials ()
241
+ async def test_create_vertex_fine_tune_jobs_mocked ():
242
+ load_vertex_ai_credentials ()
243
+ # Define reusable variables for the test
244
+ project_id = "633608382793"
245
+ location = "us-central1"
246
+ job_id = "3978211980451250176"
247
+ base_model = "gemini-1.0-pro-002"
248
+ tuned_model_name = f"{ base_model } -f9259f2c-3fdf-4dd3-9413-afef2bfd24f5"
249
+ training_file = (
250
+ "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
251
+ )
252
+ create_time = "2024-12-31T22:40:20.211140Z"
253
+
254
+ mock_response = AsyncMock ()
255
+ mock_response .status_code = 200
256
+ mock_response .json = MagicMock (
257
+ return_value = {
258
+ "name" : f"projects/{ project_id } /locations/{ location } /tuningJobs/{ job_id } " ,
259
+ "tunedModelDisplayName" : tuned_model_name ,
260
+ "baseModel" : base_model ,
261
+ "supervisedTuningSpec" : {"trainingDatasetUri" : training_file },
262
+ "state" : "JOB_STATE_PENDING" ,
263
+ "createTime" : create_time ,
264
+ "updateTime" : create_time ,
265
+ }
266
+ )
245
267
246
- vertex_credentials = os .getenv ("GCS_PATH_SERVICE_ACCOUNT" )
247
- print ("creating fine tuning job" )
268
+ with patch (
269
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" ,
270
+ return_value = mock_response ,
271
+ ) as mock_post :
248
272
create_fine_tuning_response = await litellm .acreate_fine_tuning_job (
249
- model = "gemini-1.0-pro-002" ,
273
+ model = base_model ,
250
274
custom_llm_provider = "vertex_ai" ,
251
- training_file = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" ,
252
- vertex_project = "adroit-crow-413218" ,
253
- vertex_location = "us-central1" ,
254
- vertex_credentials = vertex_credentials ,
275
+ training_file = training_file ,
276
+ vertex_project = project_id ,
277
+ vertex_location = location ,
255
278
)
256
- print ("vertex ai create fine tuning response=" , create_fine_tuning_response )
257
279
258
- assert create_fine_tuning_response .id is not None
259
- assert create_fine_tuning_response .model == "gemini-1.0-pro-002"
260
- assert create_fine_tuning_response .object == "fine_tuning.job"
261
- except Exception :
262
- pass
280
+ # Verify the request
281
+ mock_post .assert_called_once ()
282
+
283
+ # Validate the request
284
+ assert mock_post .call_args .kwargs ["json" ] == {
285
+ "baseModel" : base_model ,
286
+ "supervisedTuningSpec" : {"training_dataset_uri" : training_file },
287
+ "tunedModelDisplayName" : None ,
288
+ }
289
+
290
+ # Verify the response
291
+ response_json = json .loads (create_fine_tuning_response .model_dump_json ())
292
+ assert (
293
+ response_json ["id" ]
294
+ == f"projects/{ project_id } /locations/{ location } /tuningJobs/{ job_id } "
295
+ )
296
+ assert response_json ["model" ] == base_model
297
+ assert response_json ["object" ] == "fine_tuning.job"
298
+ assert response_json ["fine_tuned_model" ] == tuned_model_name
299
+ assert response_json ["status" ] == "queued"
300
+ assert response_json ["training_file" ] == training_file
301
+ assert (
302
+ response_json ["created_at" ] == 1735684820
303
+ ) # Unix timestamp for create_time
304
+ assert response_json ["error" ] is None
305
+ assert response_json ["finished_at" ] is None
306
+ assert response_json ["validation_file" ] is None
307
+ assert response_json ["trained_tokens" ] is None
308
+ assert response_json ["estimated_finish" ] is None
309
+ assert response_json ["integrations" ] == []
263
310
264
311
265
312
# Testing OpenAI -> Vertex AI param mapping
@@ -276,7 +323,7 @@ def test_convert_openai_request_to_vertex_basic():
276
323
277
324
result = vertex_finetune_api .convert_openai_request_to_vertex (openai_data )
278
325
279
- print ("converted vertex ai result=" , result )
326
+ print ("converted vertex ai result=" , json . dumps ( result , indent = 4 ) )
280
327
281
328
assert result ["baseModel" ] == "text-davinci-002"
282
329
assert result ["tunedModelDisplayName" ] == "my_fine_tuned_model"
@@ -303,15 +350,57 @@ def test_convert_openai_request_to_vertex_with_adapter_size():
303
350
openai_data , adapter_size = "SMALL"
304
351
)
305
352
306
- print ("converted vertex ai result=" , result )
353
+ print ("converted vertex ai result=" , json . dumps ( result , indent = 4 ) )
307
354
308
355
assert result ["baseModel" ] == "text-davinci-002"
309
356
assert result ["tunedModelDisplayName" ] == "custom_model"
310
357
assert (
311
358
result ["supervisedTuningSpec" ]["training_dataset_uri" ]
312
359
== "gs://bucket/train.jsonl"
313
360
)
314
- assert result ["supervisedTuningSpec" ]["validation_dataset" ] is None
315
361
assert result ["supervisedTuningSpec" ]["epoch_count" ] == 5
316
362
assert result ["supervisedTuningSpec" ]["learning_rate_multiplier" ] == 0.2
317
363
assert result ["supervisedTuningSpec" ]["adapter_size" ] == "SMALL"
364
+
365
+
366
+ def test_convert_basic_openai_request_to_vertex_request ():
367
+ openai_data = FineTuningJobCreate (
368
+ training_file = "gs://bucket/train.jsonl" ,
369
+ model = "gemini-1.0-pro-002" ,
370
+ )
371
+
372
+ result = vertex_finetune_api .convert_openai_request_to_vertex (
373
+ openai_data , adapter_size = "SMALL"
374
+ )
375
+
376
+ print ("converted vertex ai result=" , json .dumps (result , indent = 4 ))
377
+
378
+ assert result ["baseModel" ] == "gemini-1.0-pro-002"
379
+ assert result ["tunedModelDisplayName" ] == None
380
+ assert (
381
+ result ["supervisedTuningSpec" ]["training_dataset_uri" ]
382
+ == "gs://bucket/train.jsonl"
383
+ )
384
+
385
+
386
+ @pytest .mark .asyncio ()
387
+ @pytest .mark .skip (reason = "skipping - we run mock tests for vertex ai" )
388
+ async def test_create_vertex_fine_tune_jobs ():
389
+ verbose_logger .setLevel (logging .DEBUG )
390
+ # load_vertex_ai_credentials()
391
+
392
+ vertex_credentials = os .getenv ("GCS_PATH_SERVICE_ACCOUNT" )
393
+ print ("creating fine tuning job" )
394
+ create_fine_tuning_response = await litellm .acreate_fine_tuning_job (
395
+ model = "gemini-1.0-pro-002" ,
396
+ custom_llm_provider = "vertex_ai" ,
397
+ training_file = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" ,
398
+ vertex_project = "adroit-crow-413218" ,
399
+ vertex_location = "us-central1" ,
400
+ vertex_credentials = vertex_credentials ,
401
+ )
402
+ print ("vertex ai create fine tuning response=" , create_fine_tuning_response )
403
+
404
+ assert create_fine_tuning_response .id is not None
405
+ assert create_fine_tuning_response .model == "gemini-1.0-pro-002"
406
+ assert create_fine_tuning_response .object == "fine_tuning.job"
0 commit comments