1
1
from collections .abc import Iterable
2
2
from concurrent .futures import ThreadPoolExecutor
3
3
4
+ from transformers import StoppingCriteria
5
+
4
6
from docling .datamodel .base_models import Page , VlmPrediction
5
7
from docling .datamodel .document import ConversionResult
6
8
from docling .datamodel .pipeline_options_vlm_model import ApiVlmOptions
7
9
from docling .exceptions import OperationNotAllowed
8
10
from docling .models .base_model import BasePageModel
9
- from docling .utils .api_image_request import api_image_request
11
+ from docling .models .utils .generation_utils import GenerationStopper
12
+ from docling .utils .api_image_request import (
13
+ api_image_request ,
14
+ api_image_request_streaming ,
15
+ )
10
16
from docling .utils .profiling import TimeRecorder
11
17
12
18
@@ -41,19 +47,43 @@ def _vlm_request(page):
41
47
assert page ._backend is not None
42
48
if not page ._backend .is_valid ():
43
49
return page
44
- else :
45
- with TimeRecorder (conv_res , "vlm" ):
46
- assert page .size is not None
47
50
48
- hi_res_image = page .get_image (
49
- scale = self .vlm_options .scale , max_size = self .vlm_options .max_size
50
- )
51
- assert hi_res_image is not None
52
- if hi_res_image :
53
- if hi_res_image .mode != "RGB" :
54
- hi_res_image = hi_res_image .convert ("RGB" )
51
+ with TimeRecorder (conv_res , "vlm" ):
52
+ assert page .size is not None
53
+
54
+ hi_res_image = page .get_image (
55
+ scale = self .vlm_options .scale , max_size = self .vlm_options .max_size
56
+ )
57
+ assert hi_res_image is not None
58
+ if hi_res_image and hi_res_image .mode != "RGB" :
59
+ hi_res_image = hi_res_image .convert ("RGB" )
55
60
56
- prompt = self .vlm_options .build_prompt (page .parsed_page )
61
+ prompt = self .vlm_options .build_prompt (page .parsed_page )
62
+
63
+ if self .vlm_options .custom_stopping_criteria :
64
+ # Instantiate any GenerationStopper classes before passing to streaming
65
+ instantiated_stoppers = []
66
+ for criteria in self .vlm_options .custom_stopping_criteria :
67
+ if isinstance (criteria , GenerationStopper ):
68
+ instantiated_stoppers .append (criteria )
69
+ elif isinstance (criteria , type ) and issubclass (
70
+ criteria , GenerationStopper
71
+ ):
72
+ instantiated_stoppers .append (criteria ())
73
+ # Skip non-GenerationStopper criteria (should have been caught in validation)
74
+
75
+ # Streaming path with early abort support
76
+ page_tags = api_image_request_streaming (
77
+ image = hi_res_image ,
78
+ prompt = prompt ,
79
+ url = self .vlm_options .url ,
80
+ timeout = self .timeout ,
81
+ headers = self .vlm_options .headers ,
82
+ generation_stoppers = instantiated_stoppers ,
83
+ ** self .params ,
84
+ )
85
+ else :
86
+ # Non-streaming fallback (existing behavior)
57
87
page_tags = api_image_request (
58
88
image = hi_res_image ,
59
89
prompt = prompt ,
@@ -63,10 +93,10 @@ def _vlm_request(page):
63
93
** self .params ,
64
94
)
65
95
66
- page_tags = self .vlm_options .decode_response (page_tags )
67
- page .predictions .vlm_response = VlmPrediction (text = page_tags )
96
+ page_tags = self .vlm_options .decode_response (page_tags )
97
+ page .predictions .vlm_response = VlmPrediction (text = page_tags )
68
98
69
- return page
99
+ return page
70
100
71
101
with ThreadPoolExecutor (max_workers = self .concurrency ) as executor :
72
102
yield from executor .map (_vlm_request , page_batch )
0 commit comments