3
3
from asyncio import Event , Task
4
4
from typing import Callable , Dict , List , Optional , TYPE_CHECKING , Union
5
5
from jina ._docarray import docarray_v2
6
- import contextlib
7
6
8
7
if not docarray_v2 :
9
8
from docarray import DocumentArray
@@ -25,18 +24,13 @@ def __init__(
25
24
response_docarray_cls ,
26
25
output_array_type : Optional [str ] = None ,
27
26
params : Optional [Dict ] = None ,
28
- allow_concurrent : bool = False ,
29
27
flush_all : bool = False ,
30
28
preferred_batch_size : int = 4 ,
31
29
timeout : int = 10_000 ,
32
30
custom_metric : Optional [Callable [['DocumentArray' ], Union [int , float ]]] = None ,
33
31
use_custom_metric : bool = False ,
34
32
) -> None :
35
33
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
36
- if allow_concurrent and flush_all :
37
- self ._data_lock = contextlib .AsyncExitStack ()
38
- else :
39
- self ._data_lock = asyncio .Lock ()
40
34
self .func = func
41
35
if params is None :
42
36
params = dict ()
@@ -64,7 +58,7 @@ def __str__(self) -> str:
64
58
def _reset (self ) -> None :
65
59
"""Set all events and reset the batch queue."""
66
60
self ._requests : List [DataRequest ] = []
67
- # a list of every request ID
61
+ # a list of every request idx inside self._requests
68
62
self ._request_idxs : List [int ] = []
69
63
self ._request_lens : List [int ] = []
70
64
self ._docs_metrics : List [int ] = []
@@ -116,26 +110,24 @@ async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
116
110
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
117
111
# before the `flush` task processes it.
118
112
self ._start_timer ()
119
- async with self ._data_lock :
120
- if not self ._flush_task :
121
- self ._flush_task = asyncio .create_task (self ._await_then_flush (http ))
122
-
123
- self ._big_doc .extend (docs )
124
- next_req_idx = len (self ._requests )
125
- num_docs = len (docs )
126
- metric_value = num_docs
127
- if self ._custom_metric is not None :
128
- metrics = [self ._custom_metric (doc ) for doc in docs ]
129
- metric_value += sum (metrics )
130
- self ._docs_metrics .extend (metrics )
131
- self ._metric_value += metric_value
132
- self ._request_idxs .extend ([next_req_idx ] * num_docs )
133
- self ._request_lens .append (num_docs )
134
- self ._requests .append (request )
135
- queue = asyncio .Queue ()
136
- self ._requests_completed .append (queue )
137
- if self ._metric_value >= self ._preferred_batch_size :
138
- self ._flush_trigger .set ()
113
+ if not self ._flush_task :
114
+ self ._flush_task = asyncio .create_task (self ._await_then_flush (http ))
115
+ self ._big_doc .extend (docs )
116
+ next_req_idx = len (self ._requests )
117
+ num_docs = len (docs )
118
+ metric_value = num_docs
119
+ if self ._custom_metric is not None :
120
+ metrics = [self ._custom_metric (doc ) for doc in docs ]
121
+ metric_value += sum (metrics )
122
+ self ._docs_metrics .extend (metrics )
123
+ self ._metric_value += metric_value
124
+ self ._request_idxs .extend ([next_req_idx ] * num_docs )
125
+ self ._request_lens .append (num_docs )
126
+ self ._requests .append (request )
127
+ queue = asyncio .Queue ()
128
+ self ._requests_completed .append (queue )
129
+ if self ._metric_value >= self ._preferred_batch_size :
130
+ self ._flush_trigger .set ()
139
131
140
132
return queue
141
133
@@ -271,96 +263,76 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
271
263
272
264
await self ._flush_trigger .wait ()
273
265
# writes to shared data between tasks need to be mutually exclusive
274
- async with self ._data_lock :
275
- big_doc_in_batch = copy .copy (self ._big_doc )
276
- requests_idxs_in_batch = copy .copy (self ._request_idxs )
277
- requests_lens_in_batch = copy .copy (self ._request_lens )
278
- docs_metrics_in_batch = copy .copy (self ._docs_metrics )
279
- requests_in_batch = copy .copy (self ._requests )
280
- requests_completed_in_batch = copy .copy (self ._requests_completed )
281
-
282
- self ._reset ()
283
-
284
- # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
285
- # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
286
- # communicate that the request has been processed properly.
287
-
288
- if not docarray_v2 :
289
- non_assigned_to_response_docs : DocumentArray = DocumentArray .empty ()
290
- else :
291
- non_assigned_to_response_docs = self ._response_docarray_cls ()
266
+ big_doc_in_batch = copy .copy (self ._big_doc )
267
+ requests_idxs_in_batch = copy .copy (self ._request_idxs )
268
+ requests_lens_in_batch = copy .copy (self ._request_lens )
269
+ docs_metrics_in_batch = copy .copy (self ._docs_metrics )
270
+ requests_in_batch = copy .copy (self ._requests )
271
+ requests_completed_in_batch = copy .copy (self ._requests_completed )
292
272
293
- non_assigned_to_response_request_idxs = []
294
- sum_from_previous_first_req_idx = 0
295
- for docs_inner_batch , req_idxs in batch (
296
- big_doc_in_batch , requests_idxs_in_batch ,
297
- self ._preferred_batch_size if not self ._flush_all else None , docs_metrics_in_batch if self ._custom_metric is not None else None
298
- ):
299
- involved_requests_min_indx = req_idxs [0 ]
300
- involved_requests_max_indx = req_idxs [- 1 ]
301
- input_len_before_call : int = len (docs_inner_batch )
302
- batch_res_docs = None
303
- try :
304
- batch_res_docs = await self .func (
305
- docs = docs_inner_batch ,
306
- parameters = self .params ,
307
- docs_matrix = None , # joining manually with batch queue is not supported right now
308
- tracing_context = None ,
309
- )
310
- # Output validation
311
- if (docarray_v2 and isinstance (batch_res_docs , DocList )) or (
312
- not docarray_v2
313
- and isinstance (batch_res_docs , DocumentArray )
314
- ):
315
- if not len (batch_res_docs ) == input_len_before_call :
316
- raise ValueError (
317
- f'Dynamic Batching requires input size to equal output size. Expected output size { input_len_before_call } , but got { len (batch_res_docs )} '
318
- )
319
- elif batch_res_docs is None :
320
- if not len (docs_inner_batch ) == input_len_before_call :
321
- raise ValueError (
322
- f'Dynamic Batching requires input size to equal output size. Expected output size { input_len_before_call } , but got { len (docs_inner_batch )} '
323
- )
324
- else :
325
- array_name = (
326
- 'DocumentArray' if not docarray_v2 else 'DocList'
273
+ self ._reset ()
274
+
275
+ # At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
276
+ # requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
277
+ # communicate that the request has been processed properly.
278
+
279
+ if not docarray_v2 :
280
+ non_assigned_to_response_docs : DocumentArray = DocumentArray .empty ()
281
+ else :
282
+ non_assigned_to_response_docs = self ._response_docarray_cls ()
283
+
284
+ non_assigned_to_response_request_idxs = []
285
+ sum_from_previous_first_req_idx = 0
286
+ for docs_inner_batch , req_idxs in batch (
287
+ big_doc_in_batch , requests_idxs_in_batch ,
288
+ self ._preferred_batch_size if not self ._flush_all else None , docs_metrics_in_batch if self ._custom_metric is not None else None
289
+ ):
290
+ involved_requests_min_indx = req_idxs [0 ]
291
+ involved_requests_max_indx = req_idxs [- 1 ]
292
+ input_len_before_call : int = len (docs_inner_batch )
293
+ batch_res_docs = None
294
+ try :
295
+ batch_res_docs = await self .func (
296
+ docs = docs_inner_batch ,
297
+ parameters = self .params ,
298
+ docs_matrix = None , # joining manually with batch queue is not supported right now
299
+ tracing_context = None ,
300
+ )
301
+ # Output validation
302
+ if (docarray_v2 and isinstance (batch_res_docs , DocList )) or (
303
+ not docarray_v2
304
+ and isinstance (batch_res_docs , DocumentArray )
305
+ ):
306
+ if not len (batch_res_docs ) == input_len_before_call :
307
+ raise ValueError (
308
+ f'Dynamic Batching requires input size to equal output size. Expected output size { input_len_before_call } , but got { len (batch_res_docs )} '
327
309
)
328
- raise TypeError (
329
- f'The return type must be { array_name } / `None` when using dynamic batching, '
330
- f'but getting { batch_res_docs !r} '
310
+ elif batch_res_docs is None :
311
+ if not len (docs_inner_batch ) == input_len_before_call :
312
+ raise ValueError (
313
+ f'Dynamic Batching requires input size to equal output size. Expected output size { input_len_before_call } , but got { len (docs_inner_batch )} '
331
314
)
332
- except Exception as exc :
333
- # All the requests containing docs in this Exception should be raising it
334
- for request_full in requests_completed_in_batch [
335
- involved_requests_min_indx : involved_requests_max_indx + 1
336
- ]:
337
- await request_full .put (exc )
338
315
else :
339
- # We need to attribute the docs to their requests
340
- non_assigned_to_response_docs .extend (
341
- batch_res_docs or docs_inner_batch
316
+ array_name = (
317
+ 'DocumentArray' if not docarray_v2 else 'DocList'
342
318
)
343
- non_assigned_to_response_request_idxs .extend (req_idxs )
344
- num_assigned_docs = await _assign_results (
345
- non_assigned_to_response_docs ,
346
- non_assigned_to_response_request_idxs ,
347
- sum_from_previous_first_req_idx ,
348
- requests_lens_in_batch ,
349
- requests_in_batch ,
350
- requests_completed_in_batch ,
319
+ raise TypeError (
320
+ f'The return type must be { array_name } / `None` when using dynamic batching, '
321
+ f'but getting { batch_res_docs !r} '
351
322
)
352
-
353
- sum_from_previous_first_req_idx = (
354
- len (non_assigned_to_response_docs ) - num_assigned_docs
355
- )
356
- non_assigned_to_response_docs = non_assigned_to_response_docs [
357
- num_assigned_docs :
358
- ]
359
- non_assigned_to_response_request_idxs = (
360
- non_assigned_to_response_request_idxs [num_assigned_docs :]
361
- )
362
- if len (non_assigned_to_response_request_idxs ) > 0 :
363
- _ = await _assign_results (
323
+ except Exception as exc :
324
+ # All the requests containing docs in this Exception should be raising it
325
+ for request_full in requests_completed_in_batch [
326
+ involved_requests_min_indx : involved_requests_max_indx + 1
327
+ ]:
328
+ await request_full .put (exc )
329
+ else :
330
+ # We need to attribute the docs to their requests
331
+ non_assigned_to_response_docs .extend (
332
+ batch_res_docs or docs_inner_batch
333
+ )
334
+ non_assigned_to_response_request_idxs .extend (req_idxs )
335
+ num_assigned_docs = await _assign_results (
364
336
non_assigned_to_response_docs ,
365
337
non_assigned_to_response_request_idxs ,
366
338
sum_from_previous_first_req_idx ,
@@ -369,6 +341,26 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
369
341
requests_completed_in_batch ,
370
342
)
371
343
344
+ sum_from_previous_first_req_idx = (
345
+ len (non_assigned_to_response_docs ) - num_assigned_docs
346
+ )
347
+ non_assigned_to_response_docs = non_assigned_to_response_docs [
348
+ num_assigned_docs :
349
+ ]
350
+ non_assigned_to_response_request_idxs = (
351
+ non_assigned_to_response_request_idxs [num_assigned_docs :]
352
+ )
353
+ if len (non_assigned_to_response_request_idxs ) > 0 :
354
+ _ = await _assign_results (
355
+ non_assigned_to_response_docs ,
356
+ non_assigned_to_response_request_idxs ,
357
+ sum_from_previous_first_req_idx ,
358
+ requests_lens_in_batch ,
359
+ requests_in_batch ,
360
+ requests_completed_in_batch ,
361
+ )
362
+
363
+
372
364
async def close (self ):
373
365
"""Closes the batch queue by flushing pending requests."""
374
366
if not self ._is_closed :
0 commit comments