Skip to content

Commit 2f14b1e

Browse files
author
Shallow Copy Bot
committed
[serve] Transform replica level metrics to AutoScalingContext constructor args
Original PR #57202 by arcyleung Original: ray-project/ray#57202
1 parent cc97fad commit 2f14b1e

File tree

7 files changed

+367
-110
lines changed

7 files changed

+367
-110
lines changed

python/ray/serve/_private/autoscaling_state.py

Lines changed: 126 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
55

66
from ray.serve._private.common import (
7-
ONGOING_REQUESTS_KEY,
87
RUNNING_REQUESTS_KEY,
98
ApplicationName,
109
DeploymentID,
1110
HandleMetricReport,
1211
ReplicaID,
1312
ReplicaMetricReport,
1413
TargetCapacityDirection,
15-
TimeStampedValue,
14+
TimeSeries,
1615
)
1716
from ray.serve._private.constants import (
1817
RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER,
@@ -22,7 +21,7 @@
2221
from ray.serve._private.deployment_info import DeploymentInfo
2322
from ray.serve._private.metrics_utils import (
2423
aggregate_timeseries,
25-
merge_timeseries_dicts,
24+
merge_instantaneous_total,
2625
)
2726
from ray.serve._private.utils import get_capacity_adjusted_num_replicas
2827
from ray.serve.config import AutoscalingContext, AutoscalingPolicy
@@ -226,36 +225,42 @@ def get_decision_num_replicas(
226225
return self.apply_bounds(decision_num_replicas)
227226

228227
def get_autoscaling_context(self, curr_target_num_replicas):
228+
total_num_requests = self.get_total_num_requests()
229+
total_queued_requests = self._get_queued_requests()
230+
# NOTE: for non additive aggregation functions, total_running_requests is not
231+
# accurate, consider this is a approximation.
232+
total_running_requests = total_num_requests - total_queued_requests
233+
229234
autoscaling_context: AutoscalingContext = AutoscalingContext(
230235
deployment_id=self._deployment_id,
231236
deployment_name=self._deployment_id.name,
232237
app_name=self._deployment_id.app_name,
233238
current_num_replicas=len(self._running_replicas),
234239
target_num_replicas=curr_target_num_replicas,
235240
running_replicas=self._running_replicas,
236-
total_num_requests=self.get_total_num_requests(),
241+
total_num_requests=total_num_requests,
237242
capacity_adjusted_min_replicas=self.get_num_replicas_lower_bound(),
238243
capacity_adjusted_max_replicas=self.get_num_replicas_upper_bound(),
239244
policy_state=(
240245
self._policy_state.copy() if self._policy_state is not None else {}
241246
),
242247
current_time=time.time(),
243248
config=self._config,
244-
queued_requests=None,
245-
requests_per_replica=None,
246-
aggregated_metrics=None,
247-
raw_metrics=None,
249+
total_queued_requests=total_queued_requests,
250+
total_running_requests=total_running_requests,
251+
aggregated_metrics=self._get_aggregated_custom_metrics(),
252+
raw_metrics=self._get_raw_custom_metrics(),
248253
last_scale_up_time=None,
249254
last_scale_down_time=None,
250255
)
251256

252257
return autoscaling_context
253258

254-
def _collect_replica_running_requests(self) -> List[List[TimeStampedValue]]:
259+
def _collect_replica_running_requests(self) -> List[TimeSeries]:
255260
"""Collect running requests timeseries from replicas for aggregation.
256261
257262
Returns:
258-
List of timeseries data (List[TimeStampedValue]).
263+
List of timeseries data.
259264
"""
260265
timeseries_list = []
261266

@@ -271,22 +276,22 @@ def _collect_replica_running_requests(self) -> List[List[TimeStampedValue]]:
271276

272277
return timeseries_list
273278

274-
def _collect_handle_queued_requests(self) -> List[List[TimeStampedValue]]:
279+
def _collect_handle_queued_requests(self) -> List[TimeSeries]:
275280
"""Collect queued requests timeseries from all handles.
276281
277282
Returns:
278-
List of timeseries data (List[TimeStampedValue]).
283+
List of timeseries data.
279284
"""
280285
timeseries_list = []
281286
for handle_metric_report in self._handle_requests.values():
282287
timeseries_list.append(handle_metric_report.queued_requests)
283288
return timeseries_list
284289

285-
def _collect_handle_running_requests(self) -> List[List[TimeStampedValue]]:
290+
def _collect_handle_running_requests(self) -> List[TimeSeries]:
286291
"""Collect running requests timeseries from handles when not collected on replicas.
287292
288293
Returns:
289-
List of timeseries data (List[TimeStampedValue]).
294+
List of timeseries data.
290295
291296
Example:
292297
If there are 2 handles, each managing 2 replicas, and the running requests metrics are:
@@ -316,49 +321,44 @@ def _collect_handle_running_requests(self) -> List[List[TimeStampedValue]]:
316321

317322
return timeseries_list
318323

319-
def _aggregate_ongoing_requests(
320-
self, metrics_timeseries_dicts: List[Dict[str, List[TimeStampedValue]]]
324+
def _merge_and_aggregate_timeseries(
325+
self,
326+
timeseries_list: List[TimeSeries],
321327
) -> float:
322-
"""Aggregate and average ongoing requests from timeseries data using instantaneous merge.
328+
"""Aggregate and average a metric from timeseries data using instantaneous merge.
323329
324330
Args:
325-
metrics_timeseries_dicts: A list of dictionaries, each containing a key-value pair:
326-
- The key is the name of the metric (ONGOING_REQUESTS_KEY)
327-
- The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
328-
this list is sorted by timestamp ascending
331+
timeseries_list: A list of TimeSeries (TimeSeries), where each
332+
TimeSeries represents measurements from a single source (replica, handle, etc.).
333+
Each list is sorted by timestamp ascending.
329334
330335
Returns:
331-
The time-weighted average of the ongoing requests
336+
The time-weighted average of the metric
332337
333338
Example:
334-
If the metrics_timeseries_dicts is:
339+
If the timeseries_list is:
335340
[
336-
{
337-
"ongoing_requests": [
338-
TimeStampedValue(timestamp=0.1, value=5.0),
339-
TimeStampedValue(timestamp=0.2, value=7.0),
340-
]
341-
},
342-
{
343-
"ongoing_requests": [
344-
TimeStampedValue(timestamp=0.2, value=3.0),
345-
TimeStampedValue(timestamp=0.3, value=1.0),
346-
]
347-
}
341+
[
342+
TimeStampedValue(timestamp=0.1, value=5.0),
343+
TimeStampedValue(timestamp=0.2, value=7.0),
344+
],
345+
[
346+
TimeStampedValue(timestamp=0.2, value=3.0),
347+
TimeStampedValue(timestamp=0.3, value=1.0),
348+
]
348349
]
349350
Then the returned value will be:
350351
(5.0*0.1 + 7.0*0.2 + 3.0*0.2 + 1.0*0.3) / (0.1 + 0.2 + 0.2 + 0.3) = 4.5 / 0.8 = 5.625
351352
"""
352353

353-
if not metrics_timeseries_dicts:
354+
if not timeseries_list:
354355
return 0.0
355356

356357
# Use instantaneous merge approach - no arbitrary windowing needed
357-
aggregated_metrics = merge_timeseries_dicts(*metrics_timeseries_dicts)
358-
ongoing_requests_timeseries = aggregated_metrics.get(ONGOING_REQUESTS_KEY, [])
359-
if ongoing_requests_timeseries:
358+
merged_timeseries = merge_instantaneous_total(timeseries_list)
359+
if merged_timeseries:
360360
# assume that the last recorded metric is valid for last_window_s seconds
361-
last_metric_time = ongoing_requests_timeseries[-1].timestamp
361+
last_metric_time = merged_timeseries[-1].timestamp
362362
# we dont want to make any assumption about how long the last metric will be valid
363363
# only conclude that the last metric is valid for last_window_s seconds that is the
364364
# difference between the current time and the last metric recorded time
@@ -367,9 +367,9 @@ def _aggregate_ongoing_requests(
367367
# between replicas and controller. Also add a small epsilon to avoid division by zero
368368
if last_window_s <= 0:
369369
last_window_s = 1e-3
370-
# Calculate the aggregated running requests
370+
# Calculate the aggregated metric value
371371
value = aggregate_timeseries(
372-
ongoing_requests_timeseries,
372+
merged_timeseries,
373373
aggregation_function=self._config.aggregation_function,
374374
last_window_s=last_window_s,
375375
)
@@ -439,11 +439,11 @@ def _calculate_total_requests_aggregate_mode(self) -> float:
439439
Total number of requests (average running + queued) calculated from
440440
timeseries data aggregation.
441441
"""
442-
# Collect replica-based running requests (returns List[List[TimeStampedValue]])
442+
# Collect replica-based running requests (returns List[TimeSeries])
443443
replica_timeseries = self._collect_replica_running_requests()
444444
metrics_collected_on_replicas = len(replica_timeseries) > 0
445445

446-
# Collect queued requests from handles (returns List[List[TimeStampedValue]])
446+
# Collect queued requests from handles (returns List[TimeSeries])
447447
queued_timeseries = self._collect_handle_queued_requests()
448448

449449
if not metrics_collected_on_replicas:
@@ -452,23 +452,23 @@ def _calculate_total_requests_aggregate_mode(self) -> float:
452452
else:
453453
handle_timeseries = []
454454

455-
# Create minimal dictionary objects only when needed
456-
ongoing_requests_metrics = []
455+
# Collect all timeseries for ongoing requests
456+
ongoing_requests_timeseries = []
457457

458-
# Add replica timeseries with minimal dict wrapping
459-
for timeseries in replica_timeseries:
460-
ongoing_requests_metrics.append({ONGOING_REQUESTS_KEY: timeseries})
458+
# Add replica timeseries
459+
ongoing_requests_timeseries.extend(replica_timeseries)
461460

462461
# Add handle timeseries if replica metrics weren't collected
463462
if not metrics_collected_on_replicas:
464-
for timeseries in handle_timeseries:
465-
ongoing_requests_metrics.append({ONGOING_REQUESTS_KEY: timeseries})
463+
ongoing_requests_timeseries.extend(handle_timeseries)
464+
465+
# Add queued timeseries
466+
ongoing_requests_timeseries.extend(queued_timeseries)
466467

467-
# Add queued timeseries with minimal dict wrapping
468-
for timeseries in queued_timeseries:
469-
ongoing_requests_metrics.append({ONGOING_REQUESTS_KEY: timeseries})
470468
# Aggregate and add running requests to total
471-
ongoing_requests = self._aggregate_ongoing_requests(ongoing_requests_metrics)
469+
ongoing_requests = self._merge_and_aggregate_timeseries(
470+
ongoing_requests_timeseries
471+
)
472472

473473
return ongoing_requests
474474

@@ -557,11 +557,8 @@ def get_total_num_requests(self) -> float:
557557
else:
558558
return self._calculate_total_requests_simple_mode()
559559

560-
def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
560+
def get_replica_metrics(self) -> Dict[ReplicaID, List[TimeSeries]]:
561561
"""Get the raw replica metrics dict."""
562-
# arcyleung TODO: pass agg_func from autoscaling policy https://github.com/ray-project/ray/pull/51905
563-
# Dummy implementation of mean agg_func across all values of the same metrics key
564-
565562
metric_values = defaultdict(list)
566563
for id in self._running_replicas:
567564
if id in self._replica_metrics and self._replica_metrics[id].metrics:
@@ -570,6 +567,71 @@ def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
570567

571568
return metric_values
572569

570+
def _get_queued_requests(self) -> float:
571+
"""Calculate the total number of queued requests across all handles.
572+
573+
Returns:
574+
Sum of queued requests at all handles. Uses aggregated values in simple mode,
575+
or aggregates timeseries data in aggregate mode.
576+
"""
577+
if RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER:
578+
# Aggregate mode: collect and aggregate timeseries
579+
queued_timeseries = self._collect_handle_queued_requests()
580+
if not queued_timeseries:
581+
return 0.0
582+
583+
return self._merge_and_aggregate_timeseries(queued_timeseries)
584+
else:
585+
# Simple mode: sum pre-aggregated values
586+
return sum(
587+
handle_metric.aggregated_queued_requests
588+
for handle_metric in self._handle_requests.values()
589+
)
590+
591+
def _get_aggregated_custom_metrics(self) -> Dict[str, Dict[ReplicaID, float]]:
592+
"""Aggregate custom metrics from replica metric reports.
593+
594+
This method aggregates raw timeseries data from replicas on the controller,
595+
similar to how ongoing requests are aggregated.
596+
597+
Returns:
598+
Dict mapping metric name to dict of replica ID to aggregated metric value.
599+
"""
600+
aggregated_metrics = defaultdict(dict)
601+
602+
for replica_id in self._running_replicas:
603+
replica_metric_report = self._replica_metrics.get(replica_id)
604+
if replica_metric_report is None:
605+
continue
606+
607+
for metric_name, timeseries in replica_metric_report.metrics.items():
608+
# Aggregate the timeseries for this custom metric
609+
aggregated_value = self._merge_and_aggregate_timeseries([timeseries])
610+
aggregated_metrics[metric_name][replica_id] = aggregated_value
611+
612+
return dict(aggregated_metrics)
613+
614+
def _get_raw_custom_metrics(
615+
self,
616+
) -> Dict[str, Dict[ReplicaID, TimeSeries]]:
617+
"""Extract raw custom metric values from replica metric reports.
618+
619+
Returns:
620+
Dict mapping metric name to dict of replica ID to raw metric timeseries.
621+
"""
622+
raw_metrics = defaultdict(dict)
623+
624+
for replica_id in self._running_replicas:
625+
replica_metric_report = self._replica_metrics.get(replica_id)
626+
if replica_metric_report is None:
627+
continue
628+
629+
for metric_name, timeseries in replica_metric_report.metrics.items():
630+
# Extract values from TimeStampedValue list
631+
raw_metrics[metric_name][replica_id] = timeseries
632+
633+
return dict(raw_metrics)
634+
573635

574636
class ApplicationAutoscalingState:
575637
"""Manages autoscaling for a single application."""
@@ -732,12 +794,8 @@ def get_total_num_requests_for_deployment(
732794
deployment_id
733795
].get_total_num_requests()
734796

735-
def get_replica_metrics_by_deployment_id(
736-
self, deployment_id: DeploymentID, agg_func="mean"
737-
):
738-
return self._deployment_autoscaling_states[deployment_id].get_replica_metrics(
739-
agg_func
740-
)
797+
def get_replica_metrics_by_deployment_id(self, deployment_id: DeploymentID):
798+
return self._deployment_autoscaling_states[deployment_id].get_replica_metrics()
741799

742800
def is_within_bounds(
743801
self, deployment_id: DeploymentID, num_replicas_running_at_target_version: int
@@ -891,12 +949,12 @@ def on_replica_stopped(self, replica_id: ReplicaID):
891949
)
892950

893951
def get_metrics_for_deployment(
894-
self, deployment_id: DeploymentID, agg_func="mean"
895-
) -> Dict[ReplicaID, List[Any]]:
952+
self, deployment_id: DeploymentID
953+
) -> Dict[ReplicaID, List[TimeSeries]]:
896954
if deployment_id.app_name in self._app_autoscaling_states:
897955
return self._app_autoscaling_states[
898956
deployment_id.app_name
899-
].get_replica_metrics_by_deployment_id(deployment_id, agg_func)
957+
].get_replica_metrics_by_deployment_id(deployment_id)
900958
else:
901959
logger.warning(
902960
f"Cannot get metrics for deployment "

python/ray/serve/_private/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,7 @@ class CreatePlacementGroupRequest:
757757

758758
RUNNING_REQUESTS_KEY = "running_requests"
759759
ONGOING_REQUESTS_KEY = "ongoing_requests"
760+
QUEUED_REQUESTS_KEY = "queued_requests"
760761

761762

762763
@dataclass(order=True)
@@ -765,6 +766,10 @@ class TimeStampedValue:
765766
value: float = field(compare=False)
766767

767768

769+
# Type alias for time series data
770+
TimeSeries = List[TimeStampedValue]
771+
772+
768773
@dataclass
769774
class HandleMetricReport:
770775
"""Report from a deployment handle on queued and ongoing requests.
@@ -795,9 +800,9 @@ class HandleMetricReport:
795800
actor_id: str
796801
handle_source: DeploymentHandleSource
797802
aggregated_queued_requests: float
798-
queued_requests: List[TimeStampedValue]
803+
queued_requests: TimeSeries
799804
aggregated_metrics: Dict[str, Dict[ReplicaID, float]]
800-
metrics: Dict[str, Dict[ReplicaID, List[TimeStampedValue]]]
805+
metrics: Dict[str, Dict[ReplicaID, TimeSeries]]
801806
timestamp: float
802807

803808
@property
@@ -838,5 +843,5 @@ class ReplicaMetricReport:
838843

839844
replica_id: ReplicaID
840845
aggregated_metrics: Dict[str, float]
841-
metrics: Dict[str, List[TimeStampedValue]]
846+
metrics: Dict[str, TimeSeries]
842847
timestamp: float

0 commit comments

Comments
 (0)