Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions python/ray/llm/_internal/batch/stages/prepare_image_stage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Prepare Image Stage"""

import asyncio
import base64
import importlib
Expand Down Expand Up @@ -311,13 +312,17 @@ def __init__(self, data_column: str, expected_input_keys: List[str]):
self.image_processor = ImageProcessor()

def extract_image_info(self, messages: List[Dict]) -> List[_ImageType]:
"""Extract vision information such as image and video from chat messages.
"""Extract image information from chat messages.
Args:
messages: List of chat messages.
Returns:
List of _ImageType.
Note:
The optional 'detail' parameter from the OpenAI schema is not
passed forward to downstream templates.
"""

image_info: List[_ImageType] = []
Expand All @@ -336,7 +341,20 @@ def extract_image_info(self, messages: List[Dict]) -> List[_ImageType]:
for content_item in content:
if content_item["type"] not in ("image", "image_url"):
continue
image = content_item[content_item["type"]]

image_data = content_item[content_item["type"]]

if content_item["type"] == "image_url" and isinstance(image_data, dict):
# OpenAI nested format: {"image_url": {"url": "..."}}
image = image_data.get("url")
if not isinstance(image, str) or not image:
raise ValueError(
"image_url must be an object with a non-empty 'url' string"
)
else:
# Simple format: {"image": "..."} or {"image_url": "..."}
image = image_data

if not isinstance(image, str) and not isinstance(
image, self.Image.Image
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@ async def test_prepare_image_udf_invalid_image_type(mock_image_processor):
["https://example.com/image.jpg"],
"image_url_format_no_system_prompt",
),
# Test OpenAI nested format without system prompt
# https://github.com/openai/openai-openapi/blob/manual_spec/openapi.yaml#L1937-L1940
(
[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.jpg"},
},
{"type": "text", "text": "Describe this image"},
],
}
],
["https://example.com/image.jpg"],
"openai_image_url_format_no_system_prompt",
),
],
ids=lambda x: x if isinstance(x, str) else None,
)
Expand All @@ -262,5 +280,37 @@ def test_extract_image_info(messages, expected_images, test_description):
assert image_info == expected_images


@pytest.mark.parametrize(
"image_url_value,test_description",
[
({}, "missing_url"),
({"url": 12345}, "non_string_url"),
({"url": ""}, "empty_string_url"),
],
ids=lambda x: x if isinstance(x, str) else None,
)
def test_extract_image_info_invalid_nested_image_url(image_url_value, test_description):
"""Test that invalid nested image_url objects raise ValueError with proper message."""
udf = PrepareImageUDF(data_column="__data", expected_input_keys=["messages"])

messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": image_url_value,
},
{"type": "text", "text": "Describe this image"},
],
}
]

with pytest.raises(
ValueError, match="image_url must be an object with a non-empty 'url' string"
):
udf.extract_image_info(messages)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))