11import logging
22import time
33from collections import defaultdict
4- from typing import Any , Dict , List , Optional , Set
4+ from typing import Dict , List , Optional , Set
55
66from ray .serve ._private .common import (
77 ONGOING_REQUESTS_KEY ,
@@ -207,6 +207,9 @@ def get_decision_num_replicas(
207207 `_skip_bound_check` is True, then the bounds are not applied.
208208 """
209209
210+ total_num_requests = self .get_total_num_requests ()
211+ total_queued_requests = self ._get_queued_requests ()
212+ total_running_requests = total_num_requests - total_queued_requests
210213 autoscaling_context : AutoscalingContext = AutoscalingContext (
211214 deployment_id = self ._deployment_id ,
212215 deployment_name = self ._deployment_id .name ,
@@ -220,10 +223,10 @@ def get_decision_num_replicas(
220223 policy_state = self ._policy_state .copy (),
221224 current_time = time .time (),
222225 config = self ._config ,
223- queued_requests = None ,
224- requests_per_replica = None ,
225- aggregated_metrics = None ,
226- raw_metrics = None ,
226+ total_queued_requests = self . _get_queued_requests () ,
227+ total_running_requests = total_running_requests ,
228+ aggregated_metrics = self . _get_aggregated_custom_metrics () ,
229+ raw_metrics = self . _get_raw_custom_metrics () ,
227230 last_scale_up_time = None ,
228231 last_scale_down_time = None ,
229232 )
@@ -300,19 +303,22 @@ def _collect_handle_running_requests(self) -> List[List[TimeStampedValue]]:
300303
301304 return timeseries_list
302305
303- def _aggregate_ongoing_requests (
304- self , metrics_timeseries_dicts : List [Dict [str , List [TimeStampedValue ]]]
306+ def _aggregate_timeseries_metric (
307+ self ,
308+ metrics_timeseries_dicts : List [Dict [str , List [TimeStampedValue ]]],
309+ metric_key : str ,
305310 ) -> float :
306- """Aggregate and average ongoing requests from timeseries data using instantaneous merge.
311+ """Aggregate and average a metric from timeseries data using instantaneous merge.
307312
308313 Args:
309314 metrics_timeseries_dicts: A list of dictionaries, each containing a key-value pair:
310- - The key is the name of the metric (ONGOING_REQUESTS_KEY)
315+ - The key is the name of the metric (e.g., ONGOING_REQUESTS_KEY or custom metric name )
311316 - The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
312317 this list is sorted by timestamp ascending
318+ metric_key: The key to use when extracting the metric from the dictionaries
313319
314320 Returns:
315- The time-weighted average of the ongoing requests
321+ The time-weighted average of the metric
316322
317323 Example:
318324 If the metrics_timeseries_dicts is:
@@ -339,10 +345,10 @@ def _aggregate_ongoing_requests(
339345
340346 # Use instantaneous merge approach - no arbitrary windowing needed
341347 aggregated_metrics = merge_timeseries_dicts (* metrics_timeseries_dicts )
342- ongoing_requests_timeseries = aggregated_metrics .get (ONGOING_REQUESTS_KEY , [])
343- if ongoing_requests_timeseries :
348+ metric_timeseries = aggregated_metrics .get (metric_key , [])
349+ if metric_timeseries :
344350 # assume that the last recorded metric is valid for last_window_s seconds
345- last_metric_time = ongoing_requests_timeseries [- 1 ].timestamp
351+ last_metric_time = metric_timeseries [- 1 ].timestamp
346352 # we dont want to make any assumption about how long the last metric will be valid
347353 # only conclude that the last metric is valid for last_window_s seconds that is the
348354 # difference between the current time and the last metric recorded time
@@ -351,16 +357,34 @@ def _aggregate_ongoing_requests(
351357 # between replicas and controller. Also add a small epsilon to avoid division by zero
352358 if last_window_s <= 0 :
353359 last_window_s = 1e-3
354- # Calculate the aggregated running requests
360+ # Calculate the aggregated metric value
355361 value = aggregate_timeseries (
356- ongoing_requests_timeseries ,
362+ metric_timeseries ,
357363 aggregation_function = self ._config .aggregation_function ,
358364 last_window_s = last_window_s ,
359365 )
360366 return value if value is not None else 0.0
361367
362368 return 0.0
363369
370+ def _aggregate_ongoing_requests (
371+ self , metrics_timeseries_dicts : List [Dict [str , List [TimeStampedValue ]]]
372+ ) -> float :
373+ """Aggregate and average ongoing requests from timeseries data using instantaneous merge.
374+
375+ This is a convenience wrapper around _aggregate_timeseries_metric for ongoing requests.
376+
377+ Args:
378+ metrics_timeseries_dicts: A list of dictionaries containing ONGOING_REQUESTS_KEY
379+ mapped to timeseries data.
380+
381+ Returns:
382+ The time-weighted average of the ongoing requests
383+ """
384+ return self ._aggregate_timeseries_metric (
385+ metrics_timeseries_dicts , ONGOING_REQUESTS_KEY
386+ )
387+
364388 def _calculate_total_requests_aggregate_mode (self ) -> float :
365389 """Calculate total requests using aggregate metrics mode with timeseries data.
366390
@@ -541,11 +565,8 @@ def get_total_num_requests(self) -> float:
541565 else :
542566 return self ._calculate_total_requests_simple_mode ()
543567
544- def get_replica_metrics (self , agg_func : str ) -> Dict [ReplicaID , List [Any ]]:
568+ def get_replica_metrics (self ) -> Dict [ReplicaID , List [TimeStampedValue ]]:
545569 """Get the raw replica metrics dict."""
546- # arcyleung TODO: pass agg_func from autoscaling policy https://github.com/ray-project/ray/pull/51905
547- # Dummy implementation of mean agg_func across all values of the same metrics key
548-
549570 metric_values = defaultdict (list )
550571 for id in self ._running_replicas :
551572 if id in self ._replica_metrics and self ._replica_metrics [id ].metrics :
@@ -554,6 +575,86 @@ def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
554575
555576 return metric_values
556577
578+ def _get_queued_requests (self ) -> float :
579+ """Calculate the total number of queued requests across all handles.
580+
581+ Returns:
582+ Sum of queued requests at all handles. Uses aggregated values in simple mode,
583+ or aggregates timeseries data in aggregate mode.
584+ """
585+ if RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER :
586+ # Aggregate mode: collect and aggregate timeseries
587+ queued_timeseries = self ._collect_handle_queued_requests ()
588+ if not queued_timeseries :
589+ return 0.0
590+
591+ queued_metrics = [
592+ {ONGOING_REQUESTS_KEY : timeseries } for timeseries in queued_timeseries
593+ ]
594+ return self ._aggregate_ongoing_requests (queued_metrics )
595+ else :
596+ # Simple mode: sum pre-aggregated values
597+ return sum (
598+ handle_metric .aggregated_queued_requests
599+ for handle_metric in self ._handle_requests .values ()
600+ )
601+
602+ def _get_aggregated_custom_metrics (self ) -> Dict [str , Dict [ReplicaID , float ]]:
603+ """Aggregate custom metrics from replica metric reports.
604+
605+ Custom metrics are all metrics except RUNNING_REQUESTS_KEY. These are metrics
606+ emitted by the deployment using the `record_autoscaling_stats` method.
607+
608+ This method aggregates raw timeseries data from replicas on the controller,
609+ similar to how ongoing requests are aggregated.
610+
611+ Returns:
612+ Dict mapping metric name to dict of replica ID to aggregated metric value.
613+ """
614+ aggregated_metrics = defaultdict (dict )
615+
616+ for replica_id in self ._running_replicas :
617+ replica_metric_report = self ._replica_metrics .get (replica_id )
618+ if replica_metric_report is None :
619+ continue
620+
621+ for metric_name , timeseries in replica_metric_report .metrics .items ():
622+ if metric_name != RUNNING_REQUESTS_KEY :
623+ # Aggregate the timeseries for this custom metric
624+ # Use the actual metric name as the key
625+ metrics = [{metric_name : timeseries }]
626+ aggregated_value = self ._aggregate_timeseries_metric (
627+ metrics , metric_name
628+ )
629+ aggregated_metrics [metric_name ][replica_id ] = aggregated_value
630+
631+ return dict (aggregated_metrics )
632+
633+ def _get_raw_custom_metrics (
634+ self ,
635+ ) -> Dict [str , Dict [ReplicaID , List [TimeStampedValue ]]]:
636+ """Extract raw custom metric values from replica metric reports.
637+
638+ Custom metrics are all metrics except RUNNING_REQUESTS_KEY. These are metrics
639+ emitted by the deployment using the `record_autoscaling_stats` method.
640+
641+ Returns:
642+ Dict mapping metric name to dict of replica ID to list of raw metric values.
643+ """
644+ raw_metrics = defaultdict (dict )
645+
646+ for replica_id in self ._running_replicas :
647+ replica_metric_report = self ._replica_metrics .get (replica_id )
648+ if replica_metric_report is None :
649+ continue
650+
651+ for metric_name , timeseries in replica_metric_report .metrics .items ():
652+ if metric_name != RUNNING_REQUESTS_KEY :
653+ # Extract values from TimeStampedValue list
654+ raw_metrics [metric_name ][replica_id ] = timeseries
655+
656+ return dict (raw_metrics )
657+
557658
558659class AutoscalingStateManager :
559660 """Manages all things autoscaling related.
@@ -602,12 +703,10 @@ def get_metrics(self) -> Dict[DeploymentID, float]:
602703 }
603704
604705 def get_all_metrics (
605- self , agg_func = "mean"
606- ) -> Dict [DeploymentID , Dict [ReplicaID , List [Any ]]]:
706+ self ,
707+ ) -> Dict [DeploymentID , Dict [ReplicaID , List [TimeStampedValue ]]]:
607708 return {
608- deployment_id : self ._autoscaling_states [deployment_id ].get_replica_metrics (
609- agg_func
610- )
709+ deployment_id : self ._autoscaling_states [deployment_id ].get_replica_metrics ()
611710 for deployment_id in self ._autoscaling_states
612711 }
613712
0 commit comments