Skip to content

Commit 67223d1

Browse files
GuyStonenrghosh
authored andcommitted
[Data][LLM] Support OpenAI's nested image_url schema in PrepareImageStage (ray-project#56584)
Signed-off-by: Guy Stone <guys@spotify.com> Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nikhil Ghosh <nikhil@anyscale.com> Co-authored-by: Nikhil G <nrghosh@users.noreply.github.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com>
1 parent b41dd8b commit 67223d1

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

python/ray/llm/_internal/batch/stages/prepare_image_stage.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Prepare Image Stage"""
2+
23
import asyncio
34
import base64
45
import importlib
@@ -311,13 +312,17 @@ def __init__(self, data_column: str, expected_input_keys: List[str]):
311312
self.image_processor = ImageProcessor()
312313

313314
def extract_image_info(self, messages: List[Dict]) -> List[_ImageType]:
314-
"""Extract vision information such as image and video from chat messages.
315+
"""Extract image information from chat messages.
315316
316317
Args:
317318
messages: List of chat messages.
318319
319320
Returns:
320321
List of _ImageType.
322+
323+
Note:
324+
The optional 'detail' parameter from the OpenAI schema is not
325+
passed forward to downstream templates.
321326
"""
322327

323328
image_info: List[_ImageType] = []
@@ -336,7 +341,20 @@ def extract_image_info(self, messages: List[Dict]) -> List[_ImageType]:
336341
for content_item in content:
337342
if content_item["type"] not in ("image", "image_url"):
338343
continue
339-
image = content_item[content_item["type"]]
344+
345+
image_data = content_item[content_item["type"]]
346+
347+
if content_item["type"] == "image_url" and isinstance(image_data, dict):
348+
# OpenAI nested format: {"image_url": {"url": "..."}}
349+
image = image_data.get("url")
350+
if not isinstance(image, str) or not image:
351+
raise ValueError(
352+
"image_url must be an object with a non-empty 'url' string"
353+
)
354+
else:
355+
# Simple format: {"image": "..."} or {"image_url": "..."}
356+
image = image_data
357+
340358
if not isinstance(image, str) and not isinstance(
341359
image, self.Image.Image
342360
):

python/ray/llm/tests/batch/cpu/stages/test_prepare_image_stage.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,24 @@ async def test_prepare_image_udf_invalid_image_type(mock_image_processor):
250250
["https://example.com/image.jpg"],
251251
"image_url_format_no_system_prompt",
252252
),
253+
# Test OpenAI nested format without system prompt
254+
# https://github.com/openai/openai-openapi/blob/manual_spec/openapi.yaml#L1937-L1940
255+
(
256+
[
257+
{
258+
"role": "user",
259+
"content": [
260+
{
261+
"type": "image_url",
262+
"image_url": {"url": "https://example.com/image.jpg"},
263+
},
264+
{"type": "text", "text": "Describe this image"},
265+
],
266+
}
267+
],
268+
["https://example.com/image.jpg"],
269+
"openai_image_url_format_no_system_prompt",
270+
),
253271
],
254272
ids=lambda x: x if isinstance(x, str) else None,
255273
)
@@ -262,5 +280,37 @@ def test_extract_image_info(messages, expected_images, test_description):
262280
assert image_info == expected_images
263281

264282

283+
@pytest.mark.parametrize(
284+
"image_url_value,test_description",
285+
[
286+
({}, "missing_url"),
287+
({"url": 12345}, "non_string_url"),
288+
({"url": ""}, "empty_string_url"),
289+
],
290+
ids=lambda x: x if isinstance(x, str) else None,
291+
)
292+
def test_extract_image_info_invalid_nested_image_url(image_url_value, test_description):
293+
"""Test that invalid nested image_url objects raise ValueError with proper message."""
294+
udf = PrepareImageUDF(data_column="__data", expected_input_keys=["messages"])
295+
296+
messages = [
297+
{
298+
"role": "user",
299+
"content": [
300+
{
301+
"type": "image_url",
302+
"image_url": image_url_value,
303+
},
304+
{"type": "text", "text": "Describe this image"},
305+
],
306+
}
307+
]
308+
309+
with pytest.raises(
310+
ValueError, match="image_url must be an object with a non-empty 'url' string"
311+
):
312+
udf.extract_image_info(messages)
313+
314+
265315
if __name__ == "__main__":
266316
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)