Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,60 @@
DEFAULT_LATENCY_BUCKET_MS,
)

#: Histogram buckets for batch execution time in milliseconds.
BATCH_EXECUTION_TIME_BUCKETS_MS = [
1,
2,
5,
10,
20,
50,
100,
200,
500,
1000,
2000,
5000,
10000,
30000,
60000,
]

#: Histogram buckets for batch wait time in milliseconds.
BATCH_WAIT_TIME_BUCKETS_MS = [
0.1,
0.5,
1,
2,
5,
10,
20,
50,
100,
200,
500,
1000,
2000,
5000,
]

#: Histogram buckets for batch utilization percentage.
BATCH_UTILIZATION_BUCKETS_PERCENT = [
5,
10,
20,
30,
40,
50,
60,
70,
80,
90,
95,
99,
100,
]

#: Name of deployment health check method implemented by user.
HEALTH_CHECK_METHOD = "check_health"

Expand Down
76 changes: 75 additions & 1 deletion python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
from ray import serve
from ray._common.signature import extract_signature, flatten_args, recover_args
from ray._common.utils import get_or_create_event_loop
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.constants import (
BATCH_EXECUTION_TIME_BUCKETS_MS,
BATCH_UTILIZATION_BUCKETS_PERCENT,
BATCH_WAIT_TIME_BUCKETS_MS,
SERVE_LOGGER_NAME,
)
from ray.serve._private.utils import extract_self_if_method_call
from ray.serve.exceptions import RayServeException
from ray.serve.metrics import Counter, Gauge, Histogram
from ray.util.annotations import PublicAPI

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -140,6 +146,40 @@ def __init__(
# Used for observability.
self.curr_iteration_start_times: Dict[asyncio.Task, float] = {}

# Initialize batching metrics.
self._batch_wait_time_histogram = Histogram(
"serve_batch_wait_time_ms",
description="Time requests waited for batch to fill (in milliseconds).",
boundaries=BATCH_WAIT_TIME_BUCKETS_MS,
tag_keys=("function_name",),
)
self._batch_execution_time_histogram = Histogram(
"serve_batch_execution_time_ms",
description="Time to execute the batch function (in milliseconds).",
boundaries=BATCH_EXECUTION_TIME_BUCKETS_MS,
tag_keys=("function_name",),
)
self._batch_queue_length_gauge = Gauge(
"serve_batch_queue_length",
description="Number of requests waiting in the batch queue.",
tag_keys=("function_name",),
)
self._batch_utilization_histogram = Histogram(
"serve_batch_utilization_percent",
description="Batch utilization as percentage (actual_batch_size / max_batch_size * 100).",
boundaries=BATCH_UTILIZATION_BUCKETS_PERCENT,
tag_keys=("function_name",),
)
self._batches_processed_counter = Counter(
"serve_batches_processed",
description="Counter of batches executed.",
tag_keys=("function_name",),
)

self._function_name = (
handle_batch_func.__name__ if handle_batch_func is not None else None
)

self._handle_batch_task = None
self._loop = get_or_create_event_loop()
if handle_batch_func is not None:
Expand Down Expand Up @@ -195,6 +235,11 @@ async def wait_for_batch(self) -> List[_SingleRequest]:
# Wait self.timeout_s seconds for new queue arrivals.
batch_start_time = time.time()
while True:
# Record queue length metric.
self._batch_queue_length_gauge.set(
self.queue.qsize(), tags={"function_name": self._function_name}
)

remaining_batch_time_s = max(
batch_wait_timeout_s - (time.time() - batch_start_time), 0
)
Expand Down Expand Up @@ -225,6 +270,12 @@ async def wait_for_batch(self) -> List[_SingleRequest]:
):
break

# Record batch wait time metric (time spent waiting for batch to fill).
batch_wait_time_ms = (time.time() - batch_start_time) * 1000
self._batch_wait_time_histogram.observe(
batch_wait_time_ms, tags={"function_name": self._function_name}
)

return batch

def _validate_results(
Expand Down Expand Up @@ -329,11 +380,26 @@ async def _process_batch(self, func: Callable, batch: List[_SingleRequest]) -> N
if len(batch) == 0:
return

# Record batch utilization metric.
batch_size = len(batch)

# Calculate and record batch utilization percentage.
batch_utilization_percent = (batch_size / self.max_batch_size) * 100
self._batch_utilization_histogram.observe(
batch_utilization_percent, tags={"function_name": self._function_name}
)

# Increment batches processed counter.
self._batches_processed_counter.inc(
tags={"function_name": self._function_name}
)

futures = [item.future for item in batch]

# Most of the logic in the function should be wrapped in this try-
# except block, so the futures' exceptions can be set if an exception
# occurs. Otherwise, the futures' requests may hang indefinitely.
batch_execution_start_time = time.time()
try:
self_arg = batch[0].self_arg
args, kwargs = _batch_args_kwargs(
Expand Down Expand Up @@ -368,6 +434,14 @@ async def _process_batch(self, func: Callable, batch: List[_SingleRequest]) -> N

for future in futures:
_set_exception_if_not_done(future, e)
finally:
# Record batch execution time.
batch_execution_time_ms = (
time.time() - batch_execution_start_time
) * 1000
self._batch_execution_time_histogram.observe(
batch_execution_time_ms, tags={"function_name": self._function_name}
)

def _handle_completed_task(self, task: asyncio.Task) -> None:
self.tasks.remove(task)
Expand Down
97 changes: 97 additions & 0 deletions python/ray/serve/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,5 +1247,102 @@ def f():
)


def test_batching_metrics(metrics_start_shutdown):
@serve.deployment
class BatchedDeployment:
@serve.batch(max_batch_size=4, batch_wait_timeout_s=0.5)
async def batch_handler(self, requests: List[str]) -> List[str]:
# Simulate some processing time
import asyncio

await asyncio.sleep(0.05)
return [f"processed:{r}" for r in requests]

async def __call__(self, request: Request):
data = await request.body()
return await self.batch_handler(data.decode())

app_name = "batched_app"
serve.run(BatchedDeployment.bind(), name=app_name, route_prefix="/batch")

http_url = "http://localhost:8000/batch"

# Send multiple concurrent requests to trigger batching
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
futures = [
executor.submit(lambda i=i: httpx.post(http_url, content=f"req{i}"))
for i in range(8)
]
results = [f.result() for f in futures]

# Verify all requests succeeded
assert all(r.status_code == 200 for r in results)

# Verify specific metric values and tags
timeseries = PrometheusTimeseries()
expected_tags = {
"deployment": "BatchedDeployment",
"application": app_name,
"function_name": "batch_handler",
}

# Check batches_processed_total counter exists and has correct tags
wait_for_condition(
lambda: check_metric_float_eq(
"ray_serve_batches_processed_total",
expected=2,
expected_tags=expected_tags,
timeseries=timeseries,
),
timeout=10,
)

# Check batch_wait_time_ms histogram was recorded for 2 batches
wait_for_condition(
lambda: check_metric_float_eq(
"ray_serve_batch_wait_time_ms_count",
expected=2,
expected_tags=expected_tags,
timeseries=timeseries,
),
timeout=10,
)

# Check batch_execution_time_ms histogram was recorded for 2 batches
wait_for_condition(
lambda: check_metric_float_eq(
"ray_serve_batch_execution_time_ms_count",
expected=2,
expected_tags=expected_tags,
timeseries=timeseries,
),
timeout=10,
)

# Check batch_utilization_percent histogram: 2 batches at 100% each = 200 sum
wait_for_condition(
lambda: check_metric_float_eq(
"ray_serve_batch_utilization_percent_count",
expected=2,
expected_tags=expected_tags,
timeseries=timeseries,
),
timeout=10,
)

# Check batch_queue_length gauge exists (should be 0 after processing)
wait_for_condition(
lambda: check_metric_float_eq(
"ray_serve_batch_queue_length",
expected=0,
expected_tags=expected_tags,
timeseries=timeseries,
),
timeout=10,
)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))