44from typing import Any , Callable , Dict , List , Optional , Set , Tuple
55
66from ray .serve ._private .common import (
7- ONGOING_REQUESTS_KEY ,
87 RUNNING_REQUESTS_KEY ,
98 ApplicationName ,
109 DeploymentID ,
1110 HandleMetricReport ,
1211 ReplicaID ,
1312 ReplicaMetricReport ,
1413 TargetCapacityDirection ,
15- TimeStampedValue ,
14+ TimeSeries ,
1615)
1716from ray .serve ._private .constants import (
1817 RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER ,
2221from ray .serve ._private .deployment_info import DeploymentInfo
2322from ray .serve ._private .metrics_utils import (
2423 aggregate_timeseries ,
25- merge_timeseries_dicts ,
24+ merge_instantaneous_total ,
2625)
2726from ray .serve ._private .utils import get_capacity_adjusted_num_replicas
2827from ray .serve .config import AutoscalingContext , AutoscalingPolicy
@@ -226,36 +225,42 @@ def get_decision_num_replicas(
226225 return self .apply_bounds (decision_num_replicas )
227226
228227 def get_autoscaling_context (self , curr_target_num_replicas ):
228+ total_num_requests = self .get_total_num_requests ()
229+ total_queued_requests = self ._get_queued_requests ()
230+ # NOTE: for non additive aggregation functions, total_running_requests is not
231+ # accurate, consider this is a approximation.
232+ total_running_requests = total_num_requests - total_queued_requests
233+
229234 autoscaling_context : AutoscalingContext = AutoscalingContext (
230235 deployment_id = self ._deployment_id ,
231236 deployment_name = self ._deployment_id .name ,
232237 app_name = self ._deployment_id .app_name ,
233238 current_num_replicas = len (self ._running_replicas ),
234239 target_num_replicas = curr_target_num_replicas ,
235240 running_replicas = self ._running_replicas ,
236- total_num_requests = self . get_total_num_requests () ,
241+ total_num_requests = total_num_requests ,
237242 capacity_adjusted_min_replicas = self .get_num_replicas_lower_bound (),
238243 capacity_adjusted_max_replicas = self .get_num_replicas_upper_bound (),
239244 policy_state = (
240245 self ._policy_state .copy () if self ._policy_state is not None else {}
241246 ),
242247 current_time = time .time (),
243248 config = self ._config ,
244- queued_requests = None ,
245- requests_per_replica = None ,
246- aggregated_metrics = None ,
247- raw_metrics = None ,
249+ total_queued_requests = total_queued_requests ,
250+ total_running_requests = total_running_requests ,
251+ aggregated_metrics = self . _get_aggregated_custom_metrics () ,
252+ raw_metrics = self . _get_raw_custom_metrics () ,
248253 last_scale_up_time = None ,
249254 last_scale_down_time = None ,
250255 )
251256
252257 return autoscaling_context
253258
254- def _collect_replica_running_requests (self ) -> List [List [ TimeStampedValue ] ]:
259+ def _collect_replica_running_requests (self ) -> List [TimeSeries ]:
255260 """Collect running requests timeseries from replicas for aggregation.
256261
257262 Returns:
258- List of timeseries data (List[TimeStampedValue]) .
263+ List of timeseries data.
259264 """
260265 timeseries_list = []
261266
@@ -271,22 +276,22 @@ def _collect_replica_running_requests(self) -> List[List[TimeStampedValue]]:
271276
272277 return timeseries_list
273278
274- def _collect_handle_queued_requests (self ) -> List [List [ TimeStampedValue ] ]:
279+ def _collect_handle_queued_requests (self ) -> List [TimeSeries ]:
275280 """Collect queued requests timeseries from all handles.
276281
277282 Returns:
278- List of timeseries data (List[TimeStampedValue]) .
283+ List of timeseries data.
279284 """
280285 timeseries_list = []
281286 for handle_metric_report in self ._handle_requests .values ():
282287 timeseries_list .append (handle_metric_report .queued_requests )
283288 return timeseries_list
284289
285- def _collect_handle_running_requests (self ) -> List [List [ TimeStampedValue ] ]:
290+ def _collect_handle_running_requests (self ) -> List [TimeSeries ]:
286291 """Collect running requests timeseries from handles when not collected on replicas.
287292
288293 Returns:
289- List of timeseries data (List[TimeStampedValue]) .
294+ List of timeseries data.
290295
291296 Example:
292297 If there are 2 handles, each managing 2 replicas, and the running requests metrics are:
@@ -316,49 +321,44 @@ def _collect_handle_running_requests(self) -> List[List[TimeStampedValue]]:
316321
317322 return timeseries_list
318323
319- def _aggregate_ongoing_requests (
320- self , metrics_timeseries_dicts : List [Dict [str , List [TimeStampedValue ]]]
324+ def _merge_and_aggregate_timeseries (
325+ self ,
326+ timeseries_list : List [TimeSeries ],
321327 ) -> float :
322- """Aggregate and average ongoing requests from timeseries data using instantaneous merge.
328+ """Aggregate and average a metric from timeseries data using instantaneous merge.
323329
324330 Args:
325- metrics_timeseries_dicts: A list of dictionaries, each containing a key-value pair:
326- - The key is the name of the metric (ONGOING_REQUESTS_KEY)
327- - The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
328- this list is sorted by timestamp ascending
331+ timeseries_list: A list of TimeSeries (TimeSeries), where each
332+ TimeSeries represents measurements from a single source (replica, handle, etc.).
333+ Each list is sorted by timestamp ascending.
329334
330335 Returns:
331- The time-weighted average of the ongoing requests
336+ The time-weighted average of the metric
332337
333338 Example:
334- If the metrics_timeseries_dicts is:
339+ If the timeseries_list is:
335340 [
336- {
337- "ongoing_requests": [
338- TimeStampedValue(timestamp=0.1, value=5.0),
339- TimeStampedValue(timestamp=0.2, value=7.0),
340- ]
341- },
342- {
343- "ongoing_requests": [
344- TimeStampedValue(timestamp=0.2, value=3.0),
345- TimeStampedValue(timestamp=0.3, value=1.0),
346- ]
347- }
341+ [
342+ TimeStampedValue(timestamp=0.1, value=5.0),
343+ TimeStampedValue(timestamp=0.2, value=7.0),
344+ ],
345+ [
346+ TimeStampedValue(timestamp=0.2, value=3.0),
347+ TimeStampedValue(timestamp=0.3, value=1.0),
348+ ]
348349 ]
349350 Then the returned value will be:
350351 (5.0*0.1 + 7.0*0.2 + 3.0*0.2 + 1.0*0.3) / (0.1 + 0.2 + 0.2 + 0.3) = 4.5 / 0.8 = 5.625
351352 """
352353
353- if not metrics_timeseries_dicts :
354+ if not timeseries_list :
354355 return 0.0
355356
356357 # Use instantaneous merge approach - no arbitrary windowing needed
357- aggregated_metrics = merge_timeseries_dicts (* metrics_timeseries_dicts )
358- ongoing_requests_timeseries = aggregated_metrics .get (ONGOING_REQUESTS_KEY , [])
359- if ongoing_requests_timeseries :
358+ merged_timeseries = merge_instantaneous_total (timeseries_list )
359+ if merged_timeseries :
360360 # assume that the last recorded metric is valid for last_window_s seconds
361- last_metric_time = ongoing_requests_timeseries [- 1 ].timestamp
361+ last_metric_time = merged_timeseries [- 1 ].timestamp
362362 # we dont want to make any assumption about how long the last metric will be valid
363363 # only conclude that the last metric is valid for last_window_s seconds that is the
364364 # difference between the current time and the last metric recorded time
@@ -367,9 +367,9 @@ def _aggregate_ongoing_requests(
367367 # between replicas and controller. Also add a small epsilon to avoid division by zero
368368 if last_window_s <= 0 :
369369 last_window_s = 1e-3
370- # Calculate the aggregated running requests
370+ # Calculate the aggregated metric value
371371 value = aggregate_timeseries (
372- ongoing_requests_timeseries ,
372+ merged_timeseries ,
373373 aggregation_function = self ._config .aggregation_function ,
374374 last_window_s = last_window_s ,
375375 )
@@ -439,11 +439,11 @@ def _calculate_total_requests_aggregate_mode(self) -> float:
439439 Total number of requests (average running + queued) calculated from
440440 timeseries data aggregation.
441441 """
442- # Collect replica-based running requests (returns List[List[TimeStampedValue] ])
442+ # Collect replica-based running requests (returns List[TimeSeries ])
443443 replica_timeseries = self ._collect_replica_running_requests ()
444444 metrics_collected_on_replicas = len (replica_timeseries ) > 0
445445
446- # Collect queued requests from handles (returns List[List[TimeStampedValue] ])
446+ # Collect queued requests from handles (returns List[TimeSeries ])
447447 queued_timeseries = self ._collect_handle_queued_requests ()
448448
449449 if not metrics_collected_on_replicas :
@@ -452,23 +452,23 @@ def _calculate_total_requests_aggregate_mode(self) -> float:
452452 else :
453453 handle_timeseries = []
454454
455- # Create minimal dictionary objects only when needed
456- ongoing_requests_metrics = []
455+ # Collect all timeseries for ongoing requests
456+ ongoing_requests_timeseries = []
457457
458- # Add replica timeseries with minimal dict wrapping
459- for timeseries in replica_timeseries :
460- ongoing_requests_metrics .append ({ONGOING_REQUESTS_KEY : timeseries })
458+ # Add replica timeseries
459+ ongoing_requests_timeseries .extend (replica_timeseries )
461460
462461 # Add handle timeseries if replica metrics weren't collected
463462 if not metrics_collected_on_replicas :
464- for timeseries in handle_timeseries :
465- ongoing_requests_metrics .append ({ONGOING_REQUESTS_KEY : timeseries })
463+ ongoing_requests_timeseries .extend (handle_timeseries )
464+
465+ # Add queued timeseries
466+ ongoing_requests_timeseries .extend (queued_timeseries )
466467
467- # Add queued timeseries with minimal dict wrapping
468- for timeseries in queued_timeseries :
469- ongoing_requests_metrics .append ({ONGOING_REQUESTS_KEY : timeseries })
470468 # Aggregate and add running requests to total
471- ongoing_requests = self ._aggregate_ongoing_requests (ongoing_requests_metrics )
469+ ongoing_requests = self ._merge_and_aggregate_timeseries (
470+ ongoing_requests_timeseries
471+ )
472472
473473 return ongoing_requests
474474
@@ -557,11 +557,8 @@ def get_total_num_requests(self) -> float:
557557 else :
558558 return self ._calculate_total_requests_simple_mode ()
559559
560- def get_replica_metrics (self , agg_func : str ) -> Dict [ReplicaID , List [Any ]]:
560+ def get_replica_metrics (self ) -> Dict [ReplicaID , List [TimeSeries ]]:
561561 """Get the raw replica metrics dict."""
562- # arcyleung TODO: pass agg_func from autoscaling policy https://github.com/ray-project/ray/pull/51905
563- # Dummy implementation of mean agg_func across all values of the same metrics key
564-
565562 metric_values = defaultdict (list )
566563 for id in self ._running_replicas :
567564 if id in self ._replica_metrics and self ._replica_metrics [id ].metrics :
@@ -570,6 +567,71 @@ def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
570567
571568 return metric_values
572569
570+ def _get_queued_requests (self ) -> float :
571+ """Calculate the total number of queued requests across all handles.
572+
573+ Returns:
574+ Sum of queued requests at all handles. Uses aggregated values in simple mode,
575+ or aggregates timeseries data in aggregate mode.
576+ """
577+ if RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER :
578+ # Aggregate mode: collect and aggregate timeseries
579+ queued_timeseries = self ._collect_handle_queued_requests ()
580+ if not queued_timeseries :
581+ return 0.0
582+
583+ return self ._merge_and_aggregate_timeseries (queued_timeseries )
584+ else :
585+ # Simple mode: sum pre-aggregated values
586+ return sum (
587+ handle_metric .aggregated_queued_requests
588+ for handle_metric in self ._handle_requests .values ()
589+ )
590+
591+ def _get_aggregated_custom_metrics (self ) -> Dict [str , Dict [ReplicaID , float ]]:
592+ """Aggregate custom metrics from replica metric reports.
593+
594+ This method aggregates raw timeseries data from replicas on the controller,
595+ similar to how ongoing requests are aggregated.
596+
597+ Returns:
598+ Dict mapping metric name to dict of replica ID to aggregated metric value.
599+ """
600+ aggregated_metrics = defaultdict (dict )
601+
602+ for replica_id in self ._running_replicas :
603+ replica_metric_report = self ._replica_metrics .get (replica_id )
604+ if replica_metric_report is None :
605+ continue
606+
607+ for metric_name , timeseries in replica_metric_report .metrics .items ():
608+ # Aggregate the timeseries for this custom metric
609+ aggregated_value = self ._merge_and_aggregate_timeseries ([timeseries ])
610+ aggregated_metrics [metric_name ][replica_id ] = aggregated_value
611+
612+ return dict (aggregated_metrics )
613+
614+ def _get_raw_custom_metrics (
615+ self ,
616+ ) -> Dict [str , Dict [ReplicaID , TimeSeries ]]:
617+ """Extract raw custom metric values from replica metric reports.
618+
619+ Returns:
620+ Dict mapping metric name to dict of replica ID to raw metric timeseries.
621+ """
622+ raw_metrics = defaultdict (dict )
623+
624+ for replica_id in self ._running_replicas :
625+ replica_metric_report = self ._replica_metrics .get (replica_id )
626+ if replica_metric_report is None :
627+ continue
628+
629+ for metric_name , timeseries in replica_metric_report .metrics .items ():
630+ # Extract values from TimeStampedValue list
631+ raw_metrics [metric_name ][replica_id ] = timeseries
632+
633+ return dict (raw_metrics )
634+
573635
574636class ApplicationAutoscalingState :
575637 """Manages autoscaling for a single application."""
@@ -732,12 +794,8 @@ def get_total_num_requests_for_deployment(
732794 deployment_id
733795 ].get_total_num_requests ()
734796
735- def get_replica_metrics_by_deployment_id (
736- self , deployment_id : DeploymentID , agg_func = "mean"
737- ):
738- return self ._deployment_autoscaling_states [deployment_id ].get_replica_metrics (
739- agg_func
740- )
797+ def get_replica_metrics_by_deployment_id (self , deployment_id : DeploymentID ):
798+ return self ._deployment_autoscaling_states [deployment_id ].get_replica_metrics ()
741799
742800 def is_within_bounds (
743801 self , deployment_id : DeploymentID , num_replicas_running_at_target_version : int
@@ -891,12 +949,12 @@ def on_replica_stopped(self, replica_id: ReplicaID):
891949 )
892950
893951 def get_metrics_for_deployment (
894- self , deployment_id : DeploymentID , agg_func = "mean"
895- ) -> Dict [ReplicaID , List [Any ]]:
952+ self , deployment_id : DeploymentID
953+ ) -> Dict [ReplicaID , List [TimeSeries ]]:
896954 if deployment_id .app_name in self ._app_autoscaling_states :
897955 return self ._app_autoscaling_states [
898956 deployment_id .app_name
899- ].get_replica_metrics_by_deployment_id (deployment_id , agg_func )
957+ ].get_replica_metrics_by_deployment_id (deployment_id )
900958 else :
901959 logger .warning (
902960 f"Cannot get metrics for deployment "
0 commit comments