Skip to content

Commit d17b620

Browse files
Joan Fontanalsjina-bot
andauthored
feat: add custom_metric for dynamic batching (#6189)
Co-authored-by: Jina Dev Bot <[email protected]>
1 parent d4fb94d commit d17b620

File tree

5 files changed

+172
-58
lines changed

5 files changed

+172
-58
lines changed

jina/serve/executors/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,22 @@ def _validate_sagemaker(self):
655655
return
656656

657657
def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]):
658+
import collections
659+
660+
def deep_update(source, overrides):
661+
for key, value in overrides.items():
662+
if isinstance(value, collections.Mapping) and value:
663+
returned = deep_update(source.get(key, {}), value)
664+
source[key] = returned
665+
else:
666+
source[key] = overrides[key]
667+
return source
668+
658669
if _dynamic_batching:
659670
self.dynamic_batching = getattr(self, 'dynamic_batching', {})
660-
self.dynamic_batching.update(_dynamic_batching)
671+
self.dynamic_batching = deep_update(
672+
self.dynamic_batching, _dynamic_batching
673+
)
661674

662675
def _add_metas(self, _metas: Optional[Dict]):
663676
from jina.serve.executors.metas import get_default_metas

jina/serve/executors/decorators.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,9 @@ def dynamic_batching(
416416
*,
417417
preferred_batch_size: Optional[int] = None,
418418
timeout: Optional[float] = 10_000,
419-
flush_all: bool = False
419+
flush_all: bool = False,
420+
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
421+
use_custom_metric: bool = False,
420422
):
421423
"""
422424
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
@@ -434,6 +436,8 @@ def dynamic_batching(
434436
Default is 10_000ms (10 seconds).
435437
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
436438
If this is true, `preferred_batch_size` is used as a trigger mechanism.
439+
:param custom_metric: Potential lambda function to measure the "weight" of each request.
440+
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
437441
:return: decorated function
438442
"""
439443

@@ -480,6 +484,8 @@ def _inject_owner_attrs(self, owner, name):
480484
] = preferred_batch_size
481485
owner.dynamic_batching[fn_name]['timeout'] = timeout
482486
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
487+
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
488+
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
483489
setattr(owner, name, self.fn)
484490

485491
def __set_name__(self, owner, name):

jina/serve/runtimes/worker/batch_queue.py

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
import copy
33
from asyncio import Event, Task
4-
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
4+
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
55
from jina._docarray import docarray_v2
66
import contextlib
7+
78
if not docarray_v2:
89
from docarray import DocumentArray
910
else:
@@ -18,16 +19,18 @@ class BatchQueue:
1819
"""A batch queue that holds the data request and the callable to batch requests to."""
1920

2021
def __init__(
21-
self,
22-
func: Callable,
23-
request_docarray_cls,
24-
response_docarray_cls,
25-
output_array_type: Optional[str] = None,
26-
params: Optional[Dict] = None,
27-
allow_concurrent: bool = False,
28-
flush_all: bool = False,
29-
preferred_batch_size: int = 4,
30-
timeout: int = 10_000,
22+
self,
23+
func: Callable,
24+
request_docarray_cls,
25+
response_docarray_cls,
26+
output_array_type: Optional[str] = None,
27+
params: Optional[Dict] = None,
28+
allow_concurrent: bool = False,
29+
flush_all: bool = False,
30+
preferred_batch_size: int = 4,
31+
timeout: int = 10_000,
32+
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
33+
use_custom_metric: bool = False,
3134
) -> None:
3235
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
3336
if allow_concurrent and flush_all:
@@ -44,6 +47,8 @@ def __init__(
4447
self._response_docarray_cls = response_docarray_cls
4548
self._flush_all = flush_all
4649
self._preferred_batch_size: int = preferred_batch_size
50+
self._custom_metric = None if not use_custom_metric else custom_metric
51+
self._metric_value = 0
4752
self._timeout: int = timeout
4853
self._reset()
4954
self._flush_trigger: Event = Event()
@@ -62,20 +67,22 @@ def _reset(self) -> None:
6267
# a list of every request ID
6368
self._request_idxs: List[int] = []
6469
self._request_lens: List[int] = []
70+
self._docs_metrics: List[int] = []
6571
self._requests_completed: List[asyncio.Queue] = []
6672
if not docarray_v2:
6773
self._big_doc: DocumentArray = DocumentArray.empty()
6874
else:
6975
self._big_doc = self._request_docarray_cls()
76+
self._metric_value = 0
7077

7178
self._flush_task: Optional[Task] = None
7279
self._flush_trigger: Event = Event()
7380

7481
def _cancel_timer_if_pending(self):
7582
if (
76-
self._timer_task
77-
and not self._timer_task.done()
78-
and not self._timer_task.cancelled()
83+
self._timer_task
84+
and not self._timer_task.done()
85+
and not self._timer_task.cancelled()
7986
):
8087
self._timer_finished = False
8188
self._timer_task.cancel()
@@ -91,7 +98,7 @@ async def _sleep_then_set(self):
9198
self._flush_trigger.set()
9299
self._timer_finished = True
93100

94-
async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
101+
async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
95102
"""Append request to the the list of requests to be processed.
96103
97104
This method creates an asyncio Queue for that request and keeps track of it. It returns
@@ -116,12 +123,18 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
116123
self._big_doc.extend(docs)
117124
next_req_idx = len(self._requests)
118125
num_docs = len(docs)
126+
metric_value = num_docs
127+
if self._custom_metric is not None:
128+
metrics = [self._custom_metric(doc) for doc in docs]
129+
metric_value += sum(metrics)
130+
self._docs_metrics.extend(metrics)
131+
self._metric_value += metric_value
119132
self._request_idxs.extend([next_req_idx] * num_docs)
120-
self._request_lens.append(len(docs))
133+
self._request_lens.append(num_docs)
121134
self._requests.append(request)
122135
queue = asyncio.Queue()
123136
self._requests_completed.append(queue)
124-
if len(self._big_doc) >= self._preferred_batch_size:
137+
if self._metric_value >= self._preferred_batch_size:
125138
self._flush_trigger.set()
126139

127140
return queue
@@ -132,10 +145,10 @@ async def _await_then_flush(self, http=False) -> None:
132145
"""
133146

134147
def _get_docs_groups_completed_request_indexes(
135-
non_assigned_docs,
136-
non_assigned_docs_reqs_idx,
137-
sum_from_previous_mini_batch_in_first_req_idx,
138-
requests_lens_in_batch,
148+
non_assigned_docs,
149+
non_assigned_docs_reqs_idx,
150+
sum_from_previous_mini_batch_in_first_req_idx,
151+
requests_lens_in_batch,
139152
):
140153
"""
141154
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
@@ -160,9 +173,9 @@ def _get_docs_groups_completed_request_indexes(
160173
)
161174
if req_idx > min_involved_req_idx:
162175
request_bucket = non_assigned_docs[
163-
num_distributed_docs : num_distributed_docs
164-
+ num_docs_in_req_idx
165-
]
176+
num_distributed_docs: num_distributed_docs
177+
+ num_docs_in_req_idx
178+
]
166179
num_distributed_docs += num_docs_in_req_idx
167180
completed_req_idx.append(min_involved_req_idx)
168181
min_involved_req_idx = req_idx
@@ -171,25 +184,25 @@ def _get_docs_groups_completed_request_indexes(
171184
num_docs_in_req_idx += 1
172185

173186
if (
174-
req_idx not in completed_req_idx
175-
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
176-
== requests_lens_in_batch[req_idx]
187+
req_idx not in completed_req_idx
188+
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
189+
== requests_lens_in_batch[req_idx]
177190
):
178191
completed_req_idx.append(req_idx)
179192
request_bucket = non_assigned_docs[
180-
num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
181-
]
193+
num_distributed_docs: num_distributed_docs + num_docs_in_req_idx
194+
]
182195
distributed_requests.append(request_bucket)
183196

184197
return distributed_requests, completed_req_idx
185198

186199
async def _assign_results(
187-
non_assigned_docs,
188-
non_assigned_docs_reqs_idx,
189-
sum_from_previous_mini_batch_in_first_req_idx,
190-
requests_lens_in_batch,
191-
requests_in_batch,
192-
requests_completed_in_batch,
200+
non_assigned_docs,
201+
non_assigned_docs_reqs_idx,
202+
sum_from_previous_mini_batch_in_first_req_idx,
203+
requests_lens_in_batch,
204+
requests_in_batch,
205+
requests_completed_in_batch,
193206
):
194207
"""
195208
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
@@ -220,7 +233,7 @@ async def _assign_results(
220233
request = requests_in_batch[request_idx]
221234
request_completed = requests_completed_in_batch[request_idx]
222235
if http is False or self._output_array_type is not None:
223-
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
236+
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
224237
request.data.set_docs_convert_arrays(
225238
docs_group, ndarray_type=self._output_array_type
226239
)
@@ -230,22 +243,39 @@ async def _assign_results(
230243

231244
return num_assigned_docs
232245

233-
def batch(iterable_1, iterable_2, n:Optional[int] = 1):
246+
def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Optional = None):
234247
if n is None:
235248
yield iterable_1, iterable_2
236249
return
237-
items = len(iterable_1)
238-
for ndx in range(0, items, n):
239-
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
240-
ndx : min(ndx + n, items)
241-
]
250+
elif iterable_metrics is None:
251+
items = len(iterable_1)
252+
for ndx in range(0, items, n):
253+
yield iterable_1[ndx: min(ndx + n, items)], iterable_2[
254+
ndx: min(ndx + n, items)
255+
]
256+
else:
257+
batch_idx = 0
258+
batch_weight = 0
259+
260+
for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)):
261+
batch_weight += weight
262+
263+
if batch_weight >= n:
264+
yield iterable_1[batch_idx: i + 1], iterable_2[batch_idx: i + 1]
265+
batch_idx = i + 1
266+
batch_weight = 0
267+
268+
# Yield any remaining items
269+
if batch_weight > 0:
270+
yield iterable_1[batch_idx: len(iterable_1)], iterable_2[batch_idx: len(iterable_1)]
242271

243272
await self._flush_trigger.wait()
244273
# writes to shared data between tasks need to be mutually exclusive
245274
async with self._data_lock:
246275
big_doc_in_batch = copy.copy(self._big_doc)
247276
requests_idxs_in_batch = copy.copy(self._request_idxs)
248277
requests_lens_in_batch = copy.copy(self._request_lens)
278+
docs_metrics_in_batch = copy.copy(self._docs_metrics)
249279
requests_in_batch = copy.copy(self._requests)
250280
requests_completed_in_batch = copy.copy(self._requests_completed)
251281

@@ -263,7 +293,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
263293
non_assigned_to_response_request_idxs = []
264294
sum_from_previous_first_req_idx = 0
265295
for docs_inner_batch, req_idxs in batch(
266-
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
296+
big_doc_in_batch, requests_idxs_in_batch,
297+
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
267298
):
268299
involved_requests_min_indx = req_idxs[0]
269300
involved_requests_max_indx = req_idxs[-1]
@@ -278,8 +309,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
278309
)
279310
# Output validation
280311
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
281-
not docarray_v2
282-
and isinstance(batch_res_docs, DocumentArray)
312+
not docarray_v2
313+
and isinstance(batch_res_docs, DocumentArray)
283314
):
284315
if not len(batch_res_docs) == input_len_before_call:
285316
raise ValueError(
@@ -301,8 +332,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
301332
except Exception as exc:
302333
# All the requests containing docs in this Exception should be raising it
303334
for request_full in requests_completed_in_batch[
304-
involved_requests_min_indx : involved_requests_max_indx + 1
305-
]:
335+
involved_requests_min_indx: involved_requests_max_indx + 1
336+
]:
306337
await request_full.put(exc)
307338
else:
308339
# We need to attribute the docs to their requests
@@ -320,11 +351,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
320351
)
321352

322353
sum_from_previous_first_req_idx = (
323-
len(non_assigned_to_response_docs) - num_assigned_docs
354+
len(non_assigned_to_response_docs) - num_assigned_docs
324355
)
325356
non_assigned_to_response_docs = non_assigned_to_response_docs[
326-
num_assigned_docs:
327-
]
357+
num_assigned_docs:
358+
]
328359
non_assigned_to_response_request_idxs = (
329360
non_assigned_to_response_request_idxs[num_assigned_docs:]
330361
)

tests/integration/dynamic_batching/test_dynamic_batching.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,67 @@ def foo(self, docs, **kwargs):
736736

737737
assert smaller_than_5 == (1 if allow_concurrent else 0)
738738
assert larger_than_5 > 0
739+
740+
741+
@pytest.mark.asyncio
742+
@pytest.mark.parametrize('use_custom_metric', [True, False])
743+
@pytest.mark.parametrize('flush_all', [False, True])
744+
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
745+
class DynCustomBatchProcessor(Executor):
746+
747+
@dynamic_batching(preferred_batch_size=10, custom_metric=lambda x: len(x.text))
748+
@requests(on='/foo')
749+
def foo(self, docs, **kwargs):
750+
time.sleep(0.5)
751+
total_len = sum([len(doc.text) for doc in docs])
752+
for doc in docs:
753+
doc.text = f"{total_len}"
754+
755+
depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
756+
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
757+
with depl:
758+
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
759+
res = []
760+
async for r in cl.post(
761+
on='/foo',
762+
inputs=da,
763+
request_size=1,
764+
continue_on_error=True,
765+
results_in_order=True,
766+
):
767+
res.extend(r)
768+
assert len(res) == 50 # 1 request per input
769+
770+
# If custom_metric and flush all
771+
if use_custom_metric and not flush_all:
772+
for doc in res:
773+
assert doc.text == "10"
774+
775+
elif not use_custom_metric and not flush_all:
776+
for doc in res:
777+
assert doc.text == "50"
778+
779+
elif use_custom_metric and flush_all:
780+
# There will be 2 "10" and the rest will be "240"
781+
num_10 = 0
782+
num_240 = 0
783+
for doc in res:
784+
if doc.text == "10":
785+
num_10 += 1
786+
elif doc.text == "240":
787+
num_240 += 1
788+
789+
assert num_10 == 2
790+
assert num_240 == 48
791+
elif not use_custom_metric and flush_all:
792+
# There will be 10 "50" and the rest will be "200"
793+
num_50 = 0
794+
num_200 = 0
795+
for doc in res:
796+
if doc.text == "50":
797+
num_50 += 1
798+
elif doc.text == "200":
799+
num_200 += 1
800+
801+
assert num_50 == 10
802+
assert num_200 == 40

0 commit comments

Comments
 (0)