1
1
import asyncio
2
2
import copy
3
3
from asyncio import Event , Task
4
- from typing import Callable , Dict , List , Optional , TYPE_CHECKING
4
+ from typing import Callable , Dict , List , Optional , TYPE_CHECKING , Union
5
5
from jina ._docarray import docarray_v2
6
6
import contextlib
7
+
7
8
if not docarray_v2 :
8
9
from docarray import DocumentArray
9
10
else :
@@ -18,16 +19,18 @@ class BatchQueue:
18
19
"""A batch queue that holds the data request and the callable to batch requests to."""
19
20
20
21
def __init__ (
21
- self ,
22
- func : Callable ,
23
- request_docarray_cls ,
24
- response_docarray_cls ,
25
- output_array_type : Optional [str ] = None ,
26
- params : Optional [Dict ] = None ,
27
- allow_concurrent : bool = False ,
28
- flush_all : bool = False ,
29
- preferred_batch_size : int = 4 ,
30
- timeout : int = 10_000 ,
22
+ self ,
23
+ func : Callable ,
24
+ request_docarray_cls ,
25
+ response_docarray_cls ,
26
+ output_array_type : Optional [str ] = None ,
27
+ params : Optional [Dict ] = None ,
28
+ allow_concurrent : bool = False ,
29
+ flush_all : bool = False ,
30
+ preferred_batch_size : int = 4 ,
31
+ timeout : int = 10_000 ,
32
+ custom_metric : Optional [Callable [['DocumentArray' ], Union [int , float ]]] = None ,
33
+ use_custom_metric : bool = False ,
31
34
) -> None :
32
35
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
33
36
if allow_concurrent and flush_all :
@@ -44,6 +47,8 @@ def __init__(
44
47
self ._response_docarray_cls = response_docarray_cls
45
48
self ._flush_all = flush_all
46
49
self ._preferred_batch_size : int = preferred_batch_size
50
+ self ._custom_metric = None if not use_custom_metric else custom_metric
51
+ self ._metric_value = 0
47
52
self ._timeout : int = timeout
48
53
self ._reset ()
49
54
self ._flush_trigger : Event = Event ()
@@ -62,20 +67,22 @@ def _reset(self) -> None:
62
67
# a list of every request ID
63
68
self ._request_idxs : List [int ] = []
64
69
self ._request_lens : List [int ] = []
70
+ self ._docs_metrics : List [int ] = []
65
71
self ._requests_completed : List [asyncio .Queue ] = []
66
72
if not docarray_v2 :
67
73
self ._big_doc : DocumentArray = DocumentArray .empty ()
68
74
else :
69
75
self ._big_doc = self ._request_docarray_cls ()
76
+ self ._metric_value = 0
70
77
71
78
self ._flush_task : Optional [Task ] = None
72
79
self ._flush_trigger : Event = Event ()
73
80
74
81
def _cancel_timer_if_pending (self ):
75
82
if (
76
- self ._timer_task
77
- and not self ._timer_task .done ()
78
- and not self ._timer_task .cancelled ()
83
+ self ._timer_task
84
+ and not self ._timer_task .done ()
85
+ and not self ._timer_task .cancelled ()
79
86
):
80
87
self ._timer_finished = False
81
88
self ._timer_task .cancel ()
@@ -91,7 +98,7 @@ async def _sleep_then_set(self):
91
98
self ._flush_trigger .set ()
92
99
self ._timer_finished = True
93
100
94
- async def push (self , request : DataRequest , http = False ) -> asyncio .Queue :
101
+ async def push (self , request : DataRequest , http = False ) -> asyncio .Queue :
95
102
"""Append request to the the list of requests to be processed.
96
103
97
104
This method creates an asyncio Queue for that request and keeps track of it. It returns
@@ -116,12 +123,18 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
116
123
self ._big_doc .extend (docs )
117
124
next_req_idx = len (self ._requests )
118
125
num_docs = len (docs )
126
+ metric_value = num_docs
127
+ if self ._custom_metric is not None :
128
+ metrics = [self ._custom_metric (doc ) for doc in docs ]
129
+ metric_value += sum (metrics )
130
+ self ._docs_metrics .extend (metrics )
131
+ self ._metric_value += metric_value
119
132
self ._request_idxs .extend ([next_req_idx ] * num_docs )
120
- self ._request_lens .append (len ( docs ) )
133
+ self ._request_lens .append (num_docs )
121
134
self ._requests .append (request )
122
135
queue = asyncio .Queue ()
123
136
self ._requests_completed .append (queue )
124
- if len ( self ._big_doc ) >= self ._preferred_batch_size :
137
+ if self ._metric_value >= self ._preferred_batch_size :
125
138
self ._flush_trigger .set ()
126
139
127
140
return queue
@@ -132,10 +145,10 @@ async def _await_then_flush(self, http=False) -> None:
132
145
"""
133
146
134
147
def _get_docs_groups_completed_request_indexes (
135
- non_assigned_docs ,
136
- non_assigned_docs_reqs_idx ,
137
- sum_from_previous_mini_batch_in_first_req_idx ,
138
- requests_lens_in_batch ,
148
+ non_assigned_docs ,
149
+ non_assigned_docs_reqs_idx ,
150
+ sum_from_previous_mini_batch_in_first_req_idx ,
151
+ requests_lens_in_batch ,
139
152
):
140
153
"""
141
154
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
@@ -160,9 +173,9 @@ def _get_docs_groups_completed_request_indexes(
160
173
)
161
174
if req_idx > min_involved_req_idx :
162
175
request_bucket = non_assigned_docs [
163
- num_distributed_docs : num_distributed_docs
164
- + num_docs_in_req_idx
165
- ]
176
+ num_distributed_docs : num_distributed_docs
177
+ + num_docs_in_req_idx
178
+ ]
166
179
num_distributed_docs += num_docs_in_req_idx
167
180
completed_req_idx .append (min_involved_req_idx )
168
181
min_involved_req_idx = req_idx
@@ -171,25 +184,25 @@ def _get_docs_groups_completed_request_indexes(
171
184
num_docs_in_req_idx += 1
172
185
173
186
if (
174
- req_idx not in completed_req_idx
175
- and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
176
- == requests_lens_in_batch [req_idx ]
187
+ req_idx not in completed_req_idx
188
+ and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
189
+ == requests_lens_in_batch [req_idx ]
177
190
):
178
191
completed_req_idx .append (req_idx )
179
192
request_bucket = non_assigned_docs [
180
- num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
181
- ]
193
+ num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
194
+ ]
182
195
distributed_requests .append (request_bucket )
183
196
184
197
return distributed_requests , completed_req_idx
185
198
186
199
async def _assign_results (
187
- non_assigned_docs ,
188
- non_assigned_docs_reqs_idx ,
189
- sum_from_previous_mini_batch_in_first_req_idx ,
190
- requests_lens_in_batch ,
191
- requests_in_batch ,
192
- requests_completed_in_batch ,
200
+ non_assigned_docs ,
201
+ non_assigned_docs_reqs_idx ,
202
+ sum_from_previous_mini_batch_in_first_req_idx ,
203
+ requests_lens_in_batch ,
204
+ requests_in_batch ,
205
+ requests_completed_in_batch ,
193
206
):
194
207
"""
195
208
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
@@ -220,7 +233,7 @@ async def _assign_results(
220
233
request = requests_in_batch [request_idx ]
221
234
request_completed = requests_completed_in_batch [request_idx ]
222
235
if http is False or self ._output_array_type is not None :
223
- request .direct_docs = None # batch queue will work in place, therefore result will need to read from data.
236
+ request .direct_docs = None # batch queue will work in place, therefore result will need to read from data.
224
237
request .data .set_docs_convert_arrays (
225
238
docs_group , ndarray_type = self ._output_array_type
226
239
)
@@ -230,22 +243,39 @@ async def _assign_results(
230
243
231
244
return num_assigned_docs
232
245
233
- def batch (iterable_1 , iterable_2 , n :Optional [int ] = 1 ):
246
+ def batch (iterable_1 , iterable_2 , n : Optional [int ] = 1 , iterable_metrics : Optional = None ):
234
247
if n is None :
235
248
yield iterable_1 , iterable_2
236
249
return
237
- items = len (iterable_1 )
238
- for ndx in range (0 , items , n ):
239
- yield iterable_1 [ndx : min (ndx + n , items )], iterable_2 [
240
- ndx : min (ndx + n , items )
241
- ]
250
+ elif iterable_metrics is None :
251
+ items = len (iterable_1 )
252
+ for ndx in range (0 , items , n ):
253
+ yield iterable_1 [ndx : min (ndx + n , items )], iterable_2 [
254
+ ndx : min (ndx + n , items )
255
+ ]
256
+ else :
257
+ batch_idx = 0
258
+ batch_weight = 0
259
+
260
+ for i , (item , weight ) in enumerate (zip (iterable_1 , iterable_metrics )):
261
+ batch_weight += weight
262
+
263
+ if batch_weight >= n :
264
+ yield iterable_1 [batch_idx : i + 1 ], iterable_2 [batch_idx : i + 1 ]
265
+ batch_idx = i + 1
266
+ batch_weight = 0
267
+
268
+ # Yield any remaining items
269
+ if batch_weight > 0 :
270
+ yield iterable_1 [batch_idx : len (iterable_1 )], iterable_2 [batch_idx : len (iterable_1 )]
242
271
243
272
await self ._flush_trigger .wait ()
244
273
# writes to shared data between tasks need to be mutually exclusive
245
274
async with self ._data_lock :
246
275
big_doc_in_batch = copy .copy (self ._big_doc )
247
276
requests_idxs_in_batch = copy .copy (self ._request_idxs )
248
277
requests_lens_in_batch = copy .copy (self ._request_lens )
278
+ docs_metrics_in_batch = copy .copy (self ._docs_metrics )
249
279
requests_in_batch = copy .copy (self ._requests )
250
280
requests_completed_in_batch = copy .copy (self ._requests_completed )
251
281
@@ -263,7 +293,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
263
293
non_assigned_to_response_request_idxs = []
264
294
sum_from_previous_first_req_idx = 0
265
295
for docs_inner_batch , req_idxs in batch (
266
- big_doc_in_batch , requests_idxs_in_batch , self ._preferred_batch_size if not self ._flush_all else None
296
+ big_doc_in_batch , requests_idxs_in_batch ,
297
+ self ._preferred_batch_size if not self ._flush_all else None , docs_metrics_in_batch if self ._custom_metric is not None else None
267
298
):
268
299
involved_requests_min_indx = req_idxs [0 ]
269
300
involved_requests_max_indx = req_idxs [- 1 ]
@@ -278,8 +309,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
278
309
)
279
310
# Output validation
280
311
if (docarray_v2 and isinstance (batch_res_docs , DocList )) or (
281
- not docarray_v2
282
- and isinstance (batch_res_docs , DocumentArray )
312
+ not docarray_v2
313
+ and isinstance (batch_res_docs , DocumentArray )
283
314
):
284
315
if not len (batch_res_docs ) == input_len_before_call :
285
316
raise ValueError (
@@ -301,8 +332,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
301
332
except Exception as exc :
302
333
# All the requests containing docs in this Exception should be raising it
303
334
for request_full in requests_completed_in_batch [
304
- involved_requests_min_indx : involved_requests_max_indx + 1
305
- ]:
335
+ involved_requests_min_indx : involved_requests_max_indx + 1
336
+ ]:
306
337
await request_full .put (exc )
307
338
else :
308
339
# We need to attribute the docs to their requests
@@ -320,11 +351,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
320
351
)
321
352
322
353
sum_from_previous_first_req_idx = (
323
- len (non_assigned_to_response_docs ) - num_assigned_docs
354
+ len (non_assigned_to_response_docs ) - num_assigned_docs
324
355
)
325
356
non_assigned_to_response_docs = non_assigned_to_response_docs [
326
- num_assigned_docs :
327
- ]
357
+ num_assigned_docs :
358
+ ]
328
359
non_assigned_to_response_request_idxs = (
329
360
non_assigned_to_response_request_idxs [num_assigned_docs :]
330
361
)
0 commit comments