Skip to content

Commit c5b7363

Browse files
committed
Add streaming support and custom GenerationStopper support for ApiVlmModel
Signed-off-by: Christoph Auer <[email protected]>
1 parent 1c781a1 commit c5b7363

File tree

3 files changed

+147
-16
lines changed

3 files changed

+147
-16
lines changed

docling/datamodel/pipeline_options_vlm_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,6 @@ class ApiVlmOptions(BaseVlmOptions):
104104
timeout: float = 60
105105
concurrency: int = 1
106106
response_format: ResponseFormat
107+
108+
stop_strings: List[str] = []
109+
custom_stopping_criteria: List[Union[GenerationStopper]] = []

docling/models/api_vlm_model.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from collections.abc import Iterable
22
from concurrent.futures import ThreadPoolExecutor
33

4+
from transformers import StoppingCriteria
5+
46
from docling.datamodel.base_models import Page, VlmPrediction
57
from docling.datamodel.document import ConversionResult
68
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
79
from docling.exceptions import OperationNotAllowed
810
from docling.models.base_model import BasePageModel
9-
from docling.utils.api_image_request import api_image_request
11+
from docling.models.utils.generation_utils import GenerationStopper
12+
from docling.utils.api_image_request import (
13+
api_image_request,
14+
api_image_request_streaming,
15+
)
1016
from docling.utils.profiling import TimeRecorder
1117

1218

@@ -41,19 +47,43 @@ def _vlm_request(page):
4147
assert page._backend is not None
4248
if not page._backend.is_valid():
4349
return page
44-
else:
45-
with TimeRecorder(conv_res, "vlm"):
46-
assert page.size is not None
4750

48-
hi_res_image = page.get_image(
49-
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
50-
)
51-
assert hi_res_image is not None
52-
if hi_res_image:
53-
if hi_res_image.mode != "RGB":
54-
hi_res_image = hi_res_image.convert("RGB")
51+
with TimeRecorder(conv_res, "vlm"):
52+
assert page.size is not None
53+
54+
hi_res_image = page.get_image(
55+
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
56+
)
57+
assert hi_res_image is not None
58+
if hi_res_image and hi_res_image.mode != "RGB":
59+
hi_res_image = hi_res_image.convert("RGB")
5560

56-
prompt = self.vlm_options.build_prompt(page.parsed_page)
61+
prompt = self.vlm_options.build_prompt(page.parsed_page)
62+
63+
if self.vlm_options.custom_stopping_criteria:
64+
# Instantiate any GenerationStopper classes before passing to streaming
65+
instantiated_stoppers = []
66+
for criteria in self.vlm_options.custom_stopping_criteria:
67+
if isinstance(criteria, GenerationStopper):
68+
instantiated_stoppers.append(criteria)
69+
elif isinstance(criteria, type) and issubclass(
70+
criteria, GenerationStopper
71+
):
72+
instantiated_stoppers.append(criteria())
73+
# Skip non-GenerationStopper criteria (should have been caught in validation)
74+
75+
# Streaming path with early abort support
76+
page_tags = api_image_request_streaming(
77+
image=hi_res_image,
78+
prompt=prompt,
79+
url=self.vlm_options.url,
80+
timeout=self.timeout,
81+
headers=self.vlm_options.headers,
82+
generation_stoppers=instantiated_stoppers,
83+
**self.params,
84+
)
85+
else:
86+
# Non-streaming fallback (existing behavior)
5787
page_tags = api_image_request(
5888
image=hi_res_image,
5989
prompt=prompt,
@@ -63,10 +93,10 @@ def _vlm_request(page):
6393
**self.params,
6494
)
6595

66-
page_tags = self.vlm_options.decode_response(page_tags)
67-
page.predictions.vlm_response = VlmPrediction(text=page_tags)
96+
page_tags = self.vlm_options.decode_response(page_tags)
97+
page.predictions.vlm_response = VlmPrediction(text=page_tags)
6898

69-
return page
99+
return page
70100

71101
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
72102
yield from executor.map(_vlm_request, page_batch)

docling/utils/api_image_request.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import base64
2+
import json
23
import logging
34
from io import BytesIO
4-
from typing import Dict, Optional
5+
from typing import Dict, List, Optional
56

67
import requests
78
from PIL import Image
89
from pydantic import AnyUrl
910

1011
from docling.datamodel.base_models import OpenAiApiResponse
12+
from docling.models.utils.generation_utils import GenerationStopper
1113

1214
_log = logging.getLogger(__name__)
1315

@@ -59,3 +61,99 @@ def api_image_request(
5961
api_resp = OpenAiApiResponse.model_validate_json(r.text)
6062
generated_text = api_resp.choices[0].message.content.strip()
6163
return generated_text
64+
65+
66+
def api_image_request_streaming(
67+
image: Image.Image,
68+
prompt: str,
69+
url: AnyUrl,
70+
*,
71+
timeout: float = 20,
72+
headers: Optional[Dict[str, str]] = None,
73+
generation_stoppers: List[GenerationStopper] = [],
74+
**params,
75+
) -> str:
76+
"""
77+
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
78+
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
79+
Accumulates text and calls stopper.should_stop(window) as partials arrive.
80+
If stopper triggers, the HTTP connection is closed to abort server-side generation.
81+
"""
82+
img_io = BytesIO()
83+
image.save(img_io, "PNG")
84+
image_b64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
85+
86+
messages = [
87+
{
88+
"role": "user",
89+
"content": [
90+
{
91+
"type": "image_url",
92+
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
93+
},
94+
{"type": "text", "text": prompt},
95+
],
96+
}
97+
]
98+
99+
payload = {
100+
"messages": messages,
101+
"stream": True, # <-- critical for SSE streaming
102+
**params,
103+
}
104+
105+
# Some servers require Accept: text/event-stream for SSE.
106+
# It's safe to set it; OpenAI-compatible servers tolerate it.
107+
hdrs = {"Accept": "text/event-stream", **(headers or {})}
108+
109+
# Stream the HTTP response
110+
with requests.post(
111+
str(url), headers=hdrs, json=payload, timeout=timeout, stream=True
112+
) as r:
113+
if not r.ok:
114+
_log.error(f"Error calling the API (streaming). Response was {r.text}")
115+
r.raise_for_status()
116+
117+
full_text = []
118+
for raw_line in r.iter_lines(decode_unicode=True):
119+
if not raw_line: # keep-alives / blank lines
120+
continue
121+
if not raw_line.startswith("data:"):
122+
# Some proxies inject comments; ignore anything not starting with 'data:'
123+
continue
124+
125+
data = raw_line[len("data:") :].strip()
126+
if data == "[DONE]":
127+
break
128+
129+
try:
130+
obj = json.loads(data)
131+
except json.JSONDecodeError:
132+
_log.info("Skipping non-JSON SSE chunk: %r", data[:200])
133+
continue
134+
135+
# OpenAI-compatible delta format
136+
# obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls)
137+
try:
138+
delta = obj["choices"][0].get("delta") or {}
139+
piece = delta.get("content") or ""
140+
except (KeyError, IndexError) as e:
141+
_log.debug("Unexpected SSE chunk shape: %s", e)
142+
piece = ""
143+
144+
if piece:
145+
full_text.append(piece)
146+
for stopper in generation_stoppers:
147+
# Respect stopper's lookback window. We use a simple string window which
148+
# works with your regex-based stopper; no tokenizer needed.
149+
lookback = max(1, stopper.lookback_tokens())
150+
window = "".join(full_text)[-lookback:]
151+
if stopper.should_stop(window):
152+
# Closing the socket signals cancel to vLLM/OpenAI-compatible servers.
153+
# vLLM aborts the request when the client disconnects. :contentReference[oaicite:2]{index=2}
154+
try:
155+
r.close()
156+
finally:
157+
break
158+
159+
return "".join(full_text)

0 commit comments

Comments
 (0)