Skip to content
Merged
47 changes: 6 additions & 41 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
List,
Optional,
Tuple,
Type,
TypeVar,
overload,
)
Expand Down Expand Up @@ -320,22 +319,20 @@ def __init__(
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
handle_batch_func: Optional[Callable] = None,
batch_queue_cls: Type[_BatchQueue] = _BatchQueue,
):
self._queue: Type[_BatchQueue] = None
self._queue: Optional[_BatchQueue] = None
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self.handle_batch_func = handle_batch_func
self.batch_queue_cls = batch_queue_cls

@property
def queue(self) -> Type[_BatchQueue]:
def queue(self) -> _BatchQueue:
"""Returns _BatchQueue.

Initializes queue when called for the first time.
"""
if self._queue is None:
self._queue = self.batch_queue_cls(
self._queue = _BatchQueue(
self.max_batch_size,
self.batch_wait_timeout_s,
self.handle_batch_func,
Expand Down Expand Up @@ -451,8 +448,6 @@ def batch(
_func: Optional[Callable] = None,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
*,
batch_queue_cls: Type[_BatchQueue] = _BatchQueue,
):
"""Converts a function to asynchronously handle batches.

Expand Down Expand Up @@ -500,7 +495,6 @@ async def __call__(self, request: Request):
one call to the underlying function.
batch_wait_timeout_s: the maximum duration to wait for
`max_batch_size` elements before running the current batch.
batch_queue_cls: the class to use for the underlying batch queue.
"""
# `_func` will be None in the case when the decorator is parametrized.
# See the comment at the end of this function for a detailed explanation.
Expand All @@ -521,7 +515,6 @@ def _batch_decorator(_func):
max_batch_size,
batch_wait_timeout_s,
_func,
batch_queue_cls,
)

async def batch_handler_generator(
Expand All @@ -539,47 +532,19 @@ async def batch_handler_generator(
break

def enqueue_request(args, kwargs) -> asyncio.Future:
self = extract_self_if_method_call(args, _func)
flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)

if self is None:
# For functions, inject the batch queue as an
# attribute of the function.
batch_queue_object = _func
else:
# For methods, inject the batch queue as an
# attribute of the object.
batch_queue_object = self
# Trim the self argument from methods
# If the function is a method, remove self as an argument.
self = extract_self_if_method_call(args, _func)
if self is not None:
flattened_args = flattened_args[2:]

batch_queue = lazy_batch_queue_wrapper.queue

# Magic batch_queue_object attributes that can be used to change the
# batch queue attributes on the fly.
# This is purposefully undocumented for now while we figure out
# the best API.
if hasattr(batch_queue_object, "_ray_serve_max_batch_size"):
new_max_batch_size = getattr(
batch_queue_object, "_ray_serve_max_batch_size"
)
_validate_max_batch_size(new_max_batch_size)
batch_queue.max_batch_size = new_max_batch_size

if hasattr(batch_queue_object, "_ray_serve_batch_wait_timeout_s"):
new_batch_wait_timeout_s = getattr(
batch_queue_object, "_ray_serve_batch_wait_timeout_s"
)
_validate_batch_wait_timeout_s(new_batch_wait_timeout_s)
batch_queue.batch_wait_timeout_s = new_batch_wait_timeout_s

future = get_or_create_event_loop().create_future()
batch_queue.put(_SingleRequest(self, flattened_args, future))
return future

# TODO (shrekris-anyscale): deprecate batch_queue_cls argument and
# convert batch_wrapper into a class once `self` argument is no
# longer needed in `enqueue_request`.
@wraps(_func)
def generator_batch_wrapper(*args, **kwargs):
first_future = enqueue_request(args, kwargs)
Expand Down
35 changes: 0 additions & 35 deletions python/ray/serve/tests/unit/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,6 @@ def event_loop():
loop.close()


@pytest.mark.asyncio
async def test_batching_magic_attributes():
class BatchingExample:
def __init__(self):
self.count = 0
self.batch_sizes = set()

@property
def _ray_serve_max_batch_size(self):
return self.count + 1

@property
def _ray_serve_batch_wait_timeout_s(self):
return 0.1

@serve.batch
async def handle_batch(self, requests):
self.count += 1
batch_size = len(requests)
self.batch_sizes.add(batch_size)
return [batch_size] * batch_size

batching_example = BatchingExample()

for batch_size in range(1, 7):
tasks = [
get_or_create_event_loop().create_task(batching_example.handle_batch(1))
for _ in range(batch_size)
]

done, _ = await asyncio.wait(tasks, return_when="ALL_COMPLETED")
assert set({task.result() for task in done}) == {batch_size}
time.sleep(0.05)


@pytest.mark.asyncio
async def test_decorator_validation():
@serve.batch
Expand Down