Skip to content

Commit 3fbea82

Browse files
committed
enrich autoscaling context
Signed-off-by: abrar <abrar@anyscale.com>
1 parent c5e6647 commit 3fbea82

File tree

6 files changed

+254
-49
lines changed

6 files changed

+254
-49
lines changed

python/ray/serve/_private/autoscaling_state.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import time
33
from collections import defaultdict
4-
from typing import Any, Dict, List, Optional, Set
4+
from typing import Dict, List, Optional, Set
55

66
from ray.serve._private.common import (
77
ONGOING_REQUESTS_KEY,
@@ -207,6 +207,9 @@ def get_decision_num_replicas(
207207
`_skip_bound_check` is True, then the bounds are not applied.
208208
"""
209209

210+
total_num_requests = self.get_total_num_requests()
211+
total_queued_requests = self._get_queued_requests()
212+
total_running_requests = total_num_requests - total_queued_requests
210213
autoscaling_context: AutoscalingContext = AutoscalingContext(
211214
deployment_id=self._deployment_id,
212215
deployment_name=self._deployment_id.name,
@@ -220,10 +223,10 @@ def get_decision_num_replicas(
220223
policy_state=self._policy_state.copy(),
221224
current_time=time.time(),
222225
config=self._config,
223-
queued_requests=None,
224-
requests_per_replica=None,
225-
aggregated_metrics=None,
226-
raw_metrics=None,
226+
total_queued_requests=self._get_queued_requests(),
227+
total_running_requests=total_running_requests,
228+
aggregated_metrics=self._get_aggregated_custom_metrics(),
229+
raw_metrics=self._get_raw_custom_metrics(),
227230
last_scale_up_time=None,
228231
last_scale_down_time=None,
229232
)
@@ -300,19 +303,22 @@ def _collect_handle_running_requests(self) -> List[List[TimeStampedValue]]:
300303

301304
return timeseries_list
302305

303-
def _aggregate_ongoing_requests(
304-
self, metrics_timeseries_dicts: List[Dict[str, List[TimeStampedValue]]]
306+
def _aggregate_timeseries_metric(
307+
self,
308+
metrics_timeseries_dicts: List[Dict[str, List[TimeStampedValue]]],
309+
metric_key: str,
305310
) -> float:
306-
"""Aggregate and average ongoing requests from timeseries data using instantaneous merge.
311+
"""Aggregate and average a metric from timeseries data using instantaneous merge.
307312
308313
Args:
309314
metrics_timeseries_dicts: A list of dictionaries, each containing a key-value pair:
310-
- The key is the name of the metric (ONGOING_REQUESTS_KEY)
315+
- The key is the name of the metric (e.g., ONGOING_REQUESTS_KEY or custom metric name)
311316
- The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
312317
this list is sorted by timestamp ascending
318+
metric_key: The key to use when extracting the metric from the dictionaries
313319
314320
Returns:
315-
The time-weighted average of the ongoing requests
321+
The time-weighted average of the metric
316322
317323
Example:
318324
If the metrics_timeseries_dicts is:
@@ -339,10 +345,10 @@ def _aggregate_ongoing_requests(
339345

340346
# Use instantaneous merge approach - no arbitrary windowing needed
341347
aggregated_metrics = merge_timeseries_dicts(*metrics_timeseries_dicts)
342-
ongoing_requests_timeseries = aggregated_metrics.get(ONGOING_REQUESTS_KEY, [])
343-
if ongoing_requests_timeseries:
348+
metric_timeseries = aggregated_metrics.get(metric_key, [])
349+
if metric_timeseries:
344350
# assume that the last recorded metric is valid for last_window_s seconds
345-
last_metric_time = ongoing_requests_timeseries[-1].timestamp
351+
last_metric_time = metric_timeseries[-1].timestamp
346352
# we dont want to make any assumption about how long the last metric will be valid
347353
# only conclude that the last metric is valid for last_window_s seconds that is the
348354
# difference between the current time and the last metric recorded time
@@ -351,16 +357,34 @@ def _aggregate_ongoing_requests(
351357
# between replicas and controller. Also add a small epsilon to avoid division by zero
352358
if last_window_s <= 0:
353359
last_window_s = 1e-3
354-
# Calculate the aggregated running requests
360+
# Calculate the aggregated metric value
355361
value = aggregate_timeseries(
356-
ongoing_requests_timeseries,
362+
metric_timeseries,
357363
aggregation_function=self._config.aggregation_function,
358364
last_window_s=last_window_s,
359365
)
360366
return value if value is not None else 0.0
361367

362368
return 0.0
363369

370+
def _aggregate_ongoing_requests(
371+
self, metrics_timeseries_dicts: List[Dict[str, List[TimeStampedValue]]]
372+
) -> float:
373+
"""Aggregate and average ongoing requests from timeseries data using instantaneous merge.
374+
375+
This is a convenience wrapper around _aggregate_timeseries_metric for ongoing requests.
376+
377+
Args:
378+
metrics_timeseries_dicts: A list of dictionaries containing ONGOING_REQUESTS_KEY
379+
mapped to timeseries data.
380+
381+
Returns:
382+
The time-weighted average of the ongoing requests
383+
"""
384+
return self._aggregate_timeseries_metric(
385+
metrics_timeseries_dicts, ONGOING_REQUESTS_KEY
386+
)
387+
364388
def _calculate_total_requests_aggregate_mode(self) -> float:
365389
"""Calculate total requests using aggregate metrics mode with timeseries data.
366390
@@ -541,11 +565,8 @@ def get_total_num_requests(self) -> float:
541565
else:
542566
return self._calculate_total_requests_simple_mode()
543567

544-
def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
568+
def get_replica_metrics(self) -> Dict[ReplicaID, List[TimeStampedValue]]:
545569
"""Get the raw replica metrics dict."""
546-
# arcyleung TODO: pass agg_func from autoscaling policy https://github.com/ray-project/ray/pull/51905
547-
# Dummy implementation of mean agg_func across all values of the same metrics key
548-
549570
metric_values = defaultdict(list)
550571
for id in self._running_replicas:
551572
if id in self._replica_metrics and self._replica_metrics[id].metrics:
@@ -554,6 +575,86 @@ def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
554575

555576
return metric_values
556577

578+
def _get_queued_requests(self) -> float:
579+
"""Calculate the total number of queued requests across all handles.
580+
581+
Returns:
582+
Sum of queued requests at all handles. Uses aggregated values in simple mode,
583+
or aggregates timeseries data in aggregate mode.
584+
"""
585+
if RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER:
586+
# Aggregate mode: collect and aggregate timeseries
587+
queued_timeseries = self._collect_handle_queued_requests()
588+
if not queued_timeseries:
589+
return 0.0
590+
591+
queued_metrics = [
592+
{ONGOING_REQUESTS_KEY: timeseries} for timeseries in queued_timeseries
593+
]
594+
return self._aggregate_ongoing_requests(queued_metrics)
595+
else:
596+
# Simple mode: sum pre-aggregated values
597+
return sum(
598+
handle_metric.aggregated_queued_requests
599+
for handle_metric in self._handle_requests.values()
600+
)
601+
602+
def _get_aggregated_custom_metrics(self) -> Dict[str, Dict[ReplicaID, float]]:
603+
"""Aggregate custom metrics from replica metric reports.
604+
605+
Custom metrics are all metrics except RUNNING_REQUESTS_KEY. These are metrics
606+
emitted by the deployment using the `record_autoscaling_stats` method.
607+
608+
This method aggregates raw timeseries data from replicas on the controller,
609+
similar to how ongoing requests are aggregated.
610+
611+
Returns:
612+
Dict mapping metric name to dict of replica ID to aggregated metric value.
613+
"""
614+
aggregated_metrics = defaultdict(dict)
615+
616+
for replica_id in self._running_replicas:
617+
replica_metric_report = self._replica_metrics.get(replica_id)
618+
if replica_metric_report is None:
619+
continue
620+
621+
for metric_name, timeseries in replica_metric_report.metrics.items():
622+
if metric_name != RUNNING_REQUESTS_KEY:
623+
# Aggregate the timeseries for this custom metric
624+
# Use the actual metric name as the key
625+
metrics = [{metric_name: timeseries}]
626+
aggregated_value = self._aggregate_timeseries_metric(
627+
metrics, metric_name
628+
)
629+
aggregated_metrics[metric_name][replica_id] = aggregated_value
630+
631+
return dict(aggregated_metrics)
632+
633+
def _get_raw_custom_metrics(
634+
self,
635+
) -> Dict[str, Dict[ReplicaID, List[TimeStampedValue]]]:
636+
"""Extract raw custom metric values from replica metric reports.
637+
638+
Custom metrics are all metrics except RUNNING_REQUESTS_KEY. These are metrics
639+
emitted by the deployment using the `record_autoscaling_stats` method.
640+
641+
Returns:
642+
Dict mapping metric name to dict of replica ID to list of raw metric values.
643+
"""
644+
raw_metrics = defaultdict(dict)
645+
646+
for replica_id in self._running_replicas:
647+
replica_metric_report = self._replica_metrics.get(replica_id)
648+
if replica_metric_report is None:
649+
continue
650+
651+
for metric_name, timeseries in replica_metric_report.metrics.items():
652+
if metric_name != RUNNING_REQUESTS_KEY:
653+
# Extract values from TimeStampedValue list
654+
raw_metrics[metric_name][replica_id] = timeseries
655+
656+
return dict(raw_metrics)
657+
557658

558659
class AutoscalingStateManager:
559660
"""Manages all things autoscaling related.
@@ -602,12 +703,10 @@ def get_metrics(self) -> Dict[DeploymentID, float]:
602703
}
603704

604705
def get_all_metrics(
605-
self, agg_func="mean"
606-
) -> Dict[DeploymentID, Dict[ReplicaID, List[Any]]]:
706+
self,
707+
) -> Dict[DeploymentID, Dict[ReplicaID, List[TimeStampedValue]]]:
607708
return {
608-
deployment_id: self._autoscaling_states[deployment_id].get_replica_metrics(
609-
agg_func
610-
)
709+
deployment_id: self._autoscaling_states[deployment_id].get_replica_metrics()
611710
for deployment_id in self._autoscaling_states
612711
}
613712

python/ray/serve/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ray._common.utils import import_attr
2020

2121
# Import types needed for AutoscalingContext
22-
from ray.serve._private.common import DeploymentID, ReplicaID
22+
from ray.serve._private.common import DeploymentID, ReplicaID, TimeStampedValue
2323
from ray.serve._private.constants import (
2424
DEFAULT_AUTOSCALING_POLICY_NAME,
2525
DEFAULT_GRPC_PORT,
@@ -63,18 +63,18 @@ class AutoscalingContext:
6363

6464
# Built-in metrics
6565
total_num_requests: float #: Total number of requests across all replicas.
66-
queued_requests: Optional[float] #: Number of requests currently queued.
67-
requests_per_replica: Dict[
68-
ReplicaID, float
69-
] #: Mapping of replica ID to number of requests.
66+
total_queued_requests: Optional[float] #: Number of requests currently queued.
67+
total_running_requests: Optional[
68+
float
69+
] #: Total number of requests currently running.
7070

7171
# Custom metrics
7272
aggregated_metrics: Dict[
7373
str, Dict[ReplicaID, float]
7474
] #: Time-weighted averages of custom metrics per replica.
7575
raw_metrics: Dict[
76-
str, Dict[ReplicaID, List[float]]
77-
] #: Raw custom metric values per replica.
76+
str, Dict[ReplicaID, List[TimeStampedValue]]
77+
] #: Raw custom metric timeseries per replica.
7878

7979
# Capacity and bounds
8080
capacity_adjusted_min_replicas: int #: Minimum replicas adjusted for cluster capacity.

python/ray/serve/tests/BUILD.bazel

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,27 @@ py_test_module_list(
3131
)
3232

3333
# Custom metrics tests.
34-
py_test_module_list(
34+
py_test_module_list_with_env_variants(
3535
size = "small",
36-
env = {
37-
"RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE": "0",
38-
"RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S": "2",
39-
"RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S": "3",
36+
env_variants = {
37+
"metr_agg_at_controller": {
38+
"env": {
39+
"RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER": "1",
40+
"RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE": "0",
41+
"RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S": "0.1",
42+
"RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S": "3",
43+
},
44+
"name_suffix": "_metr_agg_at_controller",
45+
},
46+
"metr_agg_at_replicas": {
47+
"env": {
48+
"RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER": "0",
49+
"RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE": "0",
50+
"RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S": "0.1",
51+
"RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S": "3",
52+
},
53+
"name_suffix": "_metr_agg_at_replicas",
54+
},
4055
},
4156
files = [
4257
"test_custom_autoscaling_metrics.py",

python/ray/serve/tests/test_autoscaling_policy.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,46 @@ async def __call__(self):
15921592
wait_for_condition(lambda: ray.get(signal.cur_num_waiters.remote()) == 0)
15931593

15941594

1595+
def custom_autoscaling_policy_with_queued_requests(ctx: AutoscalingContext):
1596+
total_queued_requests = ctx.total_queued_requests
1597+
if total_queued_requests > 20:
1598+
return 3, {}
1599+
else:
1600+
return 1, {}
1601+
1602+
1603+
def test_policy_using_custom_metrics_with_queued_requests(serve_instance):
1604+
signal = SignalActor.remote()
1605+
1606+
@serve.deployment(
1607+
autoscaling_config={
1608+
"min_replicas": 1,
1609+
"max_replicas": 5,
1610+
"upscale_delay_s": 2,
1611+
"downscale_delay_s": 1,
1612+
"metrics_interval_s": 0.1,
1613+
"look_back_period_s": 1,
1614+
"target_ongoing_requests": 10,
1615+
"policy": AutoscalingPolicy(
1616+
policy_function=custom_autoscaling_policy_with_queued_requests
1617+
),
1618+
},
1619+
max_ongoing_requests=10,
1620+
)
1621+
class CustomMetricsDeployment:
1622+
async def __call__(self) -> str:
1623+
await signal.wait.remote()
1624+
return "Hello, world"
1625+
1626+
handle = serve.run(CustomMetricsDeployment.bind())
1627+
[handle.remote() for _ in range(40)]
1628+
wait_for_condition(lambda: ray.get(signal.cur_num_waiters.remote()) == 10)
1629+
wait_for_condition(check_num_replicas_eq, name="CustomMetricsDeployment", target=3)
1630+
ray.get(signal.send.remote())
1631+
wait_for_condition(lambda: ray.get(signal.cur_num_waiters.remote()) == 0)
1632+
wait_for_condition(check_num_replicas_eq, name="CustomMetricsDeployment", target=1)
1633+
1634+
15951635
if __name__ == "__main__":
15961636
import sys
15971637

0 commit comments

Comments
 (0)