Skip to content

Commit 3894353

Browse files
author
Joan Fontanals
authored
test: test no data lock in batch queue (#6201)
1 parent 246f596 commit 3894353

File tree

4 files changed

+125
-221
lines changed

4 files changed

+125
-221
lines changed

jina/serve/runtimes/worker/batch_queue.py

Lines changed: 103 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from asyncio import Event, Task
44
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
55
from jina._docarray import docarray_v2
6-
import contextlib
76

87
if not docarray_v2:
98
from docarray import DocumentArray
@@ -25,18 +24,13 @@ def __init__(
2524
response_docarray_cls,
2625
output_array_type: Optional[str] = None,
2726
params: Optional[Dict] = None,
28-
allow_concurrent: bool = False,
2927
flush_all: bool = False,
3028
preferred_batch_size: int = 4,
3129
timeout: int = 10_000,
3230
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
3331
use_custom_metric: bool = False,
3432
) -> None:
3533
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
36-
if allow_concurrent and flush_all:
37-
self._data_lock = contextlib.AsyncExitStack()
38-
else:
39-
self._data_lock = asyncio.Lock()
4034
self.func = func
4135
if params is None:
4236
params = dict()
@@ -64,7 +58,7 @@ def __str__(self) -> str:
6458
def _reset(self) -> None:
6559
"""Set all events and reset the batch queue."""
6660
self._requests: List[DataRequest] = []
67-
# a list of every request ID
61+
# a list of every request idx inside self._requests
6862
self._request_idxs: List[int] = []
6963
self._request_lens: List[int] = []
7064
self._docs_metrics: List[int] = []
@@ -116,26 +110,24 @@ async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
116110
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
117111
# before the `flush` task processes it.
118112
self._start_timer()
119-
async with self._data_lock:
120-
if not self._flush_task:
121-
self._flush_task = asyncio.create_task(self._await_then_flush(http))
122-
123-
self._big_doc.extend(docs)
124-
next_req_idx = len(self._requests)
125-
num_docs = len(docs)
126-
metric_value = num_docs
127-
if self._custom_metric is not None:
128-
metrics = [self._custom_metric(doc) for doc in docs]
129-
metric_value += sum(metrics)
130-
self._docs_metrics.extend(metrics)
131-
self._metric_value += metric_value
132-
self._request_idxs.extend([next_req_idx] * num_docs)
133-
self._request_lens.append(num_docs)
134-
self._requests.append(request)
135-
queue = asyncio.Queue()
136-
self._requests_completed.append(queue)
137-
if self._metric_value >= self._preferred_batch_size:
138-
self._flush_trigger.set()
113+
if not self._flush_task:
114+
self._flush_task = asyncio.create_task(self._await_then_flush(http))
115+
self._big_doc.extend(docs)
116+
next_req_idx = len(self._requests)
117+
num_docs = len(docs)
118+
metric_value = num_docs
119+
if self._custom_metric is not None:
120+
metrics = [self._custom_metric(doc) for doc in docs]
121+
metric_value += sum(metrics)
122+
self._docs_metrics.extend(metrics)
123+
self._metric_value += metric_value
124+
self._request_idxs.extend([next_req_idx] * num_docs)
125+
self._request_lens.append(num_docs)
126+
self._requests.append(request)
127+
queue = asyncio.Queue()
128+
self._requests_completed.append(queue)
129+
if self._metric_value >= self._preferred_batch_size:
130+
self._flush_trigger.set()
139131

140132
return queue
141133

@@ -271,96 +263,76 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
271263

272264
await self._flush_trigger.wait()
273265
# writes to shared data between tasks need to be mutually exclusive
274-
async with self._data_lock:
275-
big_doc_in_batch = copy.copy(self._big_doc)
276-
requests_idxs_in_batch = copy.copy(self._request_idxs)
277-
requests_lens_in_batch = copy.copy(self._request_lens)
278-
docs_metrics_in_batch = copy.copy(self._docs_metrics)
279-
requests_in_batch = copy.copy(self._requests)
280-
requests_completed_in_batch = copy.copy(self._requests_completed)
281-
282-
self._reset()
283-
284-
# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
285-
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
286-
# communicate that the request has been processed properly.
287-
288-
if not docarray_v2:
289-
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
290-
else:
291-
non_assigned_to_response_docs = self._response_docarray_cls()
266+
big_doc_in_batch = copy.copy(self._big_doc)
267+
requests_idxs_in_batch = copy.copy(self._request_idxs)
268+
requests_lens_in_batch = copy.copy(self._request_lens)
269+
docs_metrics_in_batch = copy.copy(self._docs_metrics)
270+
requests_in_batch = copy.copy(self._requests)
271+
requests_completed_in_batch = copy.copy(self._requests_completed)
292272

293-
non_assigned_to_response_request_idxs = []
294-
sum_from_previous_first_req_idx = 0
295-
for docs_inner_batch, req_idxs in batch(
296-
big_doc_in_batch, requests_idxs_in_batch,
297-
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
298-
):
299-
involved_requests_min_indx = req_idxs[0]
300-
involved_requests_max_indx = req_idxs[-1]
301-
input_len_before_call: int = len(docs_inner_batch)
302-
batch_res_docs = None
303-
try:
304-
batch_res_docs = await self.func(
305-
docs=docs_inner_batch,
306-
parameters=self.params,
307-
docs_matrix=None, # joining manually with batch queue is not supported right now
308-
tracing_context=None,
309-
)
310-
# Output validation
311-
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
312-
not docarray_v2
313-
and isinstance(batch_res_docs, DocumentArray)
314-
):
315-
if not len(batch_res_docs) == input_len_before_call:
316-
raise ValueError(
317-
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
318-
)
319-
elif batch_res_docs is None:
320-
if not len(docs_inner_batch) == input_len_before_call:
321-
raise ValueError(
322-
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
323-
)
324-
else:
325-
array_name = (
326-
'DocumentArray' if not docarray_v2 else 'DocList'
273+
self._reset()
274+
275+
# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
276+
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
277+
# communicate that the request has been processed properly.
278+
279+
if not docarray_v2:
280+
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
281+
else:
282+
non_assigned_to_response_docs = self._response_docarray_cls()
283+
284+
non_assigned_to_response_request_idxs = []
285+
sum_from_previous_first_req_idx = 0
286+
for docs_inner_batch, req_idxs in batch(
287+
big_doc_in_batch, requests_idxs_in_batch,
288+
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
289+
):
290+
involved_requests_min_indx = req_idxs[0]
291+
involved_requests_max_indx = req_idxs[-1]
292+
input_len_before_call: int = len(docs_inner_batch)
293+
batch_res_docs = None
294+
try:
295+
batch_res_docs = await self.func(
296+
docs=docs_inner_batch,
297+
parameters=self.params,
298+
docs_matrix=None, # joining manually with batch queue is not supported right now
299+
tracing_context=None,
300+
)
301+
# Output validation
302+
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
303+
not docarray_v2
304+
and isinstance(batch_res_docs, DocumentArray)
305+
):
306+
if not len(batch_res_docs) == input_len_before_call:
307+
raise ValueError(
308+
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
327309
)
328-
raise TypeError(
329-
f'The return type must be {array_name} / `None` when using dynamic batching, '
330-
f'but getting {batch_res_docs!r}'
310+
elif batch_res_docs is None:
311+
if not len(docs_inner_batch) == input_len_before_call:
312+
raise ValueError(
313+
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
331314
)
332-
except Exception as exc:
333-
# All the requests containing docs in this Exception should be raising it
334-
for request_full in requests_completed_in_batch[
335-
involved_requests_min_indx: involved_requests_max_indx + 1
336-
]:
337-
await request_full.put(exc)
338315
else:
339-
# We need to attribute the docs to their requests
340-
non_assigned_to_response_docs.extend(
341-
batch_res_docs or docs_inner_batch
316+
array_name = (
317+
'DocumentArray' if not docarray_v2 else 'DocList'
342318
)
343-
non_assigned_to_response_request_idxs.extend(req_idxs)
344-
num_assigned_docs = await _assign_results(
345-
non_assigned_to_response_docs,
346-
non_assigned_to_response_request_idxs,
347-
sum_from_previous_first_req_idx,
348-
requests_lens_in_batch,
349-
requests_in_batch,
350-
requests_completed_in_batch,
319+
raise TypeError(
320+
f'The return type must be {array_name} / `None` when using dynamic batching, '
321+
f'but getting {batch_res_docs!r}'
351322
)
352-
353-
sum_from_previous_first_req_idx = (
354-
len(non_assigned_to_response_docs) - num_assigned_docs
355-
)
356-
non_assigned_to_response_docs = non_assigned_to_response_docs[
357-
num_assigned_docs:
358-
]
359-
non_assigned_to_response_request_idxs = (
360-
non_assigned_to_response_request_idxs[num_assigned_docs:]
361-
)
362-
if len(non_assigned_to_response_request_idxs) > 0:
363-
_ = await _assign_results(
323+
except Exception as exc:
324+
# All the requests containing docs in this Exception should be raising it
325+
for request_full in requests_completed_in_batch[
326+
involved_requests_min_indx: involved_requests_max_indx + 1
327+
]:
328+
await request_full.put(exc)
329+
else:
330+
# We need to attribute the docs to their requests
331+
non_assigned_to_response_docs.extend(
332+
batch_res_docs or docs_inner_batch
333+
)
334+
non_assigned_to_response_request_idxs.extend(req_idxs)
335+
num_assigned_docs = await _assign_results(
364336
non_assigned_to_response_docs,
365337
non_assigned_to_response_request_idxs,
366338
sum_from_previous_first_req_idx,
@@ -369,6 +341,26 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
369341
requests_completed_in_batch,
370342
)
371343

344+
sum_from_previous_first_req_idx = (
345+
len(non_assigned_to_response_docs) - num_assigned_docs
346+
)
347+
non_assigned_to_response_docs = non_assigned_to_response_docs[
348+
num_assigned_docs:
349+
]
350+
non_assigned_to_response_request_idxs = (
351+
non_assigned_to_response_request_idxs[num_assigned_docs:]
352+
)
353+
if len(non_assigned_to_response_request_idxs) > 0:
354+
_ = await _assign_results(
355+
non_assigned_to_response_docs,
356+
non_assigned_to_response_request_idxs,
357+
sum_from_previous_first_req_idx,
358+
requests_lens_in_batch,
359+
requests_in_batch,
360+
requests_completed_in_batch,
361+
)
362+
363+
372364
async def close(self):
373365
"""Closes the batch queue by flushing pending requests."""
374366
if not self._is_closed:

jina/serve/runtimes/worker/request_handling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,6 @@ async def handle(
702702
].response_schema,
703703
output_array_type=self.args.output_array_type,
704704
params=params,
705-
allow_concurrent=self.args.allow_concurrent,
706705
**self._batchqueue_config[exec_endpoint],
707706
)
708707
# This is necessary because push might need to await for the queue to be emptied

0 commit comments

Comments
 (0)