From 776ea4c3e8a48d7867a26c3daaa9420cf60c245d Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Tue, 11 Nov 2025 20:01:59 +0000 Subject: [PATCH 1/6] Add all slo aware routing plugins, no integration changes --- .../plugins/multi/slo_aware_router/config.go | 191 ++++ .../plugins/multi/slo_aware_router/headers.go | 70 ++ .../plugins/multi/slo_aware_router/helpers.go | 145 +++ .../latencypredictor_helper.go | 439 ++++++++ .../multi/slo_aware_router/prediction.go | 138 +++ .../slo_aware_router/requestcontrol_hooks.go | 262 +++++ .../requestcontrol_hooks_test.go | 945 ++++++++++++++++++ .../slo_aware_router/running_request_queue.go | 243 +++++ .../running_request_queue_test.go | 391 ++++++++ .../plugins/multi/slo_aware_router/sampler.go | 136 +++ .../plugins/multi/slo_aware_router/scorer.go | 325 ++++++ .../multi/slo_aware_router/scorer_test.go | 527 ++++++++++ .../multi/slo_aware_router/selection.go | 385 +++++++ .../plugins/multi/slo_aware_router/types.go | 57 ++ .../profile/slo_aware_profile_handler.go | 154 +++ 15 files changed, 4408 insertions(+) create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go create mode 100644 pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go new file mode 100644 index 0000000000..fcb4b72236 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -0,0 +1,191 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "os" + "strconv" + "strings" +) + +var DefaultSamplingMean = func() float64 { + if value, exists := os.LookupEnv("SAMPLING_MEAN"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 100.0 // default value +}() + +var MaxSampledTokens = func() int { + if value, exists := os.LookupEnv("MAX_SAMPLED_TOKENS"); exists { + if parsedValue, err := strconv.Atoi(value); err == nil && parsedValue > 0 { + return parsedValue + } + } + return 20 // default value +}() + +var SLOBufferFactor = func() float64 { + if value, exists := os.LookupEnv("SLO_BUFFER_FACTOR"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil { + return parsedValue + } + } + return 1.0 // default value +}() + +var NegHeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default: TTFT dominates when violating SLOs +}() + +var NegHeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("NEG_HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default: TPOT less important in your tiny-output scenario +}() + +var HeadroomTTFTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TTFT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.8 // default +}() + +var HeadroomTPOTWeight = func() float64 { + if value, exists := os.LookupEnv("HEADROOM_TPOT_WEIGHT"); exists { + if parsedValue, err := strconv.ParseFloat(value, 64); err == nil && parsedValue >= 0 { + return parsedValue + } + } + return 0.2 // default +}() + +var HeadroomSelectionStrategy = func() HeadroomStrategy { + if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { + switch strings.ToLower(value) { + case "least": + return HeadroomStrategyLeast + case "most": + return HeadroomStrategyMost + case "composite-least": + return HeadroomStrategyCompositeLeast + case "composite-most": + return HeadroomStrategyCompositeMost + case "composite-only": + return HeadroomStrategyCompositeOnly + } + } + return HeadroomStrategyLeast // default to least (better packing) +}() + +// If using composite headroom, weights for each component. Not used by default +var CompositeKVWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_KV_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositeQueueWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_QUEUE_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +var CompositePrefixWeight = func() float64 { + if v, ok := os.LookupEnv("COMPOSITE_PREFIX_WEIGHT"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 { + return f + } + } + return 1 +}() + +// With probability ε, explore (ignore affinity gate); otherwise exploit. +var EpsilonExploreSticky = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("STICKY_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +var EpsilonExploreNeg = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("NEG_HEADROOM_EPSILON"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.01 // default 1% exploration +}() + +// τ for per-path affinity gate (aka "stickiness" threshold). +var AffinityGateTau = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.80 +}() + +// Global τ for the overall candidate set (previously "overall stickiness"). +var AffinityGateTauGlobal = func() float64 { + // Prefer new env; fall back to old for compatibility. + if v, ok := os.LookupEnv("AFFINITY_GATE_TAU_GLOBAL"); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil && f >= 0 && f <= 1 { + return f + } + } + return 0.99 +}() + +// Read once at init. Values: "linear" (default) or "max". +var SelectionMode = func() PodSelectionMode { + if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { + switch strings.ToLower(v) { + case "max": + return PodSelectionMax + case "linear": + fallthrough + default: + return PodSelectionLinear + } + } + return PodSelectionLinear +}() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go new file mode 100644 index 0000000000..8574ec41b0 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -0,0 +1,70 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "fmt" + "strconv" + + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +// parseFloatHeader retrieves a header by name, parses it as a float64, +// and returns the value or an error if the header is missing or invalid. +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return 0, false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a float64 + parsedFloat, err := strconv.ParseFloat(headerValue, 64) + if err != nil { + return 0, false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a float", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedFloat, true, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, errutil.Error{ + Code: errutil.BadRequest, + Msg: fmt.Sprintf("%s must be a bool", headerName), + } + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go new file mode 100644 index 0000000000..1d55682435 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -0,0 +1,145 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod { + total := 0 + choices := s.buildCompositeChoices( + ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, + ) + if strategy == HeadroomStrategyCompositeLeast { + // Invert weights for "least" strategy + for i := range choices { + choices[i].Weight = minWeight + Wmax - choices[i].Weight + } + } + selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedPod +} +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if total == 0 { + return nil + } + logger := log.FromContext(context.Background()) + // Check if MAX_SCORE_SELECTION env variable is set + if SelectionMode == PodSelectionMax { + + logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") + maxWeight := 0 + var selectedPod schedulingtypes.Pod + for _, c := range weightedChoices { + if c.Weight > maxWeight { + maxWeight = c.Weight + selectedPod = c.PodName + } + } + if selectedPod != nil { + return selectedPod + } + // Fallback to first pod if no selection made + return candidates[0].Pod + } + + // Original weighted random selection logic + logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") + idx := r.Intn(total) + var selectedPod schedulingtypes.Pod + + for _, c := range weightedChoices { + if idx < c.Weight { + selectedPod = c.PodName + break + } + idx -= c.Weight + } + + // If no pod was selected (shouldn't happen), fallback to first pod + if selectedPod == nil { + selectedPod = candidates[0].Pod + } + + return selectedPod +} +func (s *SLOAwareRouter) buildCompositeChoices( + ctx context.Context, + candidates []PodPredictionResult, + wkv, wq, wpref float64, + total *int, +) []Choice { + + // Normalize weights + sumw := wkv + wq + wpref + if sumw <= 0 { + wkv, wq, wpref = 1, 0, 0 + } else { + wkv /= sumw + wq /= sumw + wpref /= sumw + } + + // Precompute queue stats + minQ, maxQ := math.MaxInt32, -1 + queueCounts := make(map[string]int, len(candidates)) + for _, p := range candidates { + q := p.Pod.GetMetrics().WaitingQueueSize + queueCounts[p.Pod.GetPod().String()] = q + if q < minQ { + minQ = q + } + if q > maxQ { + maxQ = q + } + } + den := float64(maxQ - minQ) + + choices := make([]Choice, 0, len(candidates)) + for _, p := range candidates { + q := queueCounts[p.Pod.GetPod().String()] + relQueue := 1.0 + if den > 0 { + relQueue = (float64(maxQ-q) / den) + } + + kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvFree := (1.0 - kvUsage) + prefix := (p.PrefixCacheScore) + + composite := wkv*kvFree + wq*relQueue + wpref*prefix + w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite))) + *total += w + choices = append(choices, Choice{PodName: p.Pod, Weight: w}) + + log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", + "pod", p.Pod.GetPod().String(), + "kvUsage", kvUsage, "kvFree", kvFree, + "queue", q, "relQueue", relQueue, + "prefix", prefix, + "wkv", wkv, "wq", wq, "wprefix", wpref, + "composite", composite, "weight", w) + } + return choices +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go new file mode 100644 index 0000000000..aa47f93c9d --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -0,0 +1,439 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "fmt" + "strings" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. +func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { + if sr := sloCtx.SchedulingResult; sr != nil { + if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { + for profileName, profileResult := range sr.ProfileResults { + if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { + sloCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + } + } + } + } else { + log.FromContext(ctx).V(logutil.DEBUG).Info("No scheduling result found, skipping metrics refresh") + } +} + +// GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. +func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext) (*backendmetrics.MetricsState, error) { + if len(sloCtx.LastSeenMetrics) == 0 { + return nil, fmt.Errorf("no last seen metrics available for prediction") + } + + primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName + if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists { + return metrics, nil + } + + return nil, fmt.Errorf("no metrics found for primary profile %s", primaryProfileName) +} + +// ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp. +func ProcessHeaderForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, +) error { + logger := log.FromContext(ctx) + + //just for debugging, print the req context scheduling result cycle state + //print the raw scores in scheduling result + + // Build prediction request + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return err + } + + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + + // Predict TTFT + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else if p == nil { + logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTTFT = 0 + } else { + logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) + metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.PredictedTTFT = p.TTFT + } + + // Advance timestamp for first token reference + sloCtx.LastTokenTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + return err +} + +// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp. +func ProcessFirstTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Actual TTFT + sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount = 1 + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) + return + } + targetPod := sloCtx.TargetPod + prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + + // Train TTFT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: sloCtx.TTFT, + ActualTPOT: 0, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: 0, + PrefixCacheScore: prefix_cache_score, + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") + } + m, err = GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + + // Predict first TPOT + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. +func ProcessTokenForLatencyPrediction( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *SLORequestContext, + now time.Time, +) { + logger := log.FromContext(ctx) + + // Initialize sampler if not yet + if sloCtx.TokenSampler == nil { + requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + } + + // Inter-token latency + latencyMs := float64(now.Sub(sloCtx.LastTokenTimestamp).Milliseconds()) + sloCtx.GeneratedTokenCount++ + + //log the inter-token latency for predicted samples + if sloCtx.GeneratedTokenCount == 2 || sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token + sloCtx.TPOTObservations = append(sloCtx.TPOTObservations, latencyMs) + sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations)) + } + + m, err := GetLatestMetricsForProfile(ctx, sloCtx) + if err != nil { + logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", + "error", err) + return + } + // Record actual TPOT + entry := latencypredictor.TrainingEntry{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + ActualTTFT: 0, + ActualTPOT: latencyMs, + Timestamp: now, + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount - 1, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { + logger.V(logutil.DEBUG).Error(err, "record TPOT training failed") + } + + // Sampled predict + if sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { + in := latencypredictor.PredictionRequest{ + KVCachePercentage: m.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + NumRequestWaiting: m.WaitingQueueSize, + NumRequestRunning: m.RunningQueueSize, + NumTokensGenerated: sloCtx.GeneratedTokenCount, + PrefixCacheScore: 0, // TPOT does not use prefix cache score + } + start := time.Now() + p, err := predictor.Predict(ctx, in) + dur := time.Since(start) + if err != nil || p == nil { + logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + } else { + logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) + sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) + sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + } + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + + sloCtx.TokenSampler.RecordPrediction(sloCtx.GeneratedTokenCount) + } + + // Advance timestamp + sloCtx.LastTokenTimestamp = now + // Refresh metrics + RefreshLastSeenMetrics(ctx, sloCtx) +} + +// PredictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. +func PredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsState *backendmetrics.MetricsState, + prompt string, + generatedTokenCount int, + prefixcachescore float64, +) (*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + if metricsState == nil { + return nil, fmt.Errorf("metrics state cannot be nil") + } + + // Build prediction request + in := latencypredictor.PredictionRequest{ + KVCachePercentage: metricsState.KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompt)), + NumRequestWaiting: metricsState.WaitingQueueSize, + NumRequestRunning: metricsState.RunningQueueSize, + NumTokensGenerated: generatedTokenCount, + PrefixCacheScore: prefixcachescore, + } + + // Perform prediction + start := time.Now() + result, err := predictor.Predict(ctx, in) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "prediction failed", + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + return nil, err + } + + if result == nil { + logger.V(logutil.DEBUG).Info("prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("prediction returned nil result") + } + + logger.V(logutil.DEBUG).Info("prediction succeeded", + "tpot_ms", result.TPOT, + "ttft_ms", result.TTFT, + "duration_ms", duration.Milliseconds(), + "input_tokens", in.InputTokenLength, + "generated_tokens", generatedTokenCount, + "kv_cache_percent", in.KVCachePercentage, + "waiting_queue", in.NumRequestWaiting, + "running_queue", in.NumRequestRunning, + "prefix_cache_score", in.PrefixCacheScore) + + return result, nil +} + +// BulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// Returns predictions in the same order as the input slices. +func BulkPredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsStates []*backendmetrics.MetricsState, + prompts []string, + generatedTokenCounts []int, + prefixCacheScores []float64, +) ([]*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + // Validate input lengths + if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + } + + if len(metricsStates) == 0 { + return []*latencypredictor.PredictionResponse{}, nil + } + + // Validate that no metrics state is nil + for i, metricsState := range metricsStates { + if metricsState == nil { + return nil, fmt.Errorf("metrics state at index %d cannot be nil", i) + } + } + + // Build bulk prediction requests + bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) + for i := range metricsStates { + bulkRequests[i] = latencypredictor.PredictionRequest{ + KVCachePercentage: metricsStates[i].KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompts[i])), + NumRequestWaiting: metricsStates[i].WaitingQueueSize, + NumRequestRunning: metricsStates[i].RunningQueueSize, + NumTokensGenerated: generatedTokenCounts[i], + PrefixCacheScore: prefixCacheScores[i], + } + } + + // Perform bulk prediction + start := time.Now() + bulkResponse, err := predictor.PredictBulkStrict(ctx, bulkRequests) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "bulk prediction failed", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests)) + return nil, err + } + + if bulkResponse == nil { + logger.V(logutil.DEBUG).Info("bulk prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("bulk prediction returned nil result") + } + + // Convert to pointer slice for consistency with single prediction + results := make([]*latencypredictor.PredictionResponse, len(bulkResponse.Predictions)) + for i := range bulkResponse.Predictions { + results[i] = &bulkResponse.Predictions[i] + } + + logger.V(logutil.DEBUG).Info("bulk prediction succeeded", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests), + "successful_predictions", bulkResponse.SuccessfulPredictions, + "failed_predictions", bulkResponse.FailedPredictions, + "processing_time_ms", bulkResponse.ProcessingTimeMs) + + // Log detailed results if at trace level + if logger.V(logutil.TRACE).Enabled() { + for i, result := range results { + logger.V(logutil.TRACE).Info("bulk prediction result", + "index", i, + "ttft_ms", result.TTFT, + "tpot_ms", result.TPOT, + "input_tokens", bulkRequests[i].InputTokenLength, + "generated_tokens", bulkRequests[i].NumTokensGenerated, + "kv_cache_percent", bulkRequests[i].KVCachePercentage, + "waiting_queue", bulkRequests[i].NumRequestWaiting, + "running_queue", bulkRequests[i].NumRequestRunning, + "prefix_cache_score", bulkRequests[i].PrefixCacheScore) + } + } + + return results, nil +} + +// calculateRunningAverage calculates the running average efficiently +func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 { + if count == 0 { + return 0 + } + if count == 1 { + return newValue + } + return currentAvg + (newValue-currentAvg)/float64(count) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go new file mode 100644 index 0000000000..0c2cfa0a95 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -0,0 +1,138 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type PodPredictionResult struct { + Pod schedulingtypes.Pod + TTFT float64 + TPOT float64 + TTFTValid bool + TPOTValid bool + IsValid bool + Error error + Headroom float64 // Headroom for the pod, if applicable + TTFTHeadroom float64 // TTFT headroom for the pod + PrefixCacheScore float64 // Prefix cache score for the pod +} + +// generatePredictions creates prediction results for all candidate pods +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) ([]PodPredictionResult, error) { + logger := log.FromContext(ctx) + predictions := make([]PodPredictionResult, 0, len(candidatePods)) + + for _, pod := range candidatePods { + predResult := PodPredictionResult{Pod: pod} + + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + + // Get prefix cache score for the pod + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + + sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + + // Generate prediction + prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) + predResult.Error = err + return nil, err + } + predResult.PrefixCacheScore = prefixCacheScore + predResult.TTFT = prediction.TTFT + predResult.TPOT = prediction.TPOT + podMinTPOTSLO := 0.0 + //if pod.GetPod().RunningRequests.Peek() != nil { + // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT + //} + // Do this: + podMinTPOTSLO = s.getPodMinTPOTSLO(pod) + predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) + + logger.V(logutil.DEBUG).Info("Prediction for scheduling", + "pod", pod.GetPod().String(), + "prefixCacheScore", prefixCacheScore, + "TTFT", prediction.TTFT, + "TPOT", prediction.TPOT, + "buffer", SLOBufferFactor, + "podMinTPOTSLO", podMinTPOTSLO, + "ttftSLO", sloCtx.TTFTSLO, + "requestTPOTSLO", sloCtx.AvgTPOTSLO, + "tpotHeadroom", predResult.Headroom, + "ttftHeadroom", predResult.TTFTHeadroom, + "tpotValid", predResult.TPOTValid, + "ttftValid", predResult.TTFTValid, + "headroomStrategy", s.headroomStrategy) + + predictions = append(predictions, predResult) + } + + return predictions, nil +} + +// updateRequestContextWithPredictions updates the request context with prediction data +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { + for _, pred := range predictions { + if pred.Error == nil { + podKey := pred.Pod.GetPod().String() + if sloCtx.PredictedTTFTForScheduling == nil { + sloCtx.PredictedTTFTForScheduling = make(map[string]float64) + } + if sloCtx.PredictedTPOTForScheduling == nil { + sloCtx.PredictedTPOTForScheduling = make(map[string]float64) + } + sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT + } + } +} + +func (s *SLOAwareRouter) validatePrediction( + pred *latencypredictor.PredictionResponse, + sloCtx *SLORequestContext, + podMinTPOTSLO float64, +) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { + + bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor + // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests + if podMinTPOTSLO > 0 { + if podMinTPOTSLO < sloCtx.AvgTPOTSLO { + //print debug message + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) + } + bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) + } + + tpotOk = pred.TPOT < bufferedTPOT + ttftOk = pred.TTFT < sloCtx.TTFTSLO + + isValid = ttftOk && tpotOk + headroom = bufferedTPOT - pred.TPOT + ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT + return +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go new file mode 100644 index 0000000000..f865bbeb37 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -0,0 +1,262 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "time" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +var _ requestcontrol.PreRequest = &SLOAwareRouter{} +var _ requestcontrol.ResponseReceived = &SLOAwareRouter{} +var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{} +var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} + +type SLORequestContext struct { + SchedulingRequest schedulingtypes.LLMRequest + TargetPod *backend.Pod + SchedulingResult *schedulingtypes.SchedulingResult + LastSeenMetrics map[string]*backendmetrics.MetricsState + LastTokenTimestamp time.Time + RequestReceivedTimestamp time.Time + GeneratedTokenCount int + IncomingModelName string + TTFT float64 + PredictedTTFT float64 + AvgTPOT float64 + AvgPredictedTPOT float64 + TokenSampler *TokenSampler + TPOTObservations []float64 + PredictedTPOTObservations []float64 + + PrefixCacheScoresForPods map[string]float64 + + // TTFTSLO is the target time to first token SLO for the request. + TTFTSLO float64 + // TPOTSLO is the target time per output token SLO for the request. + AvgTPOTSLO float64 + + // PredictorBasedScheduling indicates whether to use predictor based scheduling. + PredictorBasedScheduling bool + //PredictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + PredictedTTFTForScheduling map[string]float64 + // PredictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + PredictedTPOTForScheduling map[string]float64 + + // boolean set if request has valid pod based on predictions + HasValidPod bool +} + +func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + return &SLORequestContext{ + SchedulingRequest: *request, + LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + PrefixCacheScoresForPods: make(map[string]float64), + PredictedTTFTForScheduling: make(map[string]float64), + PredictedTPOTForScheduling: make(map[string]float64), + } +} + +func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) { + id := request.Headers[requtil.RequestIdHeaderKey] + if ctx, exists := s.sloContextStore.Load(id); exists { + return ctx.(*SLORequestContext), nil + } + return nil, fmt.Errorf("SLO context not found for request ID: %s", id) +} + +func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Store(id, ctx) +} + +func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) { + id := request.Headers[requtil.RequestIdHeaderKey] + s.sloContextStore.Delete(id) +} + +// --- RequestControl Hooks --- + +func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) { + logger := log.FromContext(ctx) + + if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.") + return + } + + targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() + if !t.CheckPredictor(logger, targetPod) { + return + } + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + if request.Headers[requtil.RequestIdHeaderKey] == "" { + logger.V(logutil.DEBUG).Error(fmt.Errorf("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + podRequestList = NewRequestPriorityQueue() + t.runningRequestLists[podName] = podRequestList + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.PreRequest: Failed to get SLO context for request", "requestID", id) + return + } + + added := podRequestList.Add(id, sloCtx.AvgTPOTSLO) + if !added { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + } + + // Set up SLO request context + sloCtx.TargetPod = targetPod + sloCtx.SchedulingResult = schedulingResult + sloCtx.RequestReceivedTimestamp = time.Now() + RefreshLastSeenMetrics(ctx, sloCtx) + t.setSLOContextForRequest(request, sloCtx) +} + +func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, targetPod) { + return + } + + id := request.Headers[requtil.RequestIdHeaderKey] + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to get SLO context for request", "requestID", id) + return + } + + if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { + logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") + } + +} + +func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + if !t.CheckPredictor(logger, pod) || response.EndOfStream { + return + } + + now := time.Now() + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.TRACE).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT == 0 { + ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } else { + ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + } + +} + +func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { + logger := log.FromContext(ctx) + targetPod := pod + if !t.CheckPredictor(logger, targetPod) { + return + } + + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + id := request.Headers[requtil.RequestIdHeaderKey] + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseComplete: Failed to get SLO context for request", "requestID", id) + return + } + + if sloCtx.TTFT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) + metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000) + metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000) + if sloCtx.TTFTSLO > 0 { + metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT, sloCtx.TTFTSLO) + } + } + + if sloCtx.AvgTPOT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) + metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000) + metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000) + if sloCtx.AvgTPOTSLO > 0 { + metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT, sloCtx.AvgTPOTSLO) + } + } + + logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) + + podName := types.NamespacedName{ + Name: targetPod.NamespacedName.Name, + Namespace: targetPod.NamespacedName.Namespace, + } + + id := request.Headers[requtil.RequestIdHeaderKey] + podRequestList, ok := t.runningRequestLists[podName] + if !ok { + err := fmt.Errorf("no running request list found for pod %s", podName.String()) + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id) + } + + _, removed := podRequestList.Remove(id) + if !removed { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + } + t.deleteSLOContextForRequest(request) +} + +func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { + if targetPod == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") + return false + } + if t.latencypredictor == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because predictor missing") + return false + } + return true +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go new file mode 100644 index 0000000000..96999af2f3 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -0,0 +1,945 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" +) + +// Helper functions + +func createTestSchedulingResult(pod *backend.Pod, kvUsage float64, runningQueue int, waitingQueue int) *schedulingtypes.SchedulingResult { + + mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningQueue, waitingQueue) + + return &schedulingtypes.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ + "default": { + TargetPods: []schedulingtypes.Pod{mockPod}, + }, + }, + } +} + +func createTestRouter() *SLOAwareRouter { + return &SLOAwareRouter{ + sloContextStore: sync.Map{}, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + latencypredictor: nil, + } +} + +// Test cases + +func TestNewSLORequestContext(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + + ctx := NewSLORequestContext(request) + + assert.NotNil(t, ctx) + assert.Equal(t, *request, ctx.SchedulingRequest) + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.LastSeenMetrics) + assert.Empty(t, ctx.PrefixCacheScoresForPods) +} + +func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + + require.NoError(t, err) + assert.Equal(t, sloCtx, retrievedCtx) +} + +func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + + // Try to get context that doesn't exist + ctx, err := router.getSLOContextForRequest(request) + + assert.Error(t, err) + assert.Nil(t, ctx) + assert.Contains(t, err.Error(), "SLO context not found") +} + +func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set and then delete context + router.setSLOContextForRequest(request, sloCtx) + router.deleteSLOContextForRequest(request) + + // Verify it's deleted + ctx, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + assert.Nil(t, ctx) +} + +func TestSLOAwareRouter_PreRequest_NoSchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + // Call PreRequest with nil scheduling result + router.PreRequest(ctx, request, nil) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_EmptySchedulingResult(t *testing.T) { + router := createTestRouter() + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + + schedulingResult := &schedulingtypes.SchedulingResult{ + ProfileResults: map[string]*schedulingtypes.ProfileRunResult{}, + } + + // Call PreRequest with empty scheduling result + router.PreRequest(ctx, request, schedulingResult) + + // Should not create SLO context + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the request priority queue + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + beforeTime := time.Now() + router.PreRequest(ctx, request, schedulingResult) + afterTime := time.Now() + + // Verify SLO context was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.Equal(t, pod.GetPod(), retrievedCtx.TargetPod) + assert.Equal(t, schedulingResult, retrievedCtx.SchedulingResult) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.After(beforeTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.RequestReceivedTimestamp.Before(afterTime) || + retrievedCtx.RequestReceivedTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + + // PreRequest should create the queue + router.PreRequest(ctx, request, schedulingResult) + + // Verify queue was created and request was added + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists, "Queue should be created for pod") + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set initial SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add first request + router.PreRequest(ctx, request1, schedulingResult) + + // Add second request to same pod + router.PreRequest(ctx, request2, schedulingResult) + + // Verify both are in the same queue + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseReceived(ctx, request, response, nil) + + // Predictor should not be called + +} + +func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // Should handle missing context gracefully + +} + +func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic and should return early + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Context should still exist + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + beforeTime := time.Now() + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + afterTime := time.Now() + + // Verify first token timestamp was set + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(beforeTime) || + retrievedCtx.LastTokenTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.LastTokenTimestamp.Before(afterTime) || + retrievedCtx.LastTokenTimestamp.Equal(afterTime)) +} + +func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + sloCtx := NewSLORequestContext(request) + sloCtx.RequestReceivedTimestamp = time.Now() + sloCtx.SchedulingResult = schedulingResult + sloCtx.SchedulingRequest = *request + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + sloCtx.PredictedTTFT = 80.0 + sloCtx.AvgPredictedTPOT = 30.0 + // ADD THIS - populate metrics + sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 1, + RunningQueueSize: 1, + } + firstTokenTime := time.Now().Add(-100 * time.Millisecond) + + router.setSLOContextForRequest(request, sloCtx) + + // Initialize the queue and add the request + queue := NewRequestPriorityQueue() + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Verify token timestamp was updated + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.True(t, retrievedCtx.LastTokenTimestamp.After(firstTokenTime)) +} + +func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + sloCtx.IncomingModelName = "test-model" + sloCtx.TargetPod = pod.GetPod() // ADD THIS to avoid other issues + router.setSLOContextForRequest(request, sloCtx) + + // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior + router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + + // Should handle gracefully when request is not in queue + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should be deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} +func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue and add request + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was deleted + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) + + // Verify request was removed from queue + assert.Equal(t, 0, queue.Len()) +} + +func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Context should still exist (deletion happens only with predictor) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + sloCtx := NewSLORequestContext(request) + router.setSLOContextForRequest(request, sloCtx) + + // Should not panic with nil pod + router.ResponseComplete(ctx, request, response, nil) + + // Context should still exist (deletion happens only with validpod.GetPod()) + _, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Don't set SLO context - should handle gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Should not panic + +} + +func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.PredictedTTFT = 85 + sloCtx.AvgPredictedTPOT = 32 + sloCtx.TTFTSLO = 100 + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "incoming-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should record metrics without panicking + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs + response := &requestcontrol.Response{} + + // Create queue + queue := NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = queue + queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) + + sloCtx := NewSLORequestContext(request) + sloCtx.TTFT = 80 + sloCtx.AvgTPOT = 30 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // Should handle missing SLOs gracefully + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify cleanup + _, err := router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_CheckPredictor_NilPod(t *testing.T) { + router := createTestRouter() + logger := logr.Discard() + + result := router.CheckPredictor(logger, nil) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { + router := createTestRouter() + router.latencypredictor = nil + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.False(t, result) +} + +func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + logger := logr.Discard() + pod := createTestPod("test-pod", 1, 1, 1) + + result := router.CheckPredictor(logger, pod.GetPod()) + + assert.True(t, result) +} + +func TestSLORequestContext_Fields(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Test all field initialization + assert.NotNil(t, ctx.LastSeenMetrics) + assert.NotNil(t, ctx.PrefixCacheScoresForPods) + assert.NotNil(t, ctx.PredictedTTFTForScheduling) + assert.NotNil(t, ctx.PredictedTPOTForScheduling) + assert.Empty(t, ctx.TPOTObservations) + assert.Empty(t, ctx.PredictedTPOTObservations) + assert.Zero(t, ctx.GeneratedTokenCount) + assert.Zero(t, ctx.TTFT) + assert.Zero(t, ctx.AvgTPOT) + assert.Nil(t, ctx.TargetPod) + assert.Nil(t, ctx.SchedulingResult) + assert.Nil(t, ctx.TokenSampler) +} + +func TestSLORequestContext_UpdateMetrics(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Add some metrics + metricsState := &backendmetrics.MetricsState{ + KVCacheUsagePercent: 0.5, + WaitingQueueSize: 3, + } + ctx.LastSeenMetrics["test-pod"] = metricsState + + assert.Len(t, ctx.LastSeenMetrics, 1) + assert.Equal(t, 0.5, ctx.LastSeenMetrics["test-pod"].KVCacheUsagePercent) + assert.Equal(t, 3, ctx.LastSeenMetrics["test-pod"].WaitingQueueSize) +} + +func TestSLORequestContext_PredictionData(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prediction data + ctx.PredictedTTFTForScheduling["pod1"] = 80.0 + ctx.PredictedTPOTForScheduling["pod1"] = 30.0 + ctx.PredictedTTFTForScheduling["pod2"] = 90.0 + ctx.PredictedTPOTForScheduling["pod2"] = 35.0 + + assert.Len(t, ctx.PredictedTTFTForScheduling, 2) + assert.Len(t, ctx.PredictedTPOTForScheduling, 2) + assert.Equal(t, 80.0, ctx.PredictedTTFTForScheduling["pod1"]) + assert.Equal(t, 30.0, ctx.PredictedTPOTForScheduling["pod1"]) +} + +func TestSLORequestContext_PrefixCacheScores(t *testing.T) { + request := createTestLLMRequest("test", 100, 50, true) + ctx := NewSLORequestContext(request) + + // Set prefix cache scores + ctx.PrefixCacheScoresForPods["pod1"] = 0.8 + ctx.PrefixCacheScoresForPods["pod2"] = 0.6 + ctx.PrefixCacheScoresForPods["pod3"] = 0.9 + + assert.Len(t, ctx.PrefixCacheScoresForPods, 3) + assert.Equal(t, 0.8, ctx.PrefixCacheScoresForPods["pod1"]) + assert.Equal(t, 0.9, ctx.PrefixCacheScoresForPods["pod3"]) +} + +func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { + router := createTestRouter() + + // Test concurrent access to context store + var wg sync.WaitGroup + numGoroutines := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + + // Set context + router.setSLOContextForRequest(request, sloCtx) + + // Get context + retrievedCtx, err := router.getSLOContextForRequest(request) + assert.NoError(t, err) + assert.NotNil(t, retrievedCtx) + + // Delete context + router.deleteSLOContextForRequest(request) + }(i) + } + + wg.Wait() +} + +func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + request3 := createTestLLMRequest("test-id-3", 100, 50, true) + + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { + sloCtx := NewSLORequestContext(req) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(req, sloCtx) + } + + // Add all requests + router.PreRequest(ctx, request1, schedulingResult) + router.PreRequest(ctx, request2, schedulingResult) + router.PreRequest(ctx, request3, schedulingResult) + + // Verify queue has all requests + queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + assert.True(t, exists) + assert.NotNil(t, queue) +} + +func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + request := createTestLLMRequest("test", 100, 50, true) + response := &requestcontrol.Response{} + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + // Create initial context + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + sloCtx.IncomingModelName = "test-model" + router.setSLOContextForRequest(request, sloCtx) + + // 1. PreRequest + router.PreRequest(ctx, request, schedulingResult) + + // Verify context exists + retrievedCtx, err := router.getSLOContextForRequest(request) + require.NoError(t, err) + assert.NotNil(t, retrievedCtx.TargetPod) + + // 2. ResponseReceived + router.ResponseReceived(ctx, request, response, pod.GetPod()) + + // 3. ResponseStreaming (first token) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 4. ResponseStreaming (subsequent tokens) + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 100 // Mark first token received + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseStreaming(ctx, request, response, pod.GetPod()) + + // 5. ResponseComplete + retrievedCtx, _ = router.getSLOContextForRequest(request) + retrievedCtx.TTFT = 80 + retrievedCtx.AvgTPOT = 30 + router.setSLOContextForRequest(request, retrievedCtx) + router.ResponseComplete(ctx, request, response, pod.GetPod()) + + // Verify context was cleaned up + _, err = router.getSLOContextForRequest(request) + assert.Error(t, err) +} + +func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { + router := createTestRouter() + mockPredictor := new(mockPredictor) + router.latencypredictor = mockPredictor + + ctx := context.Background() + + pod1 := createTestPod("test-pod-1", 1, 1, 1) + pod2 := createTestPod("test-pod-2", 1, 1, 1) + + request1 := createTestLLMRequest("test-id-1", 100, 50, true) + request2 := createTestLLMRequest("test-id-2", 100, 50, true) + + schedulingResult1 := createTestSchedulingResult(pod1.GetPod(), 1, 1, 1) + schedulingResult2 := createTestSchedulingResult(pod2.GetPod(), 1, 1, 1) + + // Create and set SLO contexts + sloCtx1 := NewSLORequestContext(request1) + sloCtx1.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request1, sloCtx1) + + sloCtx2 := NewSLORequestContext(request2) + sloCtx2.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request2, sloCtx2) + + // Add requests to different pods + router.PreRequest(ctx, request1, schedulingResult1) + router.PreRequest(ctx, request2, schedulingResult2) + + // Verify separate queues were created + queue1, exists1 := router.runningRequestLists[pod1.GetPod().NamespacedName] + queue2, exists2 := router.runningRequestLists[pod2.GetPod().NamespacedName] + + assert.True(t, exists1) + assert.True(t, exists2) + assert.NotNil(t, queue1) + assert.NotNil(t, queue2) + assert.NotEqual(t, queue1, queue2) +} + +func TestSLORequestContext_SLOValidation(t *testing.T) { + tests := []struct { + name string + ttftSLO float64 + tpotSLO float64 + expectSLOs bool + }{ + { + name: "Both SLOs set", + ttftSLO: 100, + tpotSLO: 50, + expectSLOs: true, + }, + { + name: "No SLOs", + ttftSLO: 0, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TTFT SLO", + ttftSLO: 100, + tpotSLO: 0, + expectSLOs: false, + }, + { + name: "Only TPOT SLO", + ttftSLO: 0, + tpotSLO: 50, + expectSLOs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) + ctx := NewSLORequestContext(request) + ctx.TTFTSLO = tt.ttftSLO + ctx.AvgTPOTSLO = tt.tpotSLO + + hasBothSLOs := ctx.TTFTSLO > 0 && ctx.AvgTPOTSLO > 0 + assert.Equal(t, tt.expectSLOs, hasBothSLOs) + }) + } +} + +// Benchmark tests + +func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { + router := createTestRouter() + ctx := context.Background() + pod := createTestPod("test-pod", 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + requestID := uuid.New().String() + request := createTestLLMRequest(requestID, 100, 50, true) + sloCtx := NewSLORequestContext(request) + sloCtx.AvgTPOTSLO = 50 + router.setSLOContextForRequest(request, sloCtx) + router.PreRequest(ctx, request, schedulingResult) + } +} + +func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { + router := createTestRouter() + request := createTestLLMRequest("test", 100, 50, true) + sloCtx := NewSLORequestContext(request) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.setSLOContextForRequest(request, sloCtx) + _, _ = router.getSLOContextForRequest(request) + router.deleteSLOContextForRequest(request) + } +} + +func BenchmarkSLORequestContext_Creation(b *testing.B) { + request := createTestLLMRequest("test", 100, 50, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewSLORequestContext(request) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go new file mode 100644 index 0000000000..ce1e997b07 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -0,0 +1,243 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "container/heap" + "fmt" + "sort" + "strings" + "sync" +) + +// Request represents an element in the priority queue. +// The index is needed by heap.Remove and is maintained by the heap.Interface methods. +type Request struct { + ID string // Unique identifier + TPOT float64 // The priority value (lower is higher priority) + index int +} + +// RequestPriorityQueue implements a priority queue with item removal by ID. +type RequestPriorityQueue struct { + items []*Request + lookup map[string]*Request + mutex sync.RWMutex +} + +// NewRequestPriorityQueue initializes and returns a new PriorityQueue. +func NewRequestPriorityQueue() *RequestPriorityQueue { + return &RequestPriorityQueue{ + lookup: make(map[string]*Request), + items: []*Request{}, + } +} + +// Clone creates a deep copy of the priority queue. +// The new queue is completely independent of the original. +func (pq *RequestPriorityQueue) Clone() *RequestPriorityQueue { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Initialize a new priority queue with pre-allocated capacity. + clonedPq := &RequestPriorityQueue{ + items: make([]*Request, len(pq.items)), + lookup: make(map[string]*Request, len(pq.lookup)), + } + + // Iterate through the original items to create deep copies. + for i, oldItem := range pq.items { + // Create a new Request struct, copying all values. + newItem := &Request{ + ID: oldItem.ID, + TPOT: oldItem.TPOT, + index: oldItem.index, + } + + // Assign the new item to the cloned queue's items slice. + clonedPq.items[i] = newItem + // Update the lookup map in the cloned queue to point to the new item. + clonedPq.lookup[newItem.ID] = newItem + } + + return clonedPq +} + +// Len is the number of items in the queue. +func (pq *RequestPriorityQueue) Len() int { return len(pq.items) } + +// Less reports whether the item with index i should sort before the item with index j. +func (pq *RequestPriorityQueue) Less(i, j int) bool { + return pq.items[i].TPOT < pq.items[j].TPOT +} + +// Swap swaps the items with indexes i and j. +func (pq *RequestPriorityQueue) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] + pq.items[i].index = i + pq.items[j].index = j +} + +// Push adds an item to the heap. +func (pq *RequestPriorityQueue) Push(x any) { + item := x.(*Request) + item.index = len(pq.items) + pq.items = append(pq.items, item) +} + +// Pop removes and returns the minimum item from the heap. +func (pq *RequestPriorityQueue) Pop() any { + n := len(pq.items) + item := pq.items[n-1] + pq.items[n-1] = nil // avoid memory leak + item.index = -1 // for safety + pq.items = pq.items[0 : n-1] + return item +} + +// Add adds a new item to the queue. +// Returns true if the item was added, false if an item with the same ID already exists. +func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if id == "" { + return false + } + if tpot < 0 { + return false + } + + // If item already exists, do not add + if _, exists := pq.lookup[id]; exists { + return false + } + + item := &Request{ + ID: id, + TPOT: tpot, + } + pq.lookup[id] = item + heap.Push(pq, item) + return true +} + +// Update modifies the TPOT value of an existing item in the queue. +// If the item doesn't exist, this method does nothing. +func (pq *RequestPriorityQueue) Update(id string, tpot float64) bool { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + // Validate input + if tpot < 0 { + return false + } + + item, exists := pq.lookup[id] + if !exists { + return false + } + + item.TPOT = tpot + heap.Fix(pq, item.index) + return true +} + +// Remove removes an item from the queue by its ID. +func (pq *RequestPriorityQueue) Remove(id string) (*Request, bool) { + pq.mutex.Lock() + defer pq.mutex.Unlock() + + item, ok := pq.lookup[id] + if !ok { + return nil, false + } + removed := heap.Remove(pq, item.index).(*Request) + delete(pq.lookup, id) + return removed, true +} + +// Peek returns the item with the lowest value without removing it. +func (pq *RequestPriorityQueue) Peek() *Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return nil + } + return pq.items[0] +} + +// GetSize returns the current number of items in the queue. +func (pq *RequestPriorityQueue) GetSize() int { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + return len(pq.items) +} + +// Contains checks if an item with the given ID exists in the queue. +func (pq *RequestPriorityQueue) Contains(id string) bool { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + _, exists := pq.lookup[id] + return exists +} + +// ToSlice returns a copy of all items in the queue, sorted by ID for stable comparison. +// This is primarily intended for testing and validation. +func (pq *RequestPriorityQueue) ToSlice() []*Request { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + // Create a copy to avoid returning a reference to the internal slice. + itemsCopy := make([]*Request, len(pq.items)) + copy(itemsCopy, pq.items) + + // Sort by ID to have a deterministic order for comparison in tests. + sort.Slice(itemsCopy, func(i, j int) bool { + return itemsCopy[i].ID < itemsCopy[j].ID + }) + + return itemsCopy +} + +// String returns a string representation of the queue for debugging. +func (pq *RequestPriorityQueue) String() string { + pq.mutex.RLock() + defer pq.mutex.RUnlock() + + if len(pq.items) == 0 { + return "RequestPriorityQueue: []" + } + + var builder strings.Builder + builder.WriteString("RequestPriorityQueue: [") + + for i, item := range pq.items { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(item.ID) + builder.WriteString("(") + builder.WriteString(fmt.Sprintf("%.2f", item.TPOT)) + builder.WriteString(")") + } + + builder.WriteString("]") + return builder.String() +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go new file mode 100644 index 0000000000..a8eba5fe1c --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go @@ -0,0 +1,391 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestNewRequestPriorityQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + + if pq == nil { + t.Fatal("NewRequestPriorityQueue returned nil") + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue, got size %d", pq.GetSize()) + } + + if pq.Peek() != nil { + t.Error("Expected nil from Peek on empty queue") + } +} + +func TestAdd(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test successful add + if !pq.Add("req1", 2.5) { + t.Error("Expected Add to return true for new item") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1, got %d", pq.GetSize()) + } + + // Test duplicate add + if pq.Add("req1", 3.0) { + t.Error("Expected Add to return false for duplicate ID") + } + + if pq.GetSize() != 1 { + t.Errorf("Expected size 1 after duplicate add, got %d", pq.GetSize()) + } + + // Test validation + if pq.Add("", 1.0) { + t.Error("Expected Add to return false for empty ID") + } + + if pq.Add("req2", -1.0) { + t.Error("Expected Add to return false for negative TPOT") + } +} + +func TestPriorityOrdering(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Add items with different priorities + pq.Add("high", 1.0) // highest priority (lowest TPOT) + pq.Add("medium", 5.0) // medium priority + pq.Add("low", 10.0) // lowest priority (highest TPOT) + + // Check that highest priority item is at the top + peek := pq.Peek() + if peek == nil || peek.ID != "high" || peek.TPOT != 1.0 { + t.Errorf("Expected high priority item at top, got %+v", peek) + } + + // Test removal order + expected := []struct { + id string + tpot float64 + }{ + {"high", 1.0}, + {"medium", 5.0}, + {"low", 10.0}, + } + + for _, exp := range expected { + item := pq.Peek() + if item.ID != exp.id || item.TPOT != exp.tpot { + t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.ID, item.TPOT) + } + + removed, ok := pq.Remove(item.ID) + if !ok || removed.ID != exp.id { + t.Errorf("Failed to remove %s", exp.id) + } + } +} + +func TestRemove(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test remove from empty queue + if _, ok := pq.Remove("nonexistent"); ok { + t.Error("Expected Remove to return false for empty queue") + } + + // Add some items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Test successful remove + removed, ok := pq.Remove("req2") + if !ok || removed.ID != "req2" || removed.TPOT != 2.0 { + t.Errorf("Expected to remove req2(2.0), got %+v, ok=%v", removed, ok) + } + + if pq.GetSize() != 2 { + t.Errorf("Expected size 2 after removal, got %d", pq.GetSize()) + } + + // Test remove nonexistent + if _, ok := pq.Remove("req2"); ok { + t.Error("Expected Remove to return false for already removed item") + } + + // Verify remaining items are still in correct order + if peek := pq.Peek(); peek.ID != "req1" { + t.Errorf("Expected req1 at top, got %s", peek.ID) + } +} + +func TestUpdate(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test update nonexistent item + if pq.Update("nonexistent", 1.0) { + t.Error("Expected Update to return false for nonexistent item") + } + + // Add items + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Update to make req3 highest priority + if !pq.Update("req3", 0.5) { + t.Error("Expected Update to return true for existing item") + } + + // Check that req3 is now at the top + if peek := pq.Peek(); peek.ID != "req3" || peek.TPOT != 0.5 { + t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.ID, peek.TPOT) + } + + // Test validation + if pq.Update("req1", -1.0) { + t.Error("Expected Update to return false for negative TPOT") + } +} + +func TestContains(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + if pq.Contains("req1") { + t.Error("Expected Contains to return false for empty queue") + } + + // Add item + pq.Add("req1", 1.0) + + // Test existing item + if !pq.Contains("req1") { + t.Error("Expected Contains to return true for existing item") + } + + // Test nonexistent item + if pq.Contains("req2") { + t.Error("Expected Contains to return false for nonexistent item") + } + + // Test after removal + pq.Remove("req1") + if pq.Contains("req1") { + t.Error("Expected Contains to return false after removal") + } +} + +func TestClone(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test clone of empty queue + clone := pq.Clone() + if clone.GetSize() != 0 { + t.Error("Expected cloned empty queue to be empty") + } + + // Add items to original + pq.Add("req1", 1.0) + pq.Add("req2", 2.0) + pq.Add("req3", 3.0) + + // Clone with items + clone = pq.Clone() + + // Verify clone has same items + if clone.GetSize() != pq.GetSize() { + t.Errorf("Expected clone size %d, got %d", pq.GetSize(), clone.GetSize()) + } + + // Verify independence - modify original + pq.Add("req4", 4.0) + if clone.GetSize() == pq.GetSize() { + t.Error("Clone should be independent of original") + } + + // Verify independence - modify clone + clone.Remove("req1") + if !pq.Contains("req1") { + t.Error("Original should not be affected by clone modifications") + } + + // Verify deep copy - items should be different instances + origPeek := pq.Peek() + clonePeek := clone.Peek() + if origPeek == clonePeek { + t.Error("Clone should create new Request instances, not share pointers") + } +} + +func TestString(t *testing.T) { + pq := NewRequestPriorityQueue() + + // Test empty queue + str := pq.String() + expected := "RequestPriorityQueue: []" + if str != expected { + t.Errorf("Expected %q, got %q", expected, str) + } + + // Test with items + pq.Add("req1", 1.5) + pq.Add("req2", 2.25) + + str = pq.String() + // Should contain both items in priority order + if !contains(str, "req1(1.50)") || !contains(str, "req2(2.25)") { + t.Errorf("String output missing expected items: %s", str) + } +} + +func TestConcurrency(t *testing.T) { + pq := NewRequestPriorityQueue() + const numWorkers = 10 + const itemsPerWorker = 100 + + var wg sync.WaitGroup + + // Launch workers that add items + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + for j := 0; j < itemsPerWorker; j++ { + id := fmt.Sprintf("worker%d-item%d", workerID, j) + tpot := float64(j) + float64(workerID)*0.1 + pq.Add(id, tpot) + } + }(i) + } + + // Launch workers that read from the queue + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < itemsPerWorker/2; j++ { + pq.Peek() + pq.GetSize() + time.Sleep(time.Microsecond) + } + }() + } + + wg.Wait() + + // Verify final state + expectedSize := numWorkers * itemsPerWorker + if pq.GetSize() != expectedSize { + t.Errorf("Expected final size %d, got %d", expectedSize, pq.GetSize()) + } +} + +func TestLargeQueue(t *testing.T) { + pq := NewRequestPriorityQueue() + const numItems = 10000 + + // Add many items + for i := 0; i < numItems; i++ { + id := fmt.Sprintf("item%d", i) + tpot := float64(numItems - i) // Reverse order so item0 has highest priority + pq.Add(id, tpot) + } + + if pq.GetSize() != numItems { + t.Errorf("Expected size %d, got %d", numItems, pq.GetSize()) + } + + // Verify priority ordering by removing items + lastTPOT := -1.0 + for i := 0; i < numItems; i++ { + item := pq.Peek() + if item.TPOT < lastTPOT { + t.Errorf("Priority order violated: %.1f < %.1f", item.TPOT, lastTPOT) + } + lastTPOT = item.TPOT + pq.Remove(item.ID) + } + + if pq.GetSize() != 0 { + t.Errorf("Expected empty queue after removing all items, got size %d", pq.GetSize()) + } +} + +func BenchmarkAdd(b *testing.B) { + pq := NewRequestPriorityQueue() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := fmt.Sprintf("item%d", i) + pq.Add(id, float64(i)) + } +} + +func BenchmarkPeek(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < 1000; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Peek() + } +} + +func BenchmarkRemove(b *testing.B) { + pq := NewRequestPriorityQueue() + + // Pre-populate queue + for i := 0; i < b.N; i++ { + pq.Add(fmt.Sprintf("item%d", i), float64(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pq.Remove(fmt.Sprintf("item%d", i)) + } +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go new file mode 100644 index 0000000000..bdeca30378 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -0,0 +1,136 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "hash/fnv" + "math" + "math/rand" + "time" +) + +// TokenSampler handles Poisson-distributed sampling for predictions only +// Training happens on every token regardless of sampling +type TokenSampler struct { + rng *rand.Rand + nextSampleToken int + samplingMean float64 + maxSamples int + sampleCount int +} + +// SetSamplingMean sets the sampling mean (lambda) for the Poisson distribution +func (ts *TokenSampler) SetSamplingMean(mean float64) { + ts.samplingMean = mean +} + +// SetMaxSamples sets the maximum number of samples +func (ts *TokenSampler) SetMaxSamples(max int) { + ts.maxSamples = max +} + +// SetSampleCount sets the current number of predictions made +func (ts *TokenSampler) SetSampleCount(count int) { + ts.sampleCount = count +} + +func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *TokenSampler { + // Use request ID hash as seed for reproducibility + seed := int64(0) + if requestID != "" { + hash := fnv.New64a() + hash.Write([]byte(requestID)) + seed = int64(hash.Sum64()) + } + if seed == 0 { + seed = time.Now().UnixNano() + } + + sampler := &TokenSampler{ + rng: rand.New(rand.NewSource(seed)), + samplingMean: samplingMean, + maxSamples: maxSamples, + } + + // Set first sample token (skip token 1 since that's TTFT) + sampler.nextSampleToken = 2 + sampler.poissonNext() + + return sampler +} + +// poissonNext generates the next interval using Poisson distribution +func (ts *TokenSampler) poissonNext() int { + lambda := ts.samplingMean + if lambda <= 0 { + return 1 + } + + // For small lambda, use Knuth's algorithm + if lambda < 30 { + l := math.Exp(-lambda) + k := 0 + p := 1.0 + + for p > l { + k++ + p *= ts.rng.Float64() + } + return k - 1 + } + + // For larger lambda, use normal approximation + normal := ts.rng.NormFloat64() + interval := int(math.Round(lambda + math.Sqrt(lambda)*normal)) + if interval < 1 { + return 1 + } + return interval +} + +// ShouldPredict determines if we should make a prediction for the current token +func (ts *TokenSampler) ShouldPredict(currentToken int) bool { + return currentToken == ts.nextSampleToken && ts.sampleCount < ts.maxSamples +} + +// RecordPrediction records that a prediction was made and calculates the next sample token +func (ts *TokenSampler) RecordPrediction(currentToken int) { + if ts.sampleCount >= ts.maxSamples { + return + } + + ts.sampleCount++ + + if ts.sampleCount < ts.maxSamples { + interval := ts.poissonNext() + ts.nextSampleToken = currentToken + interval + } +} + +// GetNextSampleToken returns the next token to predict for +func (ts *TokenSampler) GetNextSampleToken() int { + return ts.nextSampleToken +} + +// SetNextSampleToken sets the next token to predict for +func (ts *TokenSampler) SetNextSampleToken(token int) { + ts.nextSampleToken = token +} + +// GetSampleCount returns the current number of predictions made +func (ts *TokenSampler) GetSampleCount() int { + return ts.sampleCount +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go new file mode 100644 index 0000000000..b476579b5f --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -0,0 +1,325 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "sigs.k8s.io/controller-runtime/pkg/log" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +type SLOAwareRouter struct { + tn plugins.TypedName + latencypredictor latencypredictor.PredictorInterface + runningRequestLists map[types.NamespacedName]*RequestPriorityQueue + sloContextStore sync.Map // map[string]*SLORequestContext + headroomStrategy HeadroomStrategy +} + +var _ framework.Scorer = &SLOAwareRouter{} + +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { + return &SLOAwareRouter{ + tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, + latencypredictor: latencypredictor, + runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + sloContextStore: sync.Map{}, + headroomStrategy: strategy, + } +} + +func (s *SLOAwareRouter) TypedName() plugins.TypedName { + return s.tn +} + +func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { + s.tn.Name = name + return s +} + +// SetHeadroomStrategy allows runtime configuration of headroom selection strategy +func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { + s.headroomStrategy = strategy +} + +// GetHeadroomStrategy returns the current headroom selection strategy +func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { + return s.headroomStrategy +} + +func (s *SLOAwareRouter) epsilonGreedyAffinityGate( + ctx context.Context, + candidates []PodPredictionResult, + r *rand.Rand, + label string, // e.g. "positive" or "negative" + prefixStickyThreshold float64, +) ([]PodPredictionResult, bool) { + logger := log.FromContext(ctx) + + eligible := make([]PodPredictionResult, 0, len(candidates)) + for _, p := range candidates { + if p.PrefixCacheScore >= prefixStickyThreshold { + eligible = append(eligible, p) + } + } + + // No eligible sticky pods? Explore (no gating). + if len(eligible) == 0 { + return candidates, false + } + + // ε-exploration branch + if r.Float64() < EpsilonExploreSticky { + logger.V(logutil.DEBUG).Info("ε-greedy: exploring (ignoring affinity gate)", + "path", label, "epsilon", EpsilonExploreSticky, "eligibleCount", len(eligible)) + return candidates, false + } + + logger.V(logutil.DEBUG).Info("ε-greedy: exploiting (apply affinity gate)", + "path", label, "threshold", prefixStickyThreshold, "eligibleCount", len(eligible), "total", len(candidates)) + return eligible, true +} + +// scoreWithoutPredictions provides fallback scoring based only on prefix cache scores +// when latency predictions are unavailable +func (s *SLOAwareRouter) scoreWithoutPredictions( + ctx context.Context, + state *schedulingtypes.CycleState, + pods []schedulingtypes.Pod, + r *rand.Rand, +) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions") + + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + if len(pods) == 0 { + return scores + } + + // Build prediction results with only prefix cache scores + podResults := make([]PodPredictionResult, 0, len(pods)) + for _, pod := range pods { + prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) + podResults = append(podResults, PodPredictionResult{ + Pod: pod, + PrefixCacheScore: prefixScore, + IsValid: true, // All pods are valid when we don't check predictions + }) + } + + // Select based on composite scores (prefix cache + other non-prediction metrics) + selectedPod := s.selectFromCompositeScores(ctx, podResults, r, HeadroomStrategyCompositeOnly) + + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String()) + } + + return scores +} + +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { + logger := log.FromContext(ctx) + if s.latencypredictor == nil { + logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") + return nil + } + + sloCtx := s.getOrMakeSLORequestContext(request) + + var err error + // get request slos + // Get Request SLOs from request header + sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } + + // Check if SLOs are provided + if !sloCtx.PredictorBasedScheduling { + logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") + s.setSLOContextForRequest(request, sloCtx) + return nil + } + + // Initialize scores map with all pods having score 0 + scores := make(map[schedulingtypes.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = 0 + } + + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") + // Fall back to composite-only scoring using prefix cache scores + s.setSLOContextForRequest(request, sloCtx) + return s.scoreWithoutPredictions(ctx, state, pods, r) + } + s.updateRequestContextWithPredictions(sloCtx, predictions) + + allPreds := append([]PodPredictionResult(nil), predictions...) + allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) + + // Check if all pods are invalid and all have running requests + allPodsInvalid := true + allPodsHaveRunningRequests := true + + for _, pred := range allPreds { + if pred.IsValid { + allPodsInvalid = false + } + + runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + if runningRequestCount == 0 { + allPodsHaveRunningRequests = false + } + } + + // Set HasValidPod to false if all pods are invalid and all have running requests + if allPodsInvalid && allPodsHaveRunningRequests && !sticky { + sloCtx.HasValidPod = false + logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + } + + // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% + var posHeadroomPods, negHeadroomPods []PodPredictionResult + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Pod headroom distribution", + "positivePods", len(posHeadroomPods), + "negativePods", len(negHeadroomPods)) + + var selectedPod schedulingtypes.Pod + + if s.headroomStrategy == HeadroomStrategyCompositeOnly { + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, HeadroomStrategyCompositeOnly) + } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + } else if len(posHeadroomPods) > 0 { + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } else if len(negHeadroomPods) > 0 { + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else if len(allPreds) > 0 { + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + } else { + // No valid pods - return all zeros + logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") + return scores + } + + // Set score = 1 for selected pod, 0 for all others + if selectedPod != nil { + scores[selectedPod] = 1 + logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + } + + s.setSLOContextForRequest(request, sloCtx) + + return scores +} + +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { + sloCtx, err := t.getSLOContextForRequest(request) + if err != nil { + sloCtx = NewSLORequestContext(request) + } + return sloCtx +} + +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) + + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) + + if err != nil { + // The prefix cache plugin might not be enabled, which is a valid scenario. + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + return 0.0 + } + + prefixCacheState, ok := stateData.(*prefix.SchedulingContextState) + if !ok { + // This should not happen if the plugin is configured correctly. + log.FromContext(ctx).Error(fmt.Errorf("unexpected state type: %T", stateData), "failed to read prefix cache state") + return 0.0 + } + + total := len(prefixCacheState.PrefixHashes) + if total == 0 { + // if the request has no prefixes, return 0.0 + log.FromContext(ctx).V(logutil.DEBUG).Info("No prefixes found in request, returning prefix cache score of 0.0") + return 0.0 + } + + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + return float64(matchLen) / float64(total) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go new file mode 100644 index 0000000000..da073ff65a --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -0,0 +1,527 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +// mockPredictor implements PredictorInterface for testing +type mockPredictor struct { + predictions map[string]*latencypredictor.PredictionResponse + err error +} + +func (m *mockPredictor) Predict(ctx context.Context, request latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + return pred, nil + } + // Default prediction + return &latencypredictor.PredictionResponse{TTFT: 0.5, TPOT: 0.03}, nil +} + +func (m *mockPredictor) PredictBulk(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) PredictBulkStrict(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + if m.err != nil { + return nil, m.err + } + // Generate a key based on KV cache percentage to return different predictions for different pods + responses := make([]latencypredictor.PredictionResponse, 0, len(requests)) + for _, request := range requests { + key := fmt.Sprintf("%.1f", request.KVCachePercentage) + if pred, ok := m.predictions[key]; ok { + responses = append(responses, *pred) + } else { + return nil, fmt.Errorf("no prediction for key %s", key) + } + } + return &latencypredictor.BulkPredictionResponse{Predictions: responses}, nil +} + +func (m *mockPredictor) AddTrainingDataBulk(data []latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) AddTrainingData(data latencypredictor.TrainingEntry) error { + return nil +} + +func (m *mockPredictor) HealthCheck() error { + return nil +} + +func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor.ServerStatusResponse, error) { + return &latencypredictor.ServerStatusResponse{}, nil +} + +func createTestPod(name string, kvCacheUsage float64, runningQueueSize, waitingQueueSize int) schedulingtypes.Pod { + return &schedulingtypes.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: "default", + }, + }, + MetricsState: &backendmetrics.MetricsState{ + KVCacheUsagePercent: kvCacheUsage, + RunningQueueSize: runningQueueSize, + WaitingQueueSize: waitingQueueSize, + }, + } +} + +func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBased bool) *schedulingtypes.LLMRequest { + headers := make(map[string]string) + headers[requtil.RequestIdHeaderKey] = reqID + if ttftSLO > 0 { + headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO) + } + if tpotSLO > 0 { + headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) + } + headers["x-prediction-based-scheduling"] = fmt.Sprintf("%t", predictionBased) + + return &schedulingtypes.LLMRequest{ + Headers: headers, + Body: &schedulingtypes.LLMRequestBody{ + Completions: &schedulingtypes.CompletionsRequest{ + Prompt: "test prompt", + }, + }, + } +} + +func TestSLOAwareRouter_Score(t *testing.T) { + tests := []struct { + name string + predictor *mockPredictor + strategy HeadroomStrategy + request *schedulingtypes.LLMRequest + pods []schedulingtypes.Pod + expectedScores map[string]float64 // Map of pod name to expected score + expectNil bool + }{ + { + name: "Prediction-based scheduling disabled", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting + createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting + }, + expectNil: true, + }, + { + name: "No predictor configured", + predictor: nil, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + }, + expectNil: true, + }, + { + name: "All pods have positive headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, // 50% KV cache + "0.6": {TTFT: 0.6, TPOT: 0.04}, // 60% KV cache + "0.3": {TTFT: 0.4, TPOT: 0.02}, // 30% KV cache + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), // 50% KV cache + createTestPod("pod2", 0.6, 3, 2), // 60% KV cache + createTestPod("pod3", 0.3, 1, 0), // 30% KV cache + }, + // One pod should be selected with score 1, others 0 + expectedScores: map[string]float64{ + // We can't predict which one due to randomness, but exactly one should be 1 + }, + }, + { + name: "All pods have negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.8": {TTFT: 1.5, TPOT: 0.08}, // 80% KV cache - high load + "0.9": {TTFT: 1.8, TPOT: 0.09}, // 90% KV cache - very high load + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load + createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load + }, + // One pod should still be selected even with negative headroom + expectedScores: map[string]float64{}, + }, + { + name: "Mixed positive and negative headroom", + predictor: &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.3": {TTFT: 0.5, TPOT: 0.03}, // 30% KV cache - Positive headroom + "0.9": {TTFT: 1.5, TPOT: 0.08}, // 90% KV cache - Negative headroom + }, + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom + createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom + }, + // With 99% probability, positive headroom pod should be selected + expectedScores: map[string]float64{}, + }, + { + name: "Prediction errors - fallback to composite scoring", + predictor: &mockPredictor{ + err: fmt.Errorf("prediction failed"), + }, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + }, + // Should fall back to composite-only scoring and select one pod + expectedScores: map[string]float64{ + // One pod should be selected with score 1, verified in general validation below + }, + }, + { + name: "Empty pod list", + predictor: &mockPredictor{}, + strategy: HeadroomStrategyLeast, + request: createTestLLMRequest("test", 1.0, 0.05, true), + pods: []schedulingtypes.Pod{}, + // Should return empty scores map + expectedScores: map[string]float64{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var router *SLOAwareRouter + if tt.predictor == nil { + router = NewSLOAwareRouter(nil, tt.strategy) + } else { + router = NewSLOAwareRouter(tt.predictor, tt.strategy) + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.pods) + + if tt.expectNil { + assert.Nil(t, scores, "Expected nil scores") + return + } + + assert.NotNil(t, scores, "Expected non-nil scores") + + // If we have specific expected scores, verify them + if len(tt.expectedScores) > 0 { + for _, pod := range tt.pods { + podName := pod.GetPod().NamespacedName.Name + if expectedScore, ok := tt.expectedScores[podName]; ok { + assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", podName, expectedScore) + } + } + } + + // General validation: exactly one pod should have score 1 (selected), others should have score 0 + // This applies even when predictions fail because we fall back to composite scoring + if !tt.expectNil && len(tt.pods) > 0 && tt.predictor != nil { + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } else { + assert.InDelta(t, 0.0, score, 0.0001, "Non-selected pods should have score 0") + } + } + assert.Equal(t, 1, selectedCount, "Exactly one pod should be selected with score 1") + } + }) + } +} + +func TestSLOAwareRouter_Strategies(t *testing.T) { + tests := []struct { + name string + strategy HeadroomStrategy + }{ + { + name: "HeadroomStrategyLeast", + strategy: HeadroomStrategyLeast, + }, + { + name: "HeadroomStrategyMost", + strategy: HeadroomStrategyMost, + }, + { + name: "HeadroomStrategyCompositeMost", + strategy: HeadroomStrategyCompositeMost, + }, + { + name: "HeadroomStrategyCompositeLeast", + strategy: HeadroomStrategyCompositeLeast, + }, + { + name: "HeadroomStrategyCompositeOnly", + strategy: HeadroomStrategyCompositeOnly, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + "0.3": {TTFT: 0.4, TPOT: 0.02}, + }, + } + router := NewSLOAwareRouter(predictor, tt.strategy) + + request := createTestLLMRequest("test", 1.0, 0.05, true) + pods := []schedulingtypes.Pod{ + createTestPod("pod1", 0.5, 2, 1), + createTestPod("pod2", 0.6, 3, 2), + createTestPod("pod3", 0.3, 1, 0), + } + + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, pods) + + assert.NotNil(t, scores, "Expected non-nil scores for strategy %s", tt.strategy) + + // Verify exactly one pod is selected + selectedCount := 0 + for _, score := range scores { + if score == 1.0 { + selectedCount++ + } + } + assert.Equal(t, 1, selectedCount, "Strategy %s should select exactly one pod", tt.strategy) + }) + } +} + +func TestSLOAwareRouter_SetHeadroomStrategy(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + assert.Equal(t, HeadroomStrategyLeast, router.GetHeadroomStrategy(), "Initial strategy should be Least") + + router.SetHeadroomStrategy(HeadroomStrategyMost) + assert.Equal(t, HeadroomStrategyMost, router.GetHeadroomStrategy(), "Strategy should be updated to Most") + + router.SetHeadroomStrategy(HeadroomStrategyCompositeOnly) + assert.Equal(t, HeadroomStrategyCompositeOnly, router.GetHeadroomStrategy(), "Strategy should be updated to CompositeOnly") +} + +func TestSLOAwareRouter_TypedName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should be slo-aware-routing") + assert.Equal(t, "slo-aware-routing", tn.Name, "Default name should be slo-aware-routing") +} + +func TestSLOAwareRouter_WithName(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + customName := "custom-router" + router = router.WithName(customName) + + tn := router.TypedName() + assert.Equal(t, "slo-aware-routing", tn.Type, "Type should remain slo-aware-routing") + assert.Equal(t, customName, tn.Name, "Name should be updated to custom name") +} + +func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedCount int + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedCount: 0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedCount: 1, + }, + { + name: "Multiple running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + r.runningRequestLists[podName].Add("req2", 0.03) + r.runningRequestLists[podName].Add("req3", 0.05) + }, + expectedCount: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + count := router.getPodRunningRequestCount(pod) + assert.Equal(t, tt.expectedCount, count, "Running request count should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { + tests := []struct { + name string + setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + expectedSLO float64 + }{ + { + name: "No running requests", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + expectedSLO: 0.0, + }, + { + name: "One running request", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName].Add("req1", 0.04) + }, + expectedSLO: 0.04, + }, + { + name: "Multiple running requests - should return minimum", + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + podName := types.NamespacedName{ + Name: p.GetPod().NamespacedName.Name, + Namespace: p.GetPod().NamespacedName.Namespace, + } + r.runningRequestLists[podName] = NewRequestPriorityQueue() + // Add in any order - heap will maintain minimum at top + r.runningRequestLists[podName].Add("req1", 0.05) + r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum + r.runningRequestLists[podName].Add("req3", 0.04) + }, + expectedSLO: 0.03, // Minimum TPOT (heap guarantees this is at items[0]) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + pod := createTestPod("test-pod", 0.5, 2, 1) + + tt.setupRequests(router, pod) + + minSLO := router.getPodMinTPOTSLO(pod) + assert.InDelta(t, tt.expectedSLO, minSLO, 0.0001, "Min TPOT SLO should match expected") + }) + } +} + +func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { + tests := []struct { + name string + setupState func(*schedulingtypes.CycleState) + expectedScore float64 + }{ + { + name: "No prefix cache state", + setupState: func(s *schedulingtypes.CycleState) {}, + expectedScore: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predictor := &mockPredictor{} + router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + + state := schedulingtypes.NewCycleState() + tt.setupState(state) + + pod := createTestPod("test-pod", 0.5, 2, 1) + + score := router.getPrefixCacheScoreForPod(context.Background(), state, pod) + assert.InDelta(t, tt.expectedScore, score, 0.0001, "Prefix cache score should match expected") + }) + } +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go new file mode 100644 index 0000000000..eeab50433f --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -0,0 +1,385 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import ( + "context" + "math" + "math/rand" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(posHeadroomPods) == 1 { + return posHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeLeast) + } + + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + // Calculate weights for weighted random selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case HeadroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + case HeadroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) + +} + +// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic +// Modified to strictly prefer pods with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + logger := log.FromContext(ctx) + + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // First, separate pods by running request count + var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult + + for _, p := range negHeadroomPods { + runningRequestCount := s.getPodRunningRequestCount(p.Pod) + if runningRequestCount == 0 { + zeroRunningRequestPods = append(zeroRunningRequestPods, p) + } else { + nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + } + } + + logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", + "zeroRunningRequests", len(zeroRunningRequestPods), + "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + + // If we have pods with 0 running requests, strictly prefer them + if len(zeroRunningRequestPods) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + } + + // Otherwise, fall back to pods with running requests + logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") + return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) +} + +// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { + if len(negHeadroomPods) == 1 { + return negHeadroomPods[0].Pod + } + + // Apply perfect stickiness (with exploration) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", AffinityGateTau) + + // If perfect stickiness collapsed us to a single pod, short-circuit + if sticky && len(candidates) == 1 { + return candidates[0].Pod + } + + switch s.headroomStrategy { + case HeadroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case HeadroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + } + + // Build weighted choices for selection + weightedChoices := make([]Choice, 0, len(candidates)) + total := 0 + + s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + + // Perform weighted random selection + return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) +} + +// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// Lower blended deficit => higher weight. +func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( + ctx context.Context, + pods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeight int, + alpha, beta float64, // weights for TTFT and TPOT deficits + category string, +) { + logger := log.FromContext(ctx) + if len(pods) == 0 { + return + } + + const Wrange = 80 + const eps = 1e-9 + + // Compute raw deficits (only when headroom is negative) + type deficits struct { + pod PodPredictionResult + ttftDef float64 + tpotDef float64 + } + defs := make([]deficits, 0, len(pods)) + + minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 + minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 + + for _, p := range pods { + ttftDef := 0.0 + if p.TTFTHeadroom < 0 { + ttftDef = -p.TTFTHeadroom + } + tpotDef := 0.0 + if p.Headroom < 0 { + tpotDef = -p.Headroom + } + defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + + if ttftDef < minTTFT { + minTTFT = ttftDef + } + if ttftDef > maxTTFT { + maxTTFT = ttftDef + } + if tpotDef < minTPOT { + minTPOT = tpotDef + } + if tpotDef > maxTPOT { + maxTPOT = tpotDef + } + } + + ttftRange := maxTTFT - minTTFT + tpotRange := maxTPOT - minTPOT + + // Normalize alpha/beta + if alpha+beta <= 0 { + alpha, beta = 1.0, 0.0 + } else { + sum := alpha + beta + alpha /= sum + beta /= sum + } + + logger.V(logutil.DEBUG).Info("Negative headroom blended deficits", + "category", category, + "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, + "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, + "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + + for _, d := range defs { + // Normalize deficits to [0,1] within this bucket (0 = best / least violation) + nTTFT := 0.0 + if ttftRange > eps { + nTTFT = (d.ttftDef - minTTFT) / (ttftRange + eps) + } + nTPOT := 0.0 + if tpotRange > eps { + nTPOT = (d.tpotDef - minTPOT) / (tpotRange + eps) + } + + // Blended "badness": higher = worse violation + blended := alpha*nTTFT + beta*nTPOT + + // Convert to selection weight: lower badness -> higher weight + // Ensure a floor so no pod is completely excluded within the bucket. + w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 + + *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) + *total += w + + logger.V(logutil.TRACE).Info("Negative bucket blended weighting", + "pod", d.pod.Pod.GetPod().String(), + "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, + "normTTFT", nTTFT, "normTPOT", nTPOT, + "blendedBadness", blended, "weight", w) + } +} + +func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( + ctx context.Context, + negHeadroomPods []PodPredictionResult, + choices *[]Choice, + total *int, + minWeightForNegative int, +) { + logger := log.FromContext(ctx) + + // Categorize pods by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult + + for _, p := range negHeadroomPods { + if p.TTFTHeadroom < 0 && p.Headroom < 0 { + negTTFTNegTPOT = append(negTTFTNegTPOT, p) + } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { + negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) + } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { + nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) + } else { + nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) + } + } + + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", + "totalNegative", len(negHeadroomPods), + "negTTFT_negTPOT", len(negTTFTNegTPOT), + "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), + "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), + "nonNegTTFT_nonNegTPOT", len(nonNegTTFTNonNegTPOT)) + + // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) + if len(negTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "both_negative") + } + + // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) + if len(negTTFTNonNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "ttft_negative") + } + + // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) + if len(nonNegTTFTNegTPOT) > 0 { + s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + NegHeadroomTTFTWeight, NegHeadroomTPOTWeight, "tpot_negative") + } + + // Priority 4: edge-case bucket -> minimal weight + for _, p := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) + *total += minWeightForNegative + } +} + +func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if topReq := runningReqs.Peek(); topReq != nil { + return topReq.TPOT + } + } + return 0 // no running requests or no TPOT SLOs +} + +func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { + podName := types.NamespacedName{ + Name: pod.GetPod().NamespacedName.Name, + Namespace: pod.GetPod().NamespacedName.Namespace, + } + if runningReqs, ok := s.runningRequestLists[podName]; ok { + return runningReqs.GetSize() + } + return 0 // no running requests +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go new file mode 100644 index 0000000000..8030866d80 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -0,0 +1,57 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package requestcontrol contains helpers to decouple latency-predictor logic. +package slo_aware_router + +import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + +type HeadroomStrategy string + +type Choice struct { + PodName schedulingtypes.Pod + Weight int +} + +const ( + // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) + HeadroomStrategyLeast HeadroomStrategy = "least" + // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) + HeadroomStrategyMost HeadroomStrategy = "most" + + HeadroomStrategyCompositeLeast HeadroomStrategy = "composite-least" + HeadroomStrategyCompositeMost HeadroomStrategy = "composite-most" + HeadroomStrategyCompositeOnly HeadroomStrategy = "composite-only" + + // TTFT header string + TTFTSLOHeaderKey = "x-slo-ttft-ms" + // TPOT header string + TPOTSLOHeaderKey = "x-slo-tpot-ms" +) + +const ( + SLOAwareRouterPluginType = "slo-aware-routing" + eps = 1e-9 + Wmax = 100 + minWeight = 1 +) + +type PodSelectionMode string + +const ( + PodSelectionLinear PodSelectionMode = "linear" // weighted-random (current behavior) + PodSelectionMax PodSelectionMode = "max" // pick argmax weight +) diff --git a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go new file mode 100644 index 0000000000..900335c9ef --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + SLOAwareProfileHandlerType = "slo-aware-profile-handler" + DefaultProfileName = "default" + PrefixProfileName = "prefix" + SLOProfileName = "slo" + + // Boolean header string for whether to use predictor based scheduling + PreictionBasedSchedulingHeaderKey = "x-prediction-based-scheduling" +) + +// compile-time type assertion +var _ framework.ProfileHandler = &SLOAwareProfileHandler{} + +// SLOAwareProfileHandlerFactory defines the factory function for SLOAwareProfileHandler. +func SLOAwareProfileHandlerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewSLOAwareProfileHandler().WithName(name), nil +} + +// NewSLOAwareProfileHandler initializes a new SLOAwareProfileHandler and returns its pointer. +func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { + return &SLOAwareProfileHandler{ + typedName: plugins.TypedName{Type: SLOAwareProfileHandlerType, Name: SLOAwareProfileHandlerType}, + } +} + +// SLOAwareProfileHandler handles two profiles: the default profile and the SLO profile. +// When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select +// the destination pod. Otherwise, it uses the default profile result. +type SLOAwareProfileHandler struct { + typedName plugins.TypedName + prefixProfile string // the profile that should be executed first + +} + +// TypedName returns the type and name tuple of this plugin instance. +func (h *SLOAwareProfileHandler) TypedName() plugins.TypedName { + return h.typedName +} + +// WithName sets the name of the profile handler. +func (h *SLOAwareProfileHandler) WithName(name string) *SLOAwareProfileHandler { + h.typedName.Name = name + return h +} + +// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the +// previously executed cycles along with their results. +func (h *SLOAwareProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, + profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { + if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call + return map[string]*framework.SchedulerProfile{} + } + + if _, executed := profileResults[PrefixProfileName]; !executed { + // if prefix profile was not executed yet, first let the scheduler run the decode profile + return map[string]*framework.SchedulerProfile{ + PrefixProfileName: profiles[PrefixProfileName], + } + } + // otherwise, prefix was already executed. + + // return all profiles except prefix. + profilesToRun := make(map[string]*framework.SchedulerProfile) + for name, profile := range profiles { + if name != PrefixProfileName { + profilesToRun[name] = profile + } + } + return profilesToRun +} + +// ProcessResults handles the outcome of the profile runs after all profiles ran. +// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the +// key of the primary profile that should be used to get the request selected destination. +// When a profile run fails, its result in the profileResults map is nil. +func (h *SLOAwareProfileHandler) ProcessResults(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { + + if len(profileResults) < 2 { + return nil, errors.New("SLOAwareProfileHandler requires at least two profiles to operate") + } + + predictorBasedScheduling, err := parseBoolHeader(*request, PreictionBasedSchedulingHeaderKey) + if err != nil { + return nil, fmt.Errorf("error parsing predictorBasedScheduling from header failed to choose scheduling profile: x-prediction-based-scheduling must be a bool: %v", err) + } + + if predictorBasedScheduling { // TODO grab header directly from request.Headers instead of request field + if profileResults[SLOProfileName] == nil { // there was an error while running the SLO profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", SLOProfileName) + } + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: SLOProfileName, + }, nil + } + + if profileResults[DefaultProfileName] == nil { // there was an error while running the default profile + return nil, fmt.Errorf("failed to run scheduler profile '%s'", DefaultProfileName) + } + + return &types.SchedulingResult{ + ProfileResults: profileResults, + PrimaryProfileName: DefaultProfileName, + }, nil +} + +// parseFloatHeader retrieves a header by name, parses it as a bool, +// and returns the value or an error if the header is missing or invalid. +func parseBoolHeader(request types.LLMRequest, headerName string) (bool, error) { + // 1. Get header value from the map + headerValue, ok := request.Headers[headerName] + if !ok { + return false, nil // Header not found, return 0 and false + } + + // 2. Parse the header value to a bool + parsedBool, err := strconv.ParseBool(headerValue) + if err != nil { + return false, fmt.Errorf("must be a bool: %v", headerName) + } + + // 3. Return the successfully parsed value + return parsedBool, nil +} From 21a0b9f3ec8798e292985b79e49e39b317741ff4 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Tue, 11 Nov 2025 20:25:38 +0000 Subject: [PATCH 2/6] Add metrics required for plugins to compile --- pkg/epp/metrics/metrics.go | 366 ++++++++++++++++++++++++++++++++ pkg/epp/metrics/metrics_test.go | 2 + 2 files changed, 368 insertions(+) diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 59c8976cd8..af422f2b5c 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -63,6 +63,193 @@ var ( []string{"model_name", "target_model_name", "error_code"}, ) + requestTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3, + 4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTTFTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_ttft_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TTFT prediction duration + requestTTFTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestPredictedTPOT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0005, 0.00205, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.125, 0.15, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.8, 1, 1.5, 2, 3, 4.5, 6, 12, 18, 24, 30, 36, 48, 60, 90, 120, 180, 270, 360, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestPredictedTPOTGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_predicted_tpot_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // New metrics for TPOT prediction duration + requestTPOTPredictionDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds", + Help: metricsutil.HelpMsgWithStability("Duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, + }, + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTPredictionDurationGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_prediction_duration_seconds_gauge", + Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO Violation Metrics + requestTTFTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TTFT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTTFTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TTFT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolation = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation", + Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TPOT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOViolationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of TPOT SLO violations for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + // SLO threshold gauges (for dynamic threshold management) + requestTTFTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_ttft_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TTFT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + + requestTPOTSLOThreshold = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "request_tpot_slo_threshold_seconds", + Help: metricsutil.HelpMsgWithStability("Current TPOT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name"}, + ) + requestLatencies = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -282,6 +469,32 @@ var registerMetrics sync.Once // Register all metrics. func Register(customCollectors ...prometheus.Collector) { registerMetrics.Do(func() { + metrics.Registry.MustRegister(requestTPOT) + metrics.Registry.MustRegister(requestTTFT) + + metrics.Registry.MustRegister(requestTPOTGauge) + metrics.Registry.MustRegister(requestTTFTGauge) + + metrics.Registry.MustRegister(requestPredictedTPOT) + metrics.Registry.MustRegister(requestPredictedTTFT) + + metrics.Registry.MustRegister(requestPredictedTPOTGauge) + metrics.Registry.MustRegister(requestPredictedTTFTGauge) + + // Register new prediction duration metrics + metrics.Registry.MustRegister(requestTPOTPredictionDuration) + metrics.Registry.MustRegister(requestTPOTPredictionDurationGauge) + metrics.Registry.MustRegister(requestTTFTPredictionDuration) + metrics.Registry.MustRegister(requestTTFTPredictionDurationGauge) + + // Register SLO violation metrics + metrics.Registry.MustRegister(requestTTFTSLOViolation) + metrics.Registry.MustRegister(requestTTFTSLOViolationCounter) + metrics.Registry.MustRegister(requestTPOTSLOViolation) + metrics.Registry.MustRegister(requestTPOTSLOViolationCounter) + metrics.Registry.MustRegister(requestTTFTSLOThreshold) + metrics.Registry.MustRegister(requestTPOTSLOThreshold) + metrics.Registry.MustRegister(requestCounter) metrics.Registry.MustRegister(requestErrCounter) metrics.Registry.MustRegister(requestLatencies) @@ -332,6 +545,30 @@ func Reset() { PrefixCacheHitLength.Reset() flowControlRequestQueueDuration.Reset() flowControlQueueSize.Reset() + + requestTPOT.Reset() + requestTTFT.Reset() + requestTPOTGauge.Reset() + requestTTFTGauge.Reset() + + requestPredictedTPOT.Reset() + requestPredictedTTFT.Reset() + requestPredictedTPOTGauge.Reset() + requestPredictedTTFTGauge.Reset() + + // Reset new prediction duration metrics + requestTPOTPredictionDuration.Reset() + requestTPOTPredictionDurationGauge.Reset() + requestTTFTPredictionDuration.Reset() + requestTTFTPredictionDurationGauge.Reset() + + // Reset SLO violation metrics + requestTTFTSLOViolation.Reset() + requestTTFTSLOViolationCounter.Reset() + requestTPOTSLOViolation.Reset() + requestTPOTSLOViolationCounter.Reset() + requestTTFTSLOThreshold.Reset() + requestTPOTSLOThreshold.Reset() } // RecordRequstCounter records the number of requests. @@ -363,6 +600,123 @@ func RecordRequestLatencies(ctx context.Context, modelName, targetModelName stri return true } +func RecordRequestTPOT(ctx context.Context, modelName, targetModelName string, tpot float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + requestTPOT.WithLabelValues(modelName, targetModelName).Observe(tpot) + requestTPOTGauge.WithLabelValues(modelName, targetModelName).Set(tpot) + return true +} + +// RecordRequestTPOTWithSLO records TPOT and checks for SLO violation. +// If tpot exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If tpot is within limits, it sets gauge to 0. +func RecordRequestTPOTWithSLO(ctx context.Context, modelName, targetModelName string, tpot float64, sloThreshold float64) bool { + if tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot) + return false + } + + // Check for SLO violation (tpot exceeds threshold) + if tpot > sloThreshold { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTPOTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TPOT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot, "threshold", sloThreshold) + } else { + requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTPOT(ctx context.Context, modelName, targetModelName string, predicted_tpot float64) bool { + if predicted_tpot < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TPOT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "tpot", predicted_tpot) + return false + } + requestPredictedTPOT.WithLabelValues(modelName, targetModelName).Observe(predicted_tpot) + requestPredictedTPOTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_tpot) + return true +} + +// RecordRequestTPOTPredictionDuration records the duration taken to generate TPOT predictions. +func RecordRequestTPOTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TPOT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTPOTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTPOTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + +// TTFT records duration of request. +func RecordRequestTTFT(ctx context.Context, modelName, targetModelName string, ttft float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + requestTTFT.WithLabelValues(modelName, targetModelName).Observe(ttft) + requestTTFTGauge.WithLabelValues(modelName, targetModelName).Set(ttft) + return true +} + +// RecordRequestTTFTWithSLO records TTFT and checks for SLO violation. +// If ttft exceeds the threshold, it records a violation (sets gauge to 1 and increments counter). +// If ttft is within limits, it sets gauge to 0. +func RecordRequestTTFTWithSLO(ctx context.Context, modelName, targetModelName string, ttft float64, sloThreshold float64) bool { + if ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft) + return false + } + + // Check for SLO violation (ttft exceeds threshold) + if ttft > sloThreshold { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) + requestTTFTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + log.FromContext(ctx).V(logutil.DEFAULT).Info("TTFT SLO violation detected", + "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft, "threshold", sloThreshold) + } else { + requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + } + + return true +} + +// TPOT records duration of request. +func RecordRequestPredictedTTFT(ctx context.Context, modelName, targetModelName string, predicted_ttft float64) bool { + if predicted_ttft < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Predicted TTFT value must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "ttft", predicted_ttft) + return false + } + requestPredictedTTFT.WithLabelValues(modelName, targetModelName).Observe(predicted_ttft) + requestPredictedTTFTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_ttft) + return true +} + +// RecordRequestTTFTPredictionDuration records the duration taken to generate TTFT predictions. +func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool { + if duration < 0 { + log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "TTFT prediction duration must be non-negative", + "modelName", modelName, "targetModelName", targetModelName, "duration", duration) + return false + } + requestTTFTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) + requestTTFTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + return true +} + // RecordResponseSizes records the response sizes. func RecordResponseSizes(modelName, targetModelName string, size int) { responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size)) @@ -480,3 +834,15 @@ func IncFlowControlQueueSize(fairnessID, priority string) { func DecFlowControlQueueSize(fairnessID, priority string) { flowControlQueueSize.WithLabelValues(fairnessID, priority).Dec() } + +// SetTTFTSLOThreshold sets the TTFT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTTFTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} + +// SetTPOTSLOThreshold sets the TPOT SLO threshold for a model. +// This allows dynamic threshold management and makes the threshold visible in metrics. +func SetTPOTSLOThreshold(modelName, targetModelName string, threshold float64) { + requestTPOTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 7d41681830..754d6d2947 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -46,6 +46,8 @@ const ( KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" PerPodQueueSizeMetrics = InferencePoolComponent + "_per_pod_queue_size" + RequestTTFTSecondsMetric = InferenceObjectiveComponent + "_request_ttft_seconds" + RequestTPOTSecondsMetric = InferenceObjectiveComponent + "_request_tpot_seconds" ) func TestMain(m *testing.M) { From dabecb9e9c7c4a01f1a8cec1855d252ac650a113 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Thu, 13 Nov 2025 23:56:22 +0000 Subject: [PATCH 3/6] Small scorer changes --- .../framework/plugins/multi/slo_aware_router/scorer.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index b476579b5f..1aa2d9fd6a 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -82,7 +82,11 @@ func (s *SLOAwareRouter) epsilonGreedyAffinityGate( prefixStickyThreshold float64, ) ([]PodPredictionResult, bool) { logger := log.FromContext(ctx) - + if prefixStickyThreshold <= 0 { + // Affinity gating disabled + logger.V(logutil.DEBUG).Info("Affinity gating disabled (threshold <= 0)", "path", label) + return candidates, false + } eligible := make([]PodPredictionResult, 0, len(candidates)) for _, p := range candidates { if p.PrefixCacheScore >= prefixStickyThreshold { @@ -301,7 +305,7 @@ func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleSta if err != nil { // The prefix cache plugin might not be enabled, which is a valid scenario. - log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0: %v", err, "pod", pod.GetPod().String()) return 0.0 } From 0d235b96adba591af9b7d82276219fa07014f01a Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Tue, 18 Nov 2025 01:05:24 +0000 Subject: [PATCH 4/6] Unexport fields not used outside package, consolidate guage and counter metrics for prediction --- pkg/epp/metrics/metrics.go | 211 +++---------- .../plugins/multi/slo_aware_router/config.go | 22 +- .../plugins/multi/slo_aware_router/helpers.go | 32 +- .../latencypredictor_helper.go | 144 ++++----- .../multi/slo_aware_router/prediction.go | 42 +-- .../slo_aware_router/requestcontrol_hooks.go | 134 ++++---- .../requestcontrol_hooks_test.go | 292 +++++++++--------- .../slo_aware_router/running_request_queue.go | 90 +++--- .../running_request_queue_test.go | 54 ++-- .../plugins/multi/slo_aware_router/sampler.go | 42 +-- .../plugins/multi/slo_aware_router/scorer.go | 56 ++-- .../multi/slo_aware_router/scorer_test.go | 59 ++-- .../multi/slo_aware_router/selection.go | 58 ++-- .../plugins/multi/slo_aware_router/types.go | 34 +- 14 files changed, 569 insertions(+), 701 deletions(-) diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index af422f2b5c..24743ccd38 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -63,6 +63,16 @@ var ( []string{"model_name", "target_model_name", "error_code"}, ) + // Gauge for various inference request metrics + inferenceGauges = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: InferenceObjectiveComponent, + Name: "inference_request_metric", + Help: metricsutil.HelpMsgWithStability("Consolidated gauge for various inference request metrics including TTFT, TPOT, SLOs, and prediction durations.", compbasemetrics.ALPHA), + }, + []string{"model_name", "target_model_name", "type"}, + ) + requestTTFT = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -76,15 +86,6 @@ var ( []string{"model_name", "target_model_name"}, ) - requestTTFTGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_ttft_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Inference model TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - requestPredictedTTFT = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -98,15 +99,6 @@ var ( []string{"model_name", "target_model_name"}, ) - requestPredictedTTFTGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_predicted_ttft_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - // New metrics for TTFT prediction duration requestTTFTPredictionDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -120,15 +112,6 @@ var ( []string{"model_name", "target_model_name"}, ) - requestTTFTPredictionDurationGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_ttft_prediction_duration_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - requestTPOT = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -142,14 +125,6 @@ var ( []string{"model_name", "target_model_name"}, ) - requestTPOTGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_tpot_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Inference model TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) requestPredictedTPOT = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Subsystem: InferenceObjectiveComponent, @@ -163,15 +138,6 @@ var ( []string{"model_name", "target_model_name"}, ) - requestPredictedTPOTGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_predicted_tpot_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - // New metrics for TPOT prediction duration requestTPOTPredictionDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -185,69 +151,14 @@ var ( []string{"model_name", "target_model_name"}, ) - requestTPOTPredictionDurationGauge = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_tpot_prediction_duration_seconds_gauge", - Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - - // SLO Violation Metrics - requestTTFTSLOViolation = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_ttft_slo_violation", - Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TTFT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - - requestTTFTSLOViolationCounter = prometheus.NewCounterVec( + // Counter for SLO Violations + sloViolationCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Subsystem: InferenceObjectiveComponent, - Name: "request_ttft_slo_violation_total", - Help: metricsutil.HelpMsgWithStability("Counter of TTFT SLO violations for each model and target model.", compbasemetrics.ALPHA), + Name: "request_slo_violation_total", + Help: metricsutil.HelpMsgWithStability("Counter of SLO violations for each model, target model, and violation type.", compbasemetrics.ALPHA), }, - []string{"model_name", "target_model_name"}, - ) - - requestTPOTSLOViolation = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_tpot_slo_violation", - Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TPOT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - - requestTPOTSLOViolationCounter = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_tpot_slo_violation_total", - Help: metricsutil.HelpMsgWithStability("Counter of TPOT SLO violations for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - - // SLO threshold gauges (for dynamic threshold management) - requestTTFTSLOThreshold = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_ttft_slo_threshold_seconds", - Help: metricsutil.HelpMsgWithStability("Current TTFT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, - ) - - requestTPOTSLOThreshold = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Subsystem: InferenceObjectiveComponent, - Name: "request_tpot_slo_threshold_seconds", - Help: metricsutil.HelpMsgWithStability("Current TPOT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA), - }, - []string{"model_name", "target_model_name"}, + []string{"model_name", "target_model_name", "type"}, ) requestLatencies = prometheus.NewHistogramVec( @@ -469,32 +380,21 @@ var registerMetrics sync.Once // Register all metrics. func Register(customCollectors ...prometheus.Collector) { registerMetrics.Do(func() { + // Register inference gauges + metrics.Registry.MustRegister(inferenceGauges) + + // Register Histograms metrics.Registry.MustRegister(requestTPOT) metrics.Registry.MustRegister(requestTTFT) - - metrics.Registry.MustRegister(requestTPOTGauge) - metrics.Registry.MustRegister(requestTTFTGauge) - metrics.Registry.MustRegister(requestPredictedTPOT) metrics.Registry.MustRegister(requestPredictedTTFT) - - metrics.Registry.MustRegister(requestPredictedTPOTGauge) - metrics.Registry.MustRegister(requestPredictedTTFTGauge) - - // Register new prediction duration metrics metrics.Registry.MustRegister(requestTPOTPredictionDuration) - metrics.Registry.MustRegister(requestTPOTPredictionDurationGauge) metrics.Registry.MustRegister(requestTTFTPredictionDuration) - metrics.Registry.MustRegister(requestTTFTPredictionDurationGauge) - // Register SLO violation metrics - metrics.Registry.MustRegister(requestTTFTSLOViolation) - metrics.Registry.MustRegister(requestTTFTSLOViolationCounter) - metrics.Registry.MustRegister(requestTPOTSLOViolation) - metrics.Registry.MustRegister(requestTPOTSLOViolationCounter) - metrics.Registry.MustRegister(requestTTFTSLOThreshold) - metrics.Registry.MustRegister(requestTPOTSLOThreshold) + // Register SLO violation counters + metrics.Registry.MustRegister(sloViolationCounter) + // Register other metrics metrics.Registry.MustRegister(requestCounter) metrics.Registry.MustRegister(requestErrCounter) metrics.Registry.MustRegister(requestLatencies) @@ -524,6 +424,21 @@ func Register(customCollectors ...prometheus.Collector) { // Just for integration test func Reset() { + // Reset inference gauges + inferenceGauges.Reset() + + // Reset Histograms + requestTPOT.Reset() + requestTTFT.Reset() + requestPredictedTPOT.Reset() + requestPredictedTTFT.Reset() + requestTPOTPredictionDuration.Reset() + requestTTFTPredictionDuration.Reset() + + // Reset SLO violation counter + sloViolationCounter.Reset() + + // Reset other metrics requestCounter.Reset() requestErrCounter.Reset() requestLatencies.Reset() @@ -545,30 +460,6 @@ func Reset() { PrefixCacheHitLength.Reset() flowControlRequestQueueDuration.Reset() flowControlQueueSize.Reset() - - requestTPOT.Reset() - requestTTFT.Reset() - requestTPOTGauge.Reset() - requestTTFTGauge.Reset() - - requestPredictedTPOT.Reset() - requestPredictedTTFT.Reset() - requestPredictedTPOTGauge.Reset() - requestPredictedTTFTGauge.Reset() - - // Reset new prediction duration metrics - requestTPOTPredictionDuration.Reset() - requestTPOTPredictionDurationGauge.Reset() - requestTTFTPredictionDuration.Reset() - requestTTFTPredictionDurationGauge.Reset() - - // Reset SLO violation metrics - requestTTFTSLOViolation.Reset() - requestTTFTSLOViolationCounter.Reset() - requestTPOTSLOViolation.Reset() - requestTPOTSLOViolationCounter.Reset() - requestTTFTSLOThreshold.Reset() - requestTPOTSLOThreshold.Reset() } // RecordRequstCounter records the number of requests. @@ -607,7 +498,7 @@ func RecordRequestTPOT(ctx context.Context, modelName, targetModelName string, t return false } requestTPOT.WithLabelValues(modelName, targetModelName).Observe(tpot) - requestTPOTGauge.WithLabelValues(modelName, targetModelName).Set(tpot) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot"}).Set(tpot) return true } @@ -623,12 +514,12 @@ func RecordRequestTPOTWithSLO(ctx context.Context, modelName, targetModelName st // Check for SLO violation (tpot exceeds threshold) if tpot > sloThreshold { - requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) - requestTPOTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_violation"}).Set(1) + sloViolationCounter.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot"}).Inc() log.FromContext(ctx).V(logutil.DEFAULT).Info("TPOT SLO violation detected", "modelName", modelName, "targetModelName", targetModelName, "tpot", tpot, "threshold", sloThreshold) } else { - requestTPOTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_violation"}).Set(0) } return true @@ -642,7 +533,7 @@ func RecordRequestPredictedTPOT(ctx context.Context, modelName, targetModelName return false } requestPredictedTPOT.WithLabelValues(modelName, targetModelName).Observe(predicted_tpot) - requestPredictedTPOTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_tpot) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "predicted_tpot"}).Set(predicted_tpot) return true } @@ -654,7 +545,7 @@ func RecordRequestTPOTPredictionDuration(ctx context.Context, modelName, targetM return false } requestTPOTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) - requestTPOTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_prediction_duration"}).Set(duration) return true } @@ -666,7 +557,7 @@ func RecordRequestTTFT(ctx context.Context, modelName, targetModelName string, t return false } requestTTFT.WithLabelValues(modelName, targetModelName).Observe(ttft) - requestTTFTGauge.WithLabelValues(modelName, targetModelName).Set(ttft) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft"}).Set(ttft) return true } @@ -682,12 +573,12 @@ func RecordRequestTTFTWithSLO(ctx context.Context, modelName, targetModelName st // Check for SLO violation (ttft exceeds threshold) if ttft > sloThreshold { - requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(1) - requestTTFTSLOViolationCounter.WithLabelValues(modelName, targetModelName).Inc() + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_violation"}).Set(1) + sloViolationCounter.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft"}).Inc() log.FromContext(ctx).V(logutil.DEFAULT).Info("TTFT SLO violation detected", "modelName", modelName, "targetModelName", targetModelName, "ttft", ttft, "threshold", sloThreshold) } else { - requestTTFTSLOViolation.WithLabelValues(modelName, targetModelName).Set(0) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_violation"}).Set(0) } return true @@ -701,7 +592,7 @@ func RecordRequestPredictedTTFT(ctx context.Context, modelName, targetModelName return false } requestPredictedTTFT.WithLabelValues(modelName, targetModelName).Observe(predicted_ttft) - requestPredictedTTFTGauge.WithLabelValues(modelName, targetModelName).Set(predicted_ttft) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "predicted_ttft"}).Set(predicted_ttft) return true } @@ -713,7 +604,7 @@ func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetM return false } requestTTFTPredictionDuration.WithLabelValues(modelName, targetModelName).Observe(duration) - requestTTFTPredictionDurationGauge.WithLabelValues(modelName, targetModelName).Set(duration) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_prediction_duration"}).Set(duration) return true } @@ -838,11 +729,11 @@ func DecFlowControlQueueSize(fairnessID, priority string) { // SetTTFTSLOThreshold sets the TTFT SLO threshold for a model. // This allows dynamic threshold management and makes the threshold visible in metrics. func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) { - requestTTFTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "ttft_slo_threshold"}).Set(threshold) } // SetTPOTSLOThreshold sets the TPOT SLO threshold for a model. // This allows dynamic threshold management and makes the threshold visible in metrics. func SetTPOTSLOThreshold(modelName, targetModelName string, threshold float64) { - requestTPOTSLOThreshold.WithLabelValues(modelName, targetModelName).Set(threshold) + inferenceGauges.With(prometheus.Labels{"model_name": modelName, "target_model_name": targetModelName, "type": "tpot_slo_threshold"}).Set(threshold) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go index fcb4b72236..bbfd772232 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/config.go @@ -86,22 +86,22 @@ var HeadroomTPOTWeight = func() float64 { return 0.2 // default }() -var HeadroomSelectionStrategy = func() HeadroomStrategy { +var HeadroomSelectionStrategy = func() headroomStrategy { if value, exists := os.LookupEnv("HEADROOM_SELECTION_STRATEGY"); exists { switch strings.ToLower(value) { case "least": - return HeadroomStrategyLeast + return headroomStrategyLeast case "most": - return HeadroomStrategyMost + return headroomStrategyMost case "composite-least": - return HeadroomStrategyCompositeLeast + return headroomStrategyCompositeLeast case "composite-most": - return HeadroomStrategyCompositeMost + return headroomStrategyCompositeMost case "composite-only": - return HeadroomStrategyCompositeOnly + return headroomStrategyCompositeOnly } } - return HeadroomStrategyLeast // default to least (better packing) + return headroomStrategyLeast // default to least (better packing) }() // If using composite headroom, weights for each component. Not used by default @@ -176,16 +176,16 @@ var AffinityGateTauGlobal = func() float64 { }() // Read once at init. Values: "linear" (default) or "max". -var SelectionMode = func() PodSelectionMode { +var SelectionMode = func() podSelectionMode { if v, ok := os.LookupEnv("POD_SELECTION_MODE"); ok { switch strings.ToLower(v) { case "max": - return PodSelectionMax + return podSelectionMax case "linear": fallthrough default: - return PodSelectionLinear + return podSelectionLinear } } - return PodSelectionLinear + return podSelectionLinear }() diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go index 1d55682435..3b5820610e 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -26,35 +26,35 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []PodPredictionResult, r *rand.Rand, strategy HeadroomStrategy) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []podPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Pod { total := 0 choices := s.buildCompositeChoices( ctx, allPreds, CompositeKVWeight, CompositeQueueWeight, CompositePrefixWeight, &total, ) - if strategy == HeadroomStrategyCompositeLeast { + if strategy == headroomStrategyCompositeLeast { // Invert weights for "least" strategy for i := range choices { - choices[i].Weight = minWeight + Wmax - choices[i].Weight + choices[i].weight = minWeight + wMax - choices[i].weight } } selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) return selectedPod } -func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice, total int, candidates []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { if total == 0 { return nil } logger := log.FromContext(context.Background()) // Check if MAX_SCORE_SELECTION env variable is set - if SelectionMode == PodSelectionMax { + if SelectionMode == podSelectionMax { logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") maxWeight := 0 var selectedPod schedulingtypes.Pod for _, c := range weightedChoices { - if c.Weight > maxWeight { - maxWeight = c.Weight - selectedPod = c.PodName + if c.weight > maxWeight { + maxWeight = c.weight + selectedPod = c.podName } } if selectedPod != nil { @@ -70,11 +70,11 @@ func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice var selectedPod schedulingtypes.Pod for _, c := range weightedChoices { - if idx < c.Weight { - selectedPod = c.PodName + if idx < c.weight { + selectedPod = c.podName break } - idx -= c.Weight + idx -= c.weight } // If no pod was selected (shouldn't happen), fallback to first pod @@ -86,10 +86,10 @@ func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []Choice } func (s *SLOAwareRouter) buildCompositeChoices( ctx context.Context, - candidates []PodPredictionResult, + candidates []podPredictionResult, wkv, wq, wpref float64, total *int, -) []Choice { +) []choice { // Normalize weights sumw := wkv + wq + wpref @@ -116,7 +116,7 @@ func (s *SLOAwareRouter) buildCompositeChoices( } den := float64(maxQ - minQ) - choices := make([]Choice, 0, len(candidates)) + choices := make([]choice, 0, len(candidates)) for _, p := range candidates { q := queueCounts[p.Pod.GetPod().String()] relQueue := 1.0 @@ -129,9 +129,9 @@ func (s *SLOAwareRouter) buildCompositeChoices( prefix := (p.PrefixCacheScore) composite := wkv*kvFree + wq*relQueue + wpref*prefix - w := int(math.Round(float64(minWeight) + (float64(Wmax-minWeight) * composite))) + w := int(math.Round(float64(minWeight) + (float64(wMax-minWeight) * composite))) *total += w - choices = append(choices, Choice{PodName: p.Pod, Weight: w}) + choices = append(choices, choice{podName: p.Pod, weight: w}) log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", "pod", p.Pod.GetPod().String(), diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index aa47f93c9d..1a50847bed 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -32,13 +32,13 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) -// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. -func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { - if sr := sloCtx.SchedulingResult; sr != nil { +// refreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. +func refreshLastSeenMetrics(ctx context.Context, sloCtx *sloRequestContext) { + if sr := sloCtx.schedulingResult; sr != nil { if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { for profileName, profileResult := range sr.ProfileResults { if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { - sloCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + sloCtx.lastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() } } } @@ -48,13 +48,13 @@ func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) { } // GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. -func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext) (*backendmetrics.MetricsState, error) { - if len(sloCtx.LastSeenMetrics) == 0 { +func getLatestMetricsForProfile(ctx context.Context, sloCtx *sloRequestContext) (*backendmetrics.MetricsState, error) { + if len(sloCtx.lastSeenMetrics) == 0 { return nil, fmt.Errorf("no last seen metrics available for prediction") } - primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName - if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists { + primaryProfileName := sloCtx.schedulingResult.PrimaryProfileName + if metrics, exists := sloCtx.lastSeenMetrics[primaryProfileName]; exists { return metrics, nil } @@ -62,10 +62,10 @@ func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext) } // ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp. -func ProcessHeaderForLatencyPrediction( +func processHeaderForLatencyPrediction( ctx context.Context, predictor latencypredictor.PredictorInterface, - sloCtx *SLORequestContext, + sloCtx *sloRequestContext, ) error { logger := log.FromContext(ctx) @@ -73,18 +73,18 @@ func ProcessHeaderForLatencyPrediction( //print the raw scores in scheduling result // Build prediction request - m, err := GetLatestMetricsForProfile(ctx, sloCtx) + m, err := getLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return err } - targetPod := sloCtx.TargetPod - prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + targetPod := sloCtx.targetPod + prefix_cache_score := sloCtx.prefixCacheScoresForPods[targetPod.String()] in := latencypredictor.PredictionRequest{ KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), NumRequestWaiting: m.WaitingQueueSize, NumRequestRunning: m.RunningQueueSize, NumTokensGenerated: 0, @@ -97,55 +97,55 @@ func ProcessHeaderForLatencyPrediction( dur := time.Since(start) if err != nil { logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTTFT = 0 + sloCtx.predictedTTFT = 0 } else if p == nil { logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTTFT = 0 + sloCtx.predictedTTFT = 0 } else { logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) - metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) - sloCtx.PredictedTTFT = p.TTFT + sloCtx.predictedTTFT = p.TTFT } // Advance timestamp for first token reference - sloCtx.LastTokenTimestamp = time.Now() - RefreshLastSeenMetrics(ctx, sloCtx) + sloCtx.lastTokenTimestamp = time.Now() + refreshLastSeenMetrics(ctx, sloCtx) return err } // ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp. -func ProcessFirstTokenForLatencyPrediction( +func processFirstTokenForLatencyPrediction( ctx context.Context, predictor latencypredictor.PredictorInterface, - sloCtx *SLORequestContext, + sloCtx *sloRequestContext, now time.Time, ) { logger := log.FromContext(ctx) // Initialize sampler - if sloCtx.TokenSampler == nil { - requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] - sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) - logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + if sloCtx.tokenSampler == nil { + requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) } // Actual TTFT - sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds()) - sloCtx.GeneratedTokenCount = 1 - m, err := GetLatestMetricsForProfile(ctx, sloCtx) + sloCtx.ttft = float64(now.Sub(sloCtx.requestReceivedTimestamp).Milliseconds()) + sloCtx.generatedTokenCount = 1 + m, err := getLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return } - targetPod := sloCtx.TargetPod - prefix_cache_score := sloCtx.PrefixCacheScoresForPods[targetPod.String()] + targetPod := sloCtx.targetPod + prefix_cache_score := sloCtx.prefixCacheScoresForPods[targetPod.String()] // Train TTFT entry := latencypredictor.TrainingEntry{ KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), - ActualTTFT: sloCtx.TTFT, + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), + ActualTTFT: sloCtx.ttft, ActualTPOT: 0, Timestamp: now, NumRequestWaiting: m.WaitingQueueSize, @@ -156,7 +156,7 @@ func ProcessFirstTokenForLatencyPrediction( if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") } - m, err = GetLatestMetricsForProfile(ctx, sloCtx) + m, err = getLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -166,10 +166,10 @@ func ProcessFirstTokenForLatencyPrediction( // Predict first TPOT in := latencypredictor.PredictionRequest{ KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), NumRequestWaiting: m.WaitingQueueSize, NumRequestRunning: m.RunningQueueSize, - NumTokensGenerated: sloCtx.GeneratedTokenCount, + NumTokensGenerated: sloCtx.generatedTokenCount, PrefixCacheScore: 0, } start := time.Now() @@ -177,48 +177,48 @@ func ProcessFirstTokenForLatencyPrediction( dur := time.Since(start) if err != nil || p == nil { logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) - sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, 0) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, 0, len(sloCtx.predictedTPOTObservations)) } else { logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) - sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, p.TPOT) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, p.TPOT, len(sloCtx.predictedTPOTObservations)) } - metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) // Advance timestamp - sloCtx.LastTokenTimestamp = now + sloCtx.lastTokenTimestamp = now // Refresh metrics - RefreshLastSeenMetrics(ctx, sloCtx) + refreshLastSeenMetrics(ctx, sloCtx) } // ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. -func ProcessTokenForLatencyPrediction( +func processTokenForLatencyPrediction( ctx context.Context, predictor latencypredictor.PredictorInterface, - sloCtx *SLORequestContext, + sloCtx *sloRequestContext, now time.Time, ) { logger := log.FromContext(ctx) // Initialize sampler if not yet - if sloCtx.TokenSampler == nil { - requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey] - sloCtx.TokenSampler = NewTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) - logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken()) + if sloCtx.tokenSampler == nil { + requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) } // Inter-token latency - latencyMs := float64(now.Sub(sloCtx.LastTokenTimestamp).Milliseconds()) - sloCtx.GeneratedTokenCount++ + latencyMs := float64(now.Sub(sloCtx.lastTokenTimestamp).Milliseconds()) + sloCtx.generatedTokenCount++ //log the inter-token latency for predicted samples - if sloCtx.GeneratedTokenCount == 2 || sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token - sloCtx.TPOTObservations = append(sloCtx.TPOTObservations, latencyMs) - sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations)) + if sloCtx.generatedTokenCount == 2 || sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { //tricky logic, since next sample token is always +1 from current token + sloCtx.tpotObservations = append(sloCtx.tpotObservations, latencyMs) + sloCtx.avgTPOT = calculateRunningAverage(sloCtx.avgTPOT, latencyMs, len(sloCtx.tpotObservations)) } - m, err := GetLatestMetricsForProfile(ctx, sloCtx) + m, err := getLatestMetricsForProfile(ctx, sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -227,13 +227,13 @@ func ProcessTokenForLatencyPrediction( // Record actual TPOT entry := latencypredictor.TrainingEntry{ KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), ActualTTFT: 0, ActualTPOT: latencyMs, Timestamp: now, NumRequestWaiting: m.WaitingQueueSize, NumRequestRunning: m.RunningQueueSize, - NumTokensGenerated: sloCtx.GeneratedTokenCount - 1, + NumTokensGenerated: sloCtx.generatedTokenCount - 1, PrefixCacheScore: 0, // TPOT does not use prefix cache score } if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { @@ -241,13 +241,13 @@ func ProcessTokenForLatencyPrediction( } // Sampled predict - if sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { + if sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { in := latencypredictor.PredictionRequest{ KVCachePercentage: m.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)), + InputTokenLength: len(strings.Fields(sloCtx.schedulingRequest.Body.Completions.Prompt)), NumRequestWaiting: m.WaitingQueueSize, NumRequestRunning: m.RunningQueueSize, - NumTokensGenerated: sloCtx.GeneratedTokenCount, + NumTokensGenerated: sloCtx.generatedTokenCount, PrefixCacheScore: 0, // TPOT does not use prefix cache score } start := time.Now() @@ -255,26 +255,26 @@ func ProcessTokenForLatencyPrediction( dur := time.Since(start) if err != nil || p == nil { logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0) - sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations)) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, 0) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, 0, len(sloCtx.predictedTPOTObservations)) } else { logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds()) - sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT) - sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations)) + sloCtx.predictedTPOTObservations = append(sloCtx.predictedTPOTObservations, p.TPOT) + sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, p.TPOT, len(sloCtx.predictedTPOTObservations)) } - metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds()) + metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) - sloCtx.TokenSampler.RecordPrediction(sloCtx.GeneratedTokenCount) + sloCtx.tokenSampler.recordPrediction(sloCtx.generatedTokenCount) } // Advance timestamp - sloCtx.LastTokenTimestamp = now + sloCtx.lastTokenTimestamp = now // Refresh metrics - RefreshLastSeenMetrics(ctx, sloCtx) + refreshLastSeenMetrics(ctx, sloCtx) } -// PredictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. -func PredictWithMetrics( +// predictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. +func predictWithMetrics( ctx context.Context, predictor latencypredictor.PredictorInterface, metricsState *backendmetrics.MetricsState, @@ -335,9 +335,9 @@ func PredictWithMetrics( return result, nil } -// BulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// bulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. // Returns predictions in the same order as the input slices. -func BulkPredictWithMetrics( +func bulkPredictWithMetrics( ctx context.Context, predictor latencypredictor.PredictorInterface, metricsStates []*backendmetrics.MetricsState, diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index 0c2cfa0a95..eaaf56608d 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -26,7 +26,7 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) -type PodPredictionResult struct { +type podPredictionResult struct { Pod schedulingtypes.Pod TTFT float64 TPOT float64 @@ -40,24 +40,24 @@ type PodPredictionResult struct { } // generatePredictions creates prediction results for all candidate pods -func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *SLORequestContext, candidatePods []schedulingtypes.Pod) ([]PodPredictionResult, error) { +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) { logger := log.FromContext(ctx) - predictions := make([]PodPredictionResult, 0, len(candidatePods)) + predictions := make([]podPredictionResult, 0, len(candidatePods)) for _, pod := range candidatePods { - predResult := PodPredictionResult{Pod: pod} + predResult := podPredictionResult{Pod: pod} logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) // Get prefix cache score for the pod prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - sloCtx.PrefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) // Generate prediction - prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) + prediction, err := predictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) if err != nil { logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) predResult.Error = err @@ -81,8 +81,8 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul "TPOT", prediction.TPOT, "buffer", SLOBufferFactor, "podMinTPOTSLO", podMinTPOTSLO, - "ttftSLO", sloCtx.TTFTSLO, - "requestTPOTSLO", sloCtx.AvgTPOTSLO, + "ttftSLO", sloCtx.ttftSLO, + "requestTPOTSLO", sloCtx.avgTPOTSLO, "tpotHeadroom", predResult.Headroom, "ttftHeadroom", predResult.TTFTHeadroom, "tpotValid", predResult.TPOTValid, @@ -96,43 +96,43 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul } // updateRequestContextWithPredictions updates the request context with prediction data -func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *SLORequestContext, predictions []PodPredictionResult) { +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []podPredictionResult) { for _, pred := range predictions { if pred.Error == nil { podKey := pred.Pod.GetPod().String() - if sloCtx.PredictedTTFTForScheduling == nil { - sloCtx.PredictedTTFTForScheduling = make(map[string]float64) + if sloCtx.predictedTTFTForScheduling == nil { + sloCtx.predictedTTFTForScheduling = make(map[string]float64) } - if sloCtx.PredictedTPOTForScheduling == nil { - sloCtx.PredictedTPOTForScheduling = make(map[string]float64) + if sloCtx.predictedTPOTForScheduling == nil { + sloCtx.predictedTPOTForScheduling = make(map[string]float64) } - sloCtx.PredictedTTFTForScheduling[podKey] = pred.TTFT - sloCtx.PredictedTPOTForScheduling[podKey] = pred.TPOT + sloCtx.predictedTTFTForScheduling[podKey] = pred.TTFT + sloCtx.predictedTPOTForScheduling[podKey] = pred.TPOT } } } func (s *SLOAwareRouter) validatePrediction( pred *latencypredictor.PredictionResponse, - sloCtx *SLORequestContext, + sloCtx *sloRequestContext, podMinTPOTSLO float64, ) (ttftOk, tpotOk, isValid bool, headroom float64, ttftHeadroom float64) { - bufferedTPOT := sloCtx.AvgTPOTSLO * SLOBufferFactor + bufferedTPOT := sloCtx.avgTPOTSLO * SLOBufferFactor // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests if podMinTPOTSLO > 0 { - if podMinTPOTSLO < sloCtx.AvgTPOTSLO { + if podMinTPOTSLO < sloCtx.avgTPOTSLO { //print debug message - log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.AvgTPOTSLO) + log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.avgTPOTSLO) } bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) } tpotOk = pred.TPOT < bufferedTPOT - ttftOk = pred.TTFT < sloCtx.TTFTSLO + ttftOk = pred.TTFT < sloCtx.ttftSLO isValid = ttftOk && tpotOk headroom = bufferedTPOT - pred.TPOT - ttftHeadroom = sloCtx.TTFTSLO - pred.TTFT + ttftHeadroom = sloCtx.ttftSLO - pred.TTFT return } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index f865bbeb37..92ef6042e0 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -39,60 +39,60 @@ var _ requestcontrol.ResponseReceived = &SLOAwareRouter{} var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{} var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} -type SLORequestContext struct { - SchedulingRequest schedulingtypes.LLMRequest - TargetPod *backend.Pod - SchedulingResult *schedulingtypes.SchedulingResult - LastSeenMetrics map[string]*backendmetrics.MetricsState - LastTokenTimestamp time.Time - RequestReceivedTimestamp time.Time - GeneratedTokenCount int - IncomingModelName string - TTFT float64 - PredictedTTFT float64 - AvgTPOT float64 - AvgPredictedTPOT float64 - TokenSampler *TokenSampler - TPOTObservations []float64 - PredictedTPOTObservations []float64 - - PrefixCacheScoresForPods map[string]float64 - - // TTFTSLO is the target time to first token SLO for the request. - TTFTSLO float64 +type sloRequestContext struct { + schedulingRequest schedulingtypes.LLMRequest + targetPod *backend.Pod + schedulingResult *schedulingtypes.SchedulingResult + lastSeenMetrics map[string]*backendmetrics.MetricsState + lastTokenTimestamp time.Time + requestReceivedTimestamp time.Time + generatedTokenCount int + incomingModelName string + ttft float64 + predictedTTFT float64 + avgTPOT float64 + avgPredictedTPOT float64 + tokenSampler *tokenSampler + tpotObservations []float64 + predictedTPOTObservations []float64 + + prefixCacheScoresForPods map[string]float64 + + // ttftSLO is the target time to first token SLO for the request. + ttftSLO float64 // TPOTSLO is the target time per output token SLO for the request. - AvgTPOTSLO float64 + avgTPOTSLO float64 - // PredictorBasedScheduling indicates whether to use predictor based scheduling. - PredictorBasedScheduling bool - //PredictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. - PredictedTTFTForScheduling map[string]float64 - // PredictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. - PredictedTPOTForScheduling map[string]float64 + // predictorBasedScheduling indicates whether to use predictor based scheduling. + predictorBasedScheduling bool + //predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + predictedTTFTForScheduling map[string]float64 + // predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + predictedTPOTForScheduling map[string]float64 // boolean set if request has valid pod based on predictions - HasValidPod bool + hasValidPod bool } -func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { - return &SLORequestContext{ - SchedulingRequest: *request, - LastSeenMetrics: make(map[string]*backendmetrics.MetricsState), - PrefixCacheScoresForPods: make(map[string]float64), - PredictedTTFTForScheduling: make(map[string]float64), - PredictedTPOTForScheduling: make(map[string]float64), +func newSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { + return &sloRequestContext{ + schedulingRequest: *request, + lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + prefixCacheScoresForPods: make(map[string]float64), + predictedTTFTForScheduling: make(map[string]float64), + predictedTPOTForScheduling: make(map[string]float64), } } -func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) { +func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*sloRequestContext, error) { id := request.Headers[requtil.RequestIdHeaderKey] if ctx, exists := s.sloContextStore.Load(id); exists { - return ctx.(*SLORequestContext), nil + return ctx.(*sloRequestContext), nil } return nil, fmt.Errorf("SLO context not found for request ID: %s", id) } -func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) { +func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *sloRequestContext) { id := request.Headers[requtil.RequestIdHeaderKey] s.sloContextStore.Store(id, ctx) } @@ -113,7 +113,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype } targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() - if !t.CheckPredictor(logger, targetPod) { + if !t.checkPredictor(logger, targetPod) { return } @@ -130,7 +130,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype id := request.Headers[requtil.RequestIdHeaderKey] podRequestList, ok := t.runningRequestLists[podName] if !ok { - podRequestList = NewRequestPriorityQueue() + podRequestList = newRequestPriorityQueue() t.runningRequestLists[podName] = podRequestList } @@ -141,22 +141,22 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype return } - added := podRequestList.Add(id, sloCtx.AvgTPOTSLO) + added := podRequestList.Add(id, sloCtx.avgTPOTSLO) if !added { logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) } // Set up SLO request context - sloCtx.TargetPod = targetPod - sloCtx.SchedulingResult = schedulingResult - sloCtx.RequestReceivedTimestamp = time.Now() - RefreshLastSeenMetrics(ctx, sloCtx) + sloCtx.targetPod = targetPod + sloCtx.schedulingResult = schedulingResult + sloCtx.requestReceivedTimestamp = time.Now() + refreshLastSeenMetrics(ctx, sloCtx) t.setSLOContextForRequest(request, sloCtx) } func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { logger := log.FromContext(ctx) - if !t.CheckPredictor(logger, targetPod) { + if !t.checkPredictor(logger, targetPod) { return } @@ -168,7 +168,7 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli return } - if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { + if err := processHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil { logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed") } @@ -176,7 +176,7 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { logger := log.FromContext(ctx) - if !t.CheckPredictor(logger, pod) || response.EndOfStream { + if !t.checkPredictor(logger, pod) || response.EndOfStream { return } @@ -188,10 +188,10 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul return } - if sloCtx.TTFT == 0 { - ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + if sloCtx.ttft == 0 { + processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) } else { - ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) + processTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now) } } @@ -199,7 +199,7 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { logger := log.FromContext(ctx) targetPod := pod - if !t.CheckPredictor(logger, targetPod) { + if !t.checkPredictor(logger, targetPod) { return } @@ -210,25 +210,25 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli return } - if sloCtx.TTFT > 0 { - logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT) - metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000) - metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000) - if sloCtx.TTFTSLO > 0 { - metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT, sloCtx.TTFTSLO) + if sloCtx.ttft > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", sloCtx.ttft, "avgPredictedTTFT", sloCtx.predictedTTFT) + metrics.RecordRequestTTFT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.ttft/1000) + metrics.RecordRequestPredictedTTFT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.predictedTTFT/1000) + if sloCtx.ttftSLO > 0 { + metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.ttft, sloCtx.ttftSLO) } } - if sloCtx.AvgTPOT > 0 { - logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT) - metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000) - metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000) - if sloCtx.AvgTPOTSLO > 0 { - metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT, sloCtx.AvgTPOTSLO) + if sloCtx.avgTPOT > 0 { + logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", sloCtx.avgTPOT, "avgPredictedTPOT", sloCtx.avgPredictedTPOT) + metrics.RecordRequestTPOT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgTPOT/1000) + metrics.RecordRequestPredictedTPOT(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgPredictedTPOT/1000) + if sloCtx.avgTPOTSLO > 0 { + metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.incomingModelName, request.TargetModel, sloCtx.avgTPOT, sloCtx.avgTPOTSLO) } } - logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.PredictorBasedScheduling) + logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.predictorBasedScheduling) podName := types.NamespacedName{ Name: targetPod.NamespacedName.Name, @@ -249,7 +249,7 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli t.deleteSLOContextForRequest(request) } -func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool { +func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, targetPod *backend.Pod) bool { if targetPod == nil { logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") return false diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index 96999af2f3..3b297b2bbd 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -54,7 +54,7 @@ func createTestSchedulingResult(pod *backend.Pod, kvUsage float64, runningQueue func createTestRouter() *SLOAwareRouter { return &SLOAwareRouter{ sloContextStore: sync.Map{}, - runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), latencypredictor: nil, } } @@ -64,22 +64,22 @@ func createTestRouter() *SLOAwareRouter { func TestNewSLORequestContext(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) - ctx := NewSLORequestContext(request) + ctx := newSLORequestContext(request) assert.NotNil(t, ctx) - assert.Equal(t, *request, ctx.SchedulingRequest) - assert.NotNil(t, ctx.LastSeenMetrics) - assert.NotNil(t, ctx.PrefixCacheScoresForPods) - assert.NotNil(t, ctx.PredictedTTFTForScheduling) - assert.NotNil(t, ctx.PredictedTPOTForScheduling) - assert.Empty(t, ctx.LastSeenMetrics) - assert.Empty(t, ctx.PrefixCacheScoresForPods) + assert.Equal(t, *request, ctx.schedulingRequest) + assert.NotNil(t, ctx.lastSeenMetrics) + assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.predictedTTFTForScheduling) + assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.Empty(t, ctx.lastSeenMetrics) + assert.Empty(t, ctx.prefixCacheScoresForPods) } func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { router := createTestRouter() request := createTestLLMRequest("test", 100, 50, true) - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) // Set context router.setSLOContextForRequest(request, sloCtx) @@ -106,7 +106,7 @@ func TestSLOAwareRouter_GetSLOContext_NotFound(t *testing.T) { func TestSLOAwareRouter_DeleteSLOContext(t *testing.T) { router := createTestRouter() request := createTestLLMRequest("test", 100, 50, true) - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) // Set and then delete context router.setSLOContextForRequest(request, sloCtx) @@ -159,12 +159,12 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) // Create and set initial SLO context - sloCtx := NewSLORequestContext(request) - sloCtx.AvgTPOTSLO = 50 + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 router.setSLOContextForRequest(request, sloCtx) // Initialize the request priority queue - router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() beforeTime := time.Now() router.PreRequest(ctx, request, schedulingResult) @@ -173,12 +173,12 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { // Verify SLO context was updated retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.Equal(t, pod.GetPod(), retrievedCtx.TargetPod) - assert.Equal(t, schedulingResult, retrievedCtx.SchedulingResult) - assert.True(t, retrievedCtx.RequestReceivedTimestamp.After(beforeTime) || - retrievedCtx.RequestReceivedTimestamp.Equal(beforeTime)) - assert.True(t, retrievedCtx.RequestReceivedTimestamp.Before(afterTime) || - retrievedCtx.RequestReceivedTimestamp.Equal(afterTime)) + assert.Equal(t, pod.GetPod(), retrievedCtx.targetPod) + assert.Equal(t, schedulingResult, retrievedCtx.schedulingResult) + assert.True(t, retrievedCtx.requestReceivedTimestamp.After(beforeTime) || + retrievedCtx.requestReceivedTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.requestReceivedTimestamp.Before(afterTime) || + retrievedCtx.requestReceivedTimestamp.Equal(afterTime)) } func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { @@ -192,8 +192,8 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) // Create and set initial SLO context - sloCtx := NewSLORequestContext(request) - sloCtx.AvgTPOTSLO = 50 + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 router.setSLOContextForRequest(request, sloCtx) // PreRequest should create the queue @@ -217,12 +217,12 @@ func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) // Create and set initial SLO contexts - sloCtx1 := NewSLORequestContext(request1) - sloCtx1.AvgTPOTSLO = 50 + sloCtx1 := newSLORequestContext(request1) + sloCtx1.avgTPOTSLO = 50 router.setSLOContextForRequest(request1, sloCtx1) - sloCtx2 := NewSLORequestContext(request2) - sloCtx2.AvgTPOTSLO = 50 + sloCtx2 := newSLORequestContext(request2) + sloCtx2.avgTPOTSLO = 50 router.setSLOContextForRequest(request2, sloCtx2) // Add first request @@ -246,7 +246,7 @@ func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) router.setSLOContextForRequest(request, sloCtx) // Should not panic and should return early @@ -266,7 +266,7 @@ func TestSLOAwareRouter_ResponseReceived_NoPod(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) router.setSLOContextForRequest(request, sloCtx) // Should not panic with nil pod @@ -302,7 +302,7 @@ func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) router.setSLOContextForRequest(request, sloCtx) // Should not panic and should return early @@ -323,22 +323,22 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { response := &requestcontrol.Response{} schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) - sloCtx := NewSLORequestContext(request) - sloCtx.RequestReceivedTimestamp = time.Now() - sloCtx.SchedulingResult = schedulingResult - sloCtx.SchedulingRequest = *request - sloCtx.TTFTSLO = 100 - sloCtx.AvgTPOTSLO = 50 - sloCtx.IncomingModelName = "test-model" - sloCtx.PredictedTTFT = 80.0 - sloCtx.AvgPredictedTPOT = 30.0 + sloCtx := newSLORequestContext(request) + sloCtx.requestReceivedTimestamp = time.Now() + sloCtx.schedulingResult = schedulingResult + sloCtx.schedulingRequest = *request + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "test-model" + sloCtx.predictedTTFT = 80.0 + sloCtx.avgPredictedTPOT = 30.0 // ADD THIS - populate metrics - sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + sloCtx.lastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ KVCacheUsagePercent: 0.5, WaitingQueueSize: 1, RunningQueueSize: 1, } - sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + sloCtx.lastSeenMetrics["default"] = &backendmetrics.MetricsState{ KVCacheUsagePercent: 0.5, WaitingQueueSize: 1, RunningQueueSize: 1, @@ -346,7 +346,7 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Initialize the queue and add the request - queue := NewRequestPriorityQueue() + queue := newRequestPriorityQueue() queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) router.runningRequestLists[pod.GetPod().NamespacedName] = queue @@ -357,10 +357,10 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { // Verify first token timestamp was set retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.True(t, retrievedCtx.LastTokenTimestamp.After(beforeTime) || - retrievedCtx.LastTokenTimestamp.Equal(beforeTime)) - assert.True(t, retrievedCtx.LastTokenTimestamp.Before(afterTime) || - retrievedCtx.LastTokenTimestamp.Equal(afterTime)) + assert.True(t, retrievedCtx.lastTokenTimestamp.After(beforeTime) || + retrievedCtx.lastTokenTimestamp.Equal(beforeTime)) + assert.True(t, retrievedCtx.lastTokenTimestamp.Before(afterTime) || + retrievedCtx.lastTokenTimestamp.Equal(afterTime)) } func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { @@ -374,22 +374,22 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { response := &requestcontrol.Response{} schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) - sloCtx := NewSLORequestContext(request) - sloCtx.RequestReceivedTimestamp = time.Now() - sloCtx.SchedulingResult = schedulingResult - sloCtx.SchedulingRequest = *request - sloCtx.TTFTSLO = 100 - sloCtx.AvgTPOTSLO = 50 - sloCtx.IncomingModelName = "test-model" - sloCtx.PredictedTTFT = 80.0 - sloCtx.AvgPredictedTPOT = 30.0 + sloCtx := newSLORequestContext(request) + sloCtx.requestReceivedTimestamp = time.Now() + sloCtx.schedulingResult = schedulingResult + sloCtx.schedulingRequest = *request + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "test-model" + sloCtx.predictedTTFT = 80.0 + sloCtx.avgPredictedTPOT = 30.0 // ADD THIS - populate metrics - sloCtx.LastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ + sloCtx.lastSeenMetrics["prefill"] = &backendmetrics.MetricsState{ KVCacheUsagePercent: 0.5, WaitingQueueSize: 1, RunningQueueSize: 1, } - sloCtx.LastSeenMetrics["default"] = &backendmetrics.MetricsState{ + sloCtx.lastSeenMetrics["default"] = &backendmetrics.MetricsState{ KVCacheUsagePercent: 0.5, WaitingQueueSize: 1, RunningQueueSize: 1, @@ -399,7 +399,7 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Initialize the queue and add the request - queue := NewRequestPriorityQueue() + queue := newRequestPriorityQueue() queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) router.runningRequestLists[pod.GetPod().NamespacedName] = queue @@ -408,7 +408,7 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { // Verify token timestamp was updated retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.True(t, retrievedCtx.LastTokenTimestamp.After(firstTokenTime)) + assert.True(t, retrievedCtx.lastTokenTimestamp.After(firstTokenTime)) } func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { @@ -421,13 +421,13 @@ func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) - sloCtx.IncomingModelName = "test-model" - sloCtx.TargetPod = pod.GetPod() // ADD THIS to avoid other issues + sloCtx := newSLORequestContext(request) + sloCtx.incomingModelName = "test-model" + sloCtx.targetPod = pod.GetPod() // ADD THIS to avoid other issues router.setSLOContextForRequest(request, sloCtx) // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior - router.runningRequestLists[pod.GetPod().NamespacedName] = NewRequestPriorityQueue() + router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() // Should handle gracefully when request is not in queue router.ResponseComplete(ctx, request, response, pod.GetPod()) @@ -464,18 +464,18 @@ func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { response := &requestcontrol.Response{} // Create queue and add request - queue := NewRequestPriorityQueue() + queue := newRequestPriorityQueue() router.runningRequestLists[pod.GetPod().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) - sloCtx := NewSLORequestContext(request) - sloCtx.TTFT = 80 - sloCtx.AvgTPOT = 30 - sloCtx.PredictedTTFT = 85 - sloCtx.AvgPredictedTPOT = 32 - sloCtx.TTFTSLO = 100 - sloCtx.AvgTPOTSLO = 50 - sloCtx.IncomingModelName = "incoming-model" + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.predictedTTFT = 85 + sloCtx.avgPredictedTPOT = 32 + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "incoming-model" router.setSLOContextForRequest(request, sloCtx) router.ResponseComplete(ctx, request, response, pod.GetPod()) @@ -497,7 +497,7 @@ func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) router.setSLOContextForRequest(request, sloCtx) // Should not panic @@ -517,7 +517,7 @@ func TestSLOAwareRouter_ResponseComplete_NoPod(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) router.setSLOContextForRequest(request, sloCtx) // Should not panic with nil pod @@ -556,18 +556,18 @@ func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { response := &requestcontrol.Response{} // Create queue - queue := NewRequestPriorityQueue() + queue := newRequestPriorityQueue() router.runningRequestLists[pod.GetPod().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) - sloCtx := NewSLORequestContext(request) - sloCtx.TTFT = 80 - sloCtx.AvgTPOT = 30 - sloCtx.PredictedTTFT = 85 - sloCtx.AvgPredictedTPOT = 32 - sloCtx.TTFTSLO = 100 - sloCtx.AvgTPOTSLO = 50 - sloCtx.IncomingModelName = "incoming-model" + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.predictedTTFT = 85 + sloCtx.avgPredictedTPOT = 32 + sloCtx.ttftSLO = 100 + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "incoming-model" router.setSLOContextForRequest(request, sloCtx) // Should record metrics without panicking @@ -589,14 +589,14 @@ func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { response := &requestcontrol.Response{} // Create queue - queue := NewRequestPriorityQueue() + queue := newRequestPriorityQueue() router.runningRequestLists[pod.GetPod().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) - sloCtx := NewSLORequestContext(request) - sloCtx.TTFT = 80 - sloCtx.AvgTPOT = 30 - sloCtx.IncomingModelName = "test-model" + sloCtx := newSLORequestContext(request) + sloCtx.ttft = 80 + sloCtx.avgTPOT = 30 + sloCtx.incomingModelName = "test-model" router.setSLOContextForRequest(request, sloCtx) // Should handle missing SLOs gracefully @@ -611,7 +611,7 @@ func TestSLOAwareRouter_CheckPredictor_NilPod(t *testing.T) { router := createTestRouter() logger := logr.Discard() - result := router.CheckPredictor(logger, nil) + result := router.checkPredictor(logger, nil) assert.False(t, result) } @@ -622,7 +622,7 @@ func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { logger := logr.Discard() pod := createTestPod("test-pod", 1, 1, 1) - result := router.CheckPredictor(logger, pod.GetPod()) + result := router.checkPredictor(logger, pod.GetPod()) assert.False(t, result) } @@ -634,74 +634,74 @@ func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { logger := logr.Discard() pod := createTestPod("test-pod", 1, 1, 1) - result := router.CheckPredictor(logger, pod.GetPod()) + result := router.checkPredictor(logger, pod.GetPod()) assert.True(t, result) } func TestSLORequestContext_Fields(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) - ctx := NewSLORequestContext(request) + ctx := newSLORequestContext(request) // Test all field initialization - assert.NotNil(t, ctx.LastSeenMetrics) - assert.NotNil(t, ctx.PrefixCacheScoresForPods) - assert.NotNil(t, ctx.PredictedTTFTForScheduling) - assert.NotNil(t, ctx.PredictedTPOTForScheduling) - assert.Empty(t, ctx.TPOTObservations) - assert.Empty(t, ctx.PredictedTPOTObservations) - assert.Zero(t, ctx.GeneratedTokenCount) - assert.Zero(t, ctx.TTFT) - assert.Zero(t, ctx.AvgTPOT) - assert.Nil(t, ctx.TargetPod) - assert.Nil(t, ctx.SchedulingResult) - assert.Nil(t, ctx.TokenSampler) + assert.NotNil(t, ctx.lastSeenMetrics) + assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.predictedTTFTForScheduling) + assert.NotNil(t, ctx.predictedTPOTForScheduling) + assert.Empty(t, ctx.tpotObservations) + assert.Empty(t, ctx.predictedTPOTObservations) + assert.Zero(t, ctx.generatedTokenCount) + assert.Zero(t, ctx.ttft) + assert.Zero(t, ctx.avgTPOT) + assert.Nil(t, ctx.targetPod) + assert.Nil(t, ctx.schedulingResult) + assert.Nil(t, ctx.tokenSampler) } func TestSLORequestContext_UpdateMetrics(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) - ctx := NewSLORequestContext(request) + ctx := newSLORequestContext(request) // Add some metrics metricsState := &backendmetrics.MetricsState{ KVCacheUsagePercent: 0.5, WaitingQueueSize: 3, } - ctx.LastSeenMetrics["test-pod"] = metricsState + ctx.lastSeenMetrics["test-pod"] = metricsState - assert.Len(t, ctx.LastSeenMetrics, 1) - assert.Equal(t, 0.5, ctx.LastSeenMetrics["test-pod"].KVCacheUsagePercent) - assert.Equal(t, 3, ctx.LastSeenMetrics["test-pod"].WaitingQueueSize) + assert.Len(t, ctx.lastSeenMetrics, 1) + assert.Equal(t, 0.5, ctx.lastSeenMetrics["test-pod"].KVCacheUsagePercent) + assert.Equal(t, 3, ctx.lastSeenMetrics["test-pod"].WaitingQueueSize) } func TestSLORequestContext_PredictionData(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) - ctx := NewSLORequestContext(request) + ctx := newSLORequestContext(request) // Set prediction data - ctx.PredictedTTFTForScheduling["pod1"] = 80.0 - ctx.PredictedTPOTForScheduling["pod1"] = 30.0 - ctx.PredictedTTFTForScheduling["pod2"] = 90.0 - ctx.PredictedTPOTForScheduling["pod2"] = 35.0 + ctx.predictedTTFTForScheduling["pod1"] = 80.0 + ctx.predictedTPOTForScheduling["pod1"] = 30.0 + ctx.predictedTTFTForScheduling["pod2"] = 90.0 + ctx.predictedTPOTForScheduling["pod2"] = 35.0 - assert.Len(t, ctx.PredictedTTFTForScheduling, 2) - assert.Len(t, ctx.PredictedTPOTForScheduling, 2) - assert.Equal(t, 80.0, ctx.PredictedTTFTForScheduling["pod1"]) - assert.Equal(t, 30.0, ctx.PredictedTPOTForScheduling["pod1"]) + assert.Len(t, ctx.predictedTTFTForScheduling, 2) + assert.Len(t, ctx.predictedTPOTForScheduling, 2) + assert.Equal(t, 80.0, ctx.predictedTTFTForScheduling["pod1"]) + assert.Equal(t, 30.0, ctx.predictedTPOTForScheduling["pod1"]) } func TestSLORequestContext_PrefixCacheScores(t *testing.T) { request := createTestLLMRequest("test", 100, 50, true) - ctx := NewSLORequestContext(request) + ctx := newSLORequestContext(request) // Set prefix cache scores - ctx.PrefixCacheScoresForPods["pod1"] = 0.8 - ctx.PrefixCacheScoresForPods["pod2"] = 0.6 - ctx.PrefixCacheScoresForPods["pod3"] = 0.9 + ctx.prefixCacheScoresForPods["pod1"] = 0.8 + ctx.prefixCacheScoresForPods["pod2"] = 0.6 + ctx.prefixCacheScoresForPods["pod3"] = 0.9 - assert.Len(t, ctx.PrefixCacheScoresForPods, 3) - assert.Equal(t, 0.8, ctx.PrefixCacheScoresForPods["pod1"]) - assert.Equal(t, 0.9, ctx.PrefixCacheScoresForPods["pod3"]) + assert.Len(t, ctx.prefixCacheScoresForPods, 3) + assert.Equal(t, 0.8, ctx.prefixCacheScoresForPods["pod1"]) + assert.Equal(t, 0.9, ctx.prefixCacheScoresForPods["pod3"]) } func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { @@ -718,7 +718,7 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { requestID := uuid.New().String() request := createTestLLMRequest(requestID, 100, 50, true) - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) // Set context router.setSLOContextForRequest(request, sloCtx) @@ -752,8 +752,8 @@ func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { // Create and set SLO contexts for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { - sloCtx := NewSLORequestContext(req) - sloCtx.AvgTPOTSLO = 50 + sloCtx := newSLORequestContext(req) + sloCtx.avgTPOTSLO = 50 router.setSLOContextForRequest(req, sloCtx) } @@ -780,9 +780,9 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) // Create initial context - sloCtx := NewSLORequestContext(request) - sloCtx.AvgTPOTSLO = 50 - sloCtx.IncomingModelName = "test-model" + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 + sloCtx.incomingModelName = "test-model" router.setSLOContextForRequest(request, sloCtx) // 1. PreRequest @@ -791,7 +791,7 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { // Verify context exists retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.NotNil(t, retrievedCtx.TargetPod) + assert.NotNil(t, retrievedCtx.targetPod) // 2. ResponseReceived router.ResponseReceived(ctx, request, response, pod.GetPod()) @@ -801,14 +801,14 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { // 4. ResponseStreaming (subsequent tokens) retrievedCtx, _ = router.getSLOContextForRequest(request) - retrievedCtx.TTFT = 100 // Mark first token received + retrievedCtx.ttft = 100 // Mark first token received router.setSLOContextForRequest(request, retrievedCtx) router.ResponseStreaming(ctx, request, response, pod.GetPod()) // 5. ResponseComplete retrievedCtx, _ = router.getSLOContextForRequest(request) - retrievedCtx.TTFT = 80 - retrievedCtx.AvgTPOT = 30 + retrievedCtx.ttft = 80 + retrievedCtx.avgTPOT = 30 router.setSLOContextForRequest(request, retrievedCtx) router.ResponseComplete(ctx, request, response, pod.GetPod()) @@ -834,12 +834,12 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { schedulingResult2 := createTestSchedulingResult(pod2.GetPod(), 1, 1, 1) // Create and set SLO contexts - sloCtx1 := NewSLORequestContext(request1) - sloCtx1.AvgTPOTSLO = 50 + sloCtx1 := newSLORequestContext(request1) + sloCtx1.avgTPOTSLO = 50 router.setSLOContextForRequest(request1, sloCtx1) - sloCtx2 := NewSLORequestContext(request2) - sloCtx2.AvgTPOTSLO = 50 + sloCtx2 := newSLORequestContext(request2) + sloCtx2.avgTPOTSLO = 50 router.setSLOContextForRequest(request2, sloCtx2) // Add requests to different pods @@ -893,11 +893,11 @@ func TestSLORequestContext_SLOValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { request := createTestLLMRequest("test-id", tt.ttftSLO, tt.tpotSLO, true) - ctx := NewSLORequestContext(request) - ctx.TTFTSLO = tt.ttftSLO - ctx.AvgTPOTSLO = tt.tpotSLO + ctx := newSLORequestContext(request) + ctx.ttftSLO = tt.ttftSLO + ctx.avgTPOTSLO = tt.tpotSLO - hasBothSLOs := ctx.TTFTSLO > 0 && ctx.AvgTPOTSLO > 0 + hasBothSLOs := ctx.ttftSLO > 0 && ctx.avgTPOTSLO > 0 assert.Equal(t, tt.expectSLOs, hasBothSLOs) }) } @@ -915,8 +915,8 @@ func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { for i := 0; i < b.N; i++ { requestID := uuid.New().String() request := createTestLLMRequest(requestID, 100, 50, true) - sloCtx := NewSLORequestContext(request) - sloCtx.AvgTPOTSLO = 50 + sloCtx := newSLORequestContext(request) + sloCtx.avgTPOTSLO = 50 router.setSLOContextForRequest(request, sloCtx) router.PreRequest(ctx, request, schedulingResult) } @@ -925,7 +925,7 @@ func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { func BenchmarkSLOAwareRouter_ContextOperations(b *testing.B) { router := createTestRouter() request := createTestLLMRequest("test", 100, 50, true) - sloCtx := NewSLORequestContext(request) + sloCtx := newSLORequestContext(request) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -940,6 +940,6 @@ func BenchmarkSLORequestContext_Creation(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = NewSLORequestContext(request) + _ = newSLORequestContext(request) } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go index ce1e997b07..37017fbdcb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go @@ -24,83 +24,83 @@ import ( "sync" ) -// Request represents an element in the priority queue. +// request represents an element in the priority queue. // The index is needed by heap.Remove and is maintained by the heap.Interface methods. -type Request struct { - ID string // Unique identifier - TPOT float64 // The priority value (lower is higher priority) +type request struct { + id string // Unique identifier + tpot float64 // The priority value (lower is higher priority) index int } -// RequestPriorityQueue implements a priority queue with item removal by ID. -type RequestPriorityQueue struct { - items []*Request - lookup map[string]*Request +// requestPriorityQueue implements a priority queue with item removal by ID. +type requestPriorityQueue struct { + items []*request + lookup map[string]*request mutex sync.RWMutex } -// NewRequestPriorityQueue initializes and returns a new PriorityQueue. -func NewRequestPriorityQueue() *RequestPriorityQueue { - return &RequestPriorityQueue{ - lookup: make(map[string]*Request), - items: []*Request{}, +// newRequestPriorityQueue initializes and returns a new PriorityQueue. +func newRequestPriorityQueue() *requestPriorityQueue { + return &requestPriorityQueue{ + lookup: make(map[string]*request), + items: []*request{}, } } // Clone creates a deep copy of the priority queue. // The new queue is completely independent of the original. -func (pq *RequestPriorityQueue) Clone() *RequestPriorityQueue { +func (pq *requestPriorityQueue) Clone() *requestPriorityQueue { pq.mutex.RLock() defer pq.mutex.RUnlock() // Initialize a new priority queue with pre-allocated capacity. - clonedPq := &RequestPriorityQueue{ - items: make([]*Request, len(pq.items)), - lookup: make(map[string]*Request, len(pq.lookup)), + clonedPq := &requestPriorityQueue{ + items: make([]*request, len(pq.items)), + lookup: make(map[string]*request, len(pq.lookup)), } // Iterate through the original items to create deep copies. for i, oldItem := range pq.items { // Create a new Request struct, copying all values. - newItem := &Request{ - ID: oldItem.ID, - TPOT: oldItem.TPOT, + newItem := &request{ + id: oldItem.id, + tpot: oldItem.tpot, index: oldItem.index, } // Assign the new item to the cloned queue's items slice. clonedPq.items[i] = newItem // Update the lookup map in the cloned queue to point to the new item. - clonedPq.lookup[newItem.ID] = newItem + clonedPq.lookup[newItem.id] = newItem } return clonedPq } // Len is the number of items in the queue. -func (pq *RequestPriorityQueue) Len() int { return len(pq.items) } +func (pq *requestPriorityQueue) Len() int { return len(pq.items) } // Less reports whether the item with index i should sort before the item with index j. -func (pq *RequestPriorityQueue) Less(i, j int) bool { - return pq.items[i].TPOT < pq.items[j].TPOT +func (pq *requestPriorityQueue) Less(i, j int) bool { + return pq.items[i].tpot < pq.items[j].tpot } // Swap swaps the items with indexes i and j. -func (pq *RequestPriorityQueue) Swap(i, j int) { +func (pq *requestPriorityQueue) Swap(i, j int) { pq.items[i], pq.items[j] = pq.items[j], pq.items[i] pq.items[i].index = i pq.items[j].index = j } // Push adds an item to the heap. -func (pq *RequestPriorityQueue) Push(x any) { - item := x.(*Request) +func (pq *requestPriorityQueue) Push(x any) { + item := x.(*request) item.index = len(pq.items) pq.items = append(pq.items, item) } // Pop removes and returns the minimum item from the heap. -func (pq *RequestPriorityQueue) Pop() any { +func (pq *requestPriorityQueue) Pop() any { n := len(pq.items) item := pq.items[n-1] pq.items[n-1] = nil // avoid memory leak @@ -111,7 +111,7 @@ func (pq *RequestPriorityQueue) Pop() any { // Add adds a new item to the queue. // Returns true if the item was added, false if an item with the same ID already exists. -func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { +func (pq *requestPriorityQueue) Add(id string, tpot float64) bool { pq.mutex.Lock() defer pq.mutex.Unlock() @@ -128,9 +128,9 @@ func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { return false } - item := &Request{ - ID: id, - TPOT: tpot, + item := &request{ + id: id, + tpot: tpot, } pq.lookup[id] = item heap.Push(pq, item) @@ -139,7 +139,7 @@ func (pq *RequestPriorityQueue) Add(id string, tpot float64) bool { // Update modifies the TPOT value of an existing item in the queue. // If the item doesn't exist, this method does nothing. -func (pq *RequestPriorityQueue) Update(id string, tpot float64) bool { +func (pq *requestPriorityQueue) Update(id string, tpot float64) bool { pq.mutex.Lock() defer pq.mutex.Unlock() @@ -153,13 +153,13 @@ func (pq *RequestPriorityQueue) Update(id string, tpot float64) bool { return false } - item.TPOT = tpot + item.tpot = tpot heap.Fix(pq, item.index) return true } // Remove removes an item from the queue by its ID. -func (pq *RequestPriorityQueue) Remove(id string) (*Request, bool) { +func (pq *requestPriorityQueue) Remove(id string) (*request, bool) { pq.mutex.Lock() defer pq.mutex.Unlock() @@ -167,13 +167,13 @@ func (pq *RequestPriorityQueue) Remove(id string) (*Request, bool) { if !ok { return nil, false } - removed := heap.Remove(pq, item.index).(*Request) + removed := heap.Remove(pq, item.index).(*request) delete(pq.lookup, id) return removed, true } // Peek returns the item with the lowest value without removing it. -func (pq *RequestPriorityQueue) Peek() *Request { +func (pq *requestPriorityQueue) Peek() *request { pq.mutex.RLock() defer pq.mutex.RUnlock() @@ -184,14 +184,14 @@ func (pq *RequestPriorityQueue) Peek() *Request { } // GetSize returns the current number of items in the queue. -func (pq *RequestPriorityQueue) GetSize() int { +func (pq *requestPriorityQueue) GetSize() int { pq.mutex.RLock() defer pq.mutex.RUnlock() return len(pq.items) } // Contains checks if an item with the given ID exists in the queue. -func (pq *RequestPriorityQueue) Contains(id string) bool { +func (pq *requestPriorityQueue) Contains(id string) bool { pq.mutex.RLock() defer pq.mutex.RUnlock() _, exists := pq.lookup[id] @@ -200,24 +200,24 @@ func (pq *RequestPriorityQueue) Contains(id string) bool { // ToSlice returns a copy of all items in the queue, sorted by ID for stable comparison. // This is primarily intended for testing and validation. -func (pq *RequestPriorityQueue) ToSlice() []*Request { +func (pq *requestPriorityQueue) ToSlice() []*request { pq.mutex.RLock() defer pq.mutex.RUnlock() // Create a copy to avoid returning a reference to the internal slice. - itemsCopy := make([]*Request, len(pq.items)) + itemsCopy := make([]*request, len(pq.items)) copy(itemsCopy, pq.items) // Sort by ID to have a deterministic order for comparison in tests. sort.Slice(itemsCopy, func(i, j int) bool { - return itemsCopy[i].ID < itemsCopy[j].ID + return itemsCopy[i].id < itemsCopy[j].id }) return itemsCopy } // String returns a string representation of the queue for debugging. -func (pq *RequestPriorityQueue) String() string { +func (pq *requestPriorityQueue) String() string { pq.mutex.RLock() defer pq.mutex.RUnlock() @@ -232,9 +232,9 @@ func (pq *RequestPriorityQueue) String() string { if i > 0 { builder.WriteString(", ") } - builder.WriteString(item.ID) + builder.WriteString(item.id) builder.WriteString("(") - builder.WriteString(fmt.Sprintf("%.2f", item.TPOT)) + builder.WriteString(fmt.Sprintf("%.2f", item.tpot)) builder.WriteString(")") } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go index a8eba5fe1c..ef34c84b50 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go @@ -24,7 +24,7 @@ import ( ) func TestNewRequestPriorityQueue(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() if pq == nil { t.Fatal("NewRequestPriorityQueue returned nil") @@ -40,7 +40,7 @@ func TestNewRequestPriorityQueue(t *testing.T) { } func TestAdd(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test successful add if !pq.Add("req1", 2.5) { @@ -71,7 +71,7 @@ func TestAdd(t *testing.T) { } func TestPriorityOrdering(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Add items with different priorities pq.Add("high", 1.0) // highest priority (lowest TPOT) @@ -80,7 +80,7 @@ func TestPriorityOrdering(t *testing.T) { // Check that highest priority item is at the top peek := pq.Peek() - if peek == nil || peek.ID != "high" || peek.TPOT != 1.0 { + if peek == nil || peek.id != "high" || peek.tpot != 1.0 { t.Errorf("Expected high priority item at top, got %+v", peek) } @@ -96,19 +96,19 @@ func TestPriorityOrdering(t *testing.T) { for _, exp := range expected { item := pq.Peek() - if item.ID != exp.id || item.TPOT != exp.tpot { - t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.ID, item.TPOT) + if item.id != exp.id || item.tpot != exp.tpot { + t.Errorf("Expected %s(%.1f), got %s(%.1f)", exp.id, exp.tpot, item.id, item.tpot) } - removed, ok := pq.Remove(item.ID) - if !ok || removed.ID != exp.id { + removed, ok := pq.Remove(item.id) + if !ok || removed.id != exp.id { t.Errorf("Failed to remove %s", exp.id) } } } func TestRemove(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test remove from empty queue if _, ok := pq.Remove("nonexistent"); ok { @@ -122,7 +122,7 @@ func TestRemove(t *testing.T) { // Test successful remove removed, ok := pq.Remove("req2") - if !ok || removed.ID != "req2" || removed.TPOT != 2.0 { + if !ok || removed.id != "req2" || removed.tpot != 2.0 { t.Errorf("Expected to remove req2(2.0), got %+v, ok=%v", removed, ok) } @@ -136,13 +136,13 @@ func TestRemove(t *testing.T) { } // Verify remaining items are still in correct order - if peek := pq.Peek(); peek.ID != "req1" { - t.Errorf("Expected req1 at top, got %s", peek.ID) + if peek := pq.Peek(); peek.id != "req1" { + t.Errorf("Expected req1 at top, got %s", peek.id) } } func TestUpdate(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test update nonexistent item if pq.Update("nonexistent", 1.0) { @@ -160,8 +160,8 @@ func TestUpdate(t *testing.T) { } // Check that req3 is now at the top - if peek := pq.Peek(); peek.ID != "req3" || peek.TPOT != 0.5 { - t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.ID, peek.TPOT) + if peek := pq.Peek(); peek.id != "req3" || peek.tpot != 0.5 { + t.Errorf("Expected req3(0.5) at top, got %s(%.1f)", peek.id, peek.tpot) } // Test validation @@ -171,7 +171,7 @@ func TestUpdate(t *testing.T) { } func TestContains(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test empty queue if pq.Contains("req1") { @@ -199,7 +199,7 @@ func TestContains(t *testing.T) { } func TestClone(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test clone of empty queue clone := pq.Clone() @@ -241,7 +241,7 @@ func TestClone(t *testing.T) { } func TestString(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Test empty queue str := pq.String() @@ -262,7 +262,7 @@ func TestString(t *testing.T) { } func TestConcurrency(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() const numWorkers = 10 const itemsPerWorker = 100 @@ -304,7 +304,7 @@ func TestConcurrency(t *testing.T) { } func TestLargeQueue(t *testing.T) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() const numItems = 10000 // Add many items @@ -322,11 +322,11 @@ func TestLargeQueue(t *testing.T) { lastTPOT := -1.0 for i := 0; i < numItems; i++ { item := pq.Peek() - if item.TPOT < lastTPOT { - t.Errorf("Priority order violated: %.1f < %.1f", item.TPOT, lastTPOT) + if item.tpot < lastTPOT { + t.Errorf("Priority order violated: %.1f < %.1f", item.tpot, lastTPOT) } - lastTPOT = item.TPOT - pq.Remove(item.ID) + lastTPOT = item.tpot + pq.Remove(item.id) } if pq.GetSize() != 0 { @@ -335,7 +335,7 @@ func TestLargeQueue(t *testing.T) { } func BenchmarkAdd(b *testing.B) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -345,7 +345,7 @@ func BenchmarkAdd(b *testing.B) { } func BenchmarkPeek(b *testing.B) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Pre-populate queue for i := 0; i < 1000; i++ { @@ -359,7 +359,7 @@ func BenchmarkPeek(b *testing.B) { } func BenchmarkRemove(b *testing.B) { - pq := NewRequestPriorityQueue() + pq := newRequestPriorityQueue() // Pre-populate queue for i := 0; i < b.N; i++ { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go index bdeca30378..cd021d36a5 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -23,9 +23,9 @@ import ( "time" ) -// TokenSampler handles Poisson-distributed sampling for predictions only +// tokenSampler handles Poisson-distributed sampling for predictions only // Training happens on every token regardless of sampling -type TokenSampler struct { +type tokenSampler struct { rng *rand.Rand nextSampleToken int samplingMean float64 @@ -33,22 +33,22 @@ type TokenSampler struct { sampleCount int } -// SetSamplingMean sets the sampling mean (lambda) for the Poisson distribution -func (ts *TokenSampler) SetSamplingMean(mean float64) { +// setSamplingMean sets the sampling mean (lambda) for the Poisson distribution +func (ts *tokenSampler) setSamplingMean(mean float64) { ts.samplingMean = mean } -// SetMaxSamples sets the maximum number of samples -func (ts *TokenSampler) SetMaxSamples(max int) { +// setMaxSamples sets the maximum number of samples +func (ts *tokenSampler) setMaxSamples(max int) { ts.maxSamples = max } -// SetSampleCount sets the current number of predictions made -func (ts *TokenSampler) SetSampleCount(count int) { +// setSampleCount sets the current number of predictions made +func (ts *tokenSampler) setSampleCount(count int) { ts.sampleCount = count } -func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *TokenSampler { +func newTokenSampler(requestID string, samplingMean float64, maxSamples int) *tokenSampler { // Use request ID hash as seed for reproducibility seed := int64(0) if requestID != "" { @@ -60,7 +60,7 @@ func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *To seed = time.Now().UnixNano() } - sampler := &TokenSampler{ + sampler := &tokenSampler{ rng: rand.New(rand.NewSource(seed)), samplingMean: samplingMean, maxSamples: maxSamples, @@ -73,7 +73,7 @@ func NewTokenSampler(requestID string, samplingMean float64, maxSamples int) *To } // poissonNext generates the next interval using Poisson distribution -func (ts *TokenSampler) poissonNext() int { +func (ts *tokenSampler) poissonNext() int { lambda := ts.samplingMean if lambda <= 0 { return 1 @@ -101,13 +101,13 @@ func (ts *TokenSampler) poissonNext() int { return interval } -// ShouldPredict determines if we should make a prediction for the current token -func (ts *TokenSampler) ShouldPredict(currentToken int) bool { +// shouldPredict determines if we should make a prediction for the current token +func (ts *tokenSampler) shouldPredict(currentToken int) bool { return currentToken == ts.nextSampleToken && ts.sampleCount < ts.maxSamples } -// RecordPrediction records that a prediction was made and calculates the next sample token -func (ts *TokenSampler) RecordPrediction(currentToken int) { +// recordPrediction records that a prediction was made and calculates the next sample token +func (ts *tokenSampler) recordPrediction(currentToken int) { if ts.sampleCount >= ts.maxSamples { return } @@ -120,17 +120,17 @@ func (ts *TokenSampler) RecordPrediction(currentToken int) { } } -// GetNextSampleToken returns the next token to predict for -func (ts *TokenSampler) GetNextSampleToken() int { +// getNextSampleToken returns the next token to predict for +func (ts *tokenSampler) getNextSampleToken() int { return ts.nextSampleToken } -// SetNextSampleToken sets the next token to predict for -func (ts *TokenSampler) SetNextSampleToken(token int) { +// setNextSampleToken sets the next token to predict for +func (ts *tokenSampler) setNextSampleToken(token int) { ts.nextSampleToken = token } -// GetSampleCount returns the current number of predictions made -func (ts *TokenSampler) GetSampleCount() int { +// getSampleCount returns the current number of predictions made +func (ts *tokenSampler) getSampleCount() int { return ts.sampleCount } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index 1aa2d9fd6a..34cf36958e 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -38,18 +38,18 @@ import ( type SLOAwareRouter struct { tn plugins.TypedName latencypredictor latencypredictor.PredictorInterface - runningRequestLists map[types.NamespacedName]*RequestPriorityQueue + runningRequestLists map[types.NamespacedName]*requestPriorityQueue sloContextStore sync.Map // map[string]*SLORequestContext - headroomStrategy HeadroomStrategy + headroomStrategy headroomStrategy } var _ framework.Scorer = &SLOAwareRouter{} -func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter { +func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy headroomStrategy) *SLOAwareRouter { return &SLOAwareRouter{ tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType}, latencypredictor: latencypredictor, - runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue), + runningRequestLists: make(map[types.NamespacedName]*requestPriorityQueue), sloContextStore: sync.Map{}, headroomStrategy: strategy, } @@ -64,30 +64,20 @@ func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { return s } -// SetHeadroomStrategy allows runtime configuration of headroom selection strategy -func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) { - s.headroomStrategy = strategy -} - -// GetHeadroomStrategy returns the current headroom selection strategy -func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy { - return s.headroomStrategy -} - func (s *SLOAwareRouter) epsilonGreedyAffinityGate( ctx context.Context, - candidates []PodPredictionResult, + candidates []podPredictionResult, r *rand.Rand, label string, // e.g. "positive" or "negative" prefixStickyThreshold float64, -) ([]PodPredictionResult, bool) { +) ([]podPredictionResult, bool) { logger := log.FromContext(ctx) if prefixStickyThreshold <= 0 { // Affinity gating disabled logger.V(logutil.DEBUG).Info("Affinity gating disabled (threshold <= 0)", "path", label) return candidates, false } - eligible := make([]PodPredictionResult, 0, len(candidates)) + eligible := make([]podPredictionResult, 0, len(candidates)) for _, p := range candidates { if p.PrefixCacheScore >= prefixStickyThreshold { eligible = append(eligible, p) @@ -132,10 +122,10 @@ func (s *SLOAwareRouter) scoreWithoutPredictions( } // Build prediction results with only prefix cache scores - podResults := make([]PodPredictionResult, 0, len(pods)) + podResults := make([]podPredictionResult, 0, len(pods)) for _, pod := range pods { prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - podResults = append(podResults, PodPredictionResult{ + podResults = append(podResults, podPredictionResult{ Pod: pod, PrefixCacheScore: prefixScore, IsValid: true, // All pods are valid when we don't check predictions @@ -143,7 +133,7 @@ func (s *SLOAwareRouter) scoreWithoutPredictions( } // Select based on composite scores (prefix cache + other non-prediction metrics) - selectedPod := s.selectFromCompositeScores(ctx, podResults, r, HeadroomStrategyCompositeOnly) + selectedPod := s.selectFromCompositeScores(ctx, podResults, r, headroomStrategyCompositeOnly) if selectedPod != nil { scores[selectedPod] = 1 @@ -165,22 +155,22 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle var err error // get request slos // Get Request SLOs from request header - sloCtx.TTFTSLO, _, err = parseFloatHeader(*request, TTFTSLOHeaderKey) + sloCtx.ttftSLO, _, err = parseFloatHeader(*request, ttftSLOHeaderKey) if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TTFTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", ttftSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") } - sloCtx.AvgTPOTSLO, _, err = parseFloatHeader(*request, TPOTSLOHeaderKey) + sloCtx.avgTPOTSLO, _, err = parseFloatHeader(*request, tpotSLOHeaderKey) if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", TPOTSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") } - sloCtx.PredictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + sloCtx.predictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") if err != nil { logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") } // Check if SLOs are provided - if !sloCtx.PredictorBasedScheduling { + if !sloCtx.predictorBasedScheduling { logger.V(logutil.DEBUG).Info("PredictorBasedScheduling turned off, skipping prediction-based filtering") s.setSLOContextForRequest(request, sloCtx) return nil @@ -203,7 +193,7 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle } s.updateRequestContextWithPredictions(sloCtx, predictions) - allPreds := append([]PodPredictionResult(nil), predictions...) + allPreds := append([]podPredictionResult(nil), predictions...) allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) // Check if all pods are invalid and all have running requests @@ -223,12 +213,12 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle // Set HasValidPod to false if all pods are invalid and all have running requests if allPodsInvalid && allPodsHaveRunningRequests && !sticky { - sloCtx.HasValidPod = false + sloCtx.hasValidPod = false logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") } // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% - var posHeadroomPods, negHeadroomPods []PodPredictionResult + var posHeadroomPods, negHeadroomPods []podPredictionResult for _, p := range allPreds { // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom if p.Headroom > 0 && p.TTFTHeadroom > 0 { @@ -245,9 +235,9 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle var selectedPod schedulingtypes.Pod - if s.headroomStrategy == HeadroomStrategyCompositeOnly { + if s.headroomStrategy == headroomStrategyCompositeOnly { logger.V(logutil.DEBUG).Info("Selecting from composite scores only") - selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, HeadroomStrategyCompositeOnly) + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { // 99% chance to select from positive headroom pods, 1% from negative if r.Float64() < EpsilonExploreNeg { @@ -286,10 +276,10 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle return scores } -func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext { +func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { sloCtx, err := t.getSLOContextForRequest(request) if err != nil { - sloCtx = NewSLORequestContext(request) + sloCtx = newSLORequestContext(request) } return sloCtx } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index da073ff65a..762d62232f 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -141,7 +141,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { tests := []struct { name string predictor *mockPredictor - strategy HeadroomStrategy + strategy headroomStrategy request *schedulingtypes.LLMRequest pods []schedulingtypes.Pod expectedScores map[string]float64 // Map of pod name to expected score @@ -150,7 +150,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { { name: "Prediction-based scheduling disabled", predictor: &mockPredictor{}, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting @@ -161,7 +161,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { { name: "No predictor configured", predictor: nil, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), @@ -177,7 +177,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { "0.3": {TTFT: 0.4, TPOT: 0.02}, // 30% KV cache }, }, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), // 50% KV cache @@ -197,7 +197,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { "0.9": {TTFT: 1.8, TPOT: 0.09}, // 90% KV cache - very high load }, }, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load @@ -214,7 +214,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { "0.9": {TTFT: 1.5, TPOT: 0.08}, // 90% KV cache - Negative headroom }, }, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom @@ -228,7 +228,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { predictor: &mockPredictor{ err: fmt.Errorf("prediction failed"), }, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{ createTestPod("pod1", 0.5, 2, 1), @@ -242,7 +242,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { { name: "Empty pod list", predictor: &mockPredictor{}, - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), pods: []schedulingtypes.Pod{}, // Should return empty scores map @@ -298,27 +298,27 @@ func TestSLOAwareRouter_Score(t *testing.T) { func TestSLOAwareRouter_Strategies(t *testing.T) { tests := []struct { name string - strategy HeadroomStrategy + strategy headroomStrategy }{ { name: "HeadroomStrategyLeast", - strategy: HeadroomStrategyLeast, + strategy: headroomStrategyLeast, }, { name: "HeadroomStrategyMost", - strategy: HeadroomStrategyMost, + strategy: headroomStrategyMost, }, { name: "HeadroomStrategyCompositeMost", - strategy: HeadroomStrategyCompositeMost, + strategy: headroomStrategyCompositeMost, }, { name: "HeadroomStrategyCompositeLeast", - strategy: HeadroomStrategyCompositeLeast, + strategy: headroomStrategyCompositeLeast, }, { name: "HeadroomStrategyCompositeOnly", - strategy: HeadroomStrategyCompositeOnly, + strategy: headroomStrategyCompositeOnly, }, } @@ -356,22 +356,9 @@ func TestSLOAwareRouter_Strategies(t *testing.T) { } } -func TestSLOAwareRouter_SetHeadroomStrategy(t *testing.T) { - predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) - - assert.Equal(t, HeadroomStrategyLeast, router.GetHeadroomStrategy(), "Initial strategy should be Least") - - router.SetHeadroomStrategy(HeadroomStrategyMost) - assert.Equal(t, HeadroomStrategyMost, router.GetHeadroomStrategy(), "Strategy should be updated to Most") - - router.SetHeadroomStrategy(HeadroomStrategyCompositeOnly) - assert.Equal(t, HeadroomStrategyCompositeOnly, router.GetHeadroomStrategy(), "Strategy should be updated to CompositeOnly") -} - func TestSLOAwareRouter_TypedName(t *testing.T) { predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) tn := router.TypedName() assert.Equal(t, "slo-aware-routing", tn.Type, "Type should be slo-aware-routing") @@ -380,7 +367,7 @@ func TestSLOAwareRouter_TypedName(t *testing.T) { func TestSLOAwareRouter_WithName(t *testing.T) { predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) customName := "custom-router" router = router.WithName(customName) @@ -408,7 +395,7 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { Name: p.GetPod().NamespacedName.Name, Namespace: p.GetPod().NamespacedName.Namespace, } - r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName] = newRequestPriorityQueue() r.runningRequestLists[podName].Add("req1", 0.04) }, expectedCount: 1, @@ -420,7 +407,7 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { Name: p.GetPod().NamespacedName.Name, Namespace: p.GetPod().NamespacedName.Namespace, } - r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName] = newRequestPriorityQueue() r.runningRequestLists[podName].Add("req1", 0.04) r.runningRequestLists[podName].Add("req2", 0.03) r.runningRequestLists[podName].Add("req3", 0.05) @@ -432,7 +419,7 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) pod := createTestPod("test-pod", 0.5, 2, 1) tt.setupRequests(router, pod) @@ -461,7 +448,7 @@ func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { Name: p.GetPod().NamespacedName.Name, Namespace: p.GetPod().NamespacedName.Namespace, } - r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName] = newRequestPriorityQueue() r.runningRequestLists[podName].Add("req1", 0.04) }, expectedSLO: 0.04, @@ -473,7 +460,7 @@ func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { Name: p.GetPod().NamespacedName.Name, Namespace: p.GetPod().NamespacedName.Namespace, } - r.runningRequestLists[podName] = NewRequestPriorityQueue() + r.runningRequestLists[podName] = newRequestPriorityQueue() // Add in any order - heap will maintain minimum at top r.runningRequestLists[podName].Add("req1", 0.05) r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum @@ -486,7 +473,7 @@ func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) pod := createTestPod("test-pod", 0.5, 2, 1) tt.setupRequests(router, pod) @@ -513,7 +500,7 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { predictor := &mockPredictor{} - router := NewSLOAwareRouter(predictor, HeadroomStrategyLeast) + router := NewSLOAwareRouter(predictor, headroomStrategyLeast) state := schedulingtypes.NewCycleState() tt.setupState(state) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go index eeab50433f..682473d13d 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -30,7 +30,7 @@ import ( // selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy // Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. -func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { logger := log.FromContext(ctx) if len(posHeadroomPods) == 1 { @@ -45,10 +45,10 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos return candidates[0].Pod } switch s.headroomStrategy { - case HeadroomStrategyCompositeMost: - return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) - case HeadroomStrategyCompositeLeast: - return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeLeast) + case headroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) + case headroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeLeast) } // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] @@ -90,7 +90,7 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) // Calculate weights for weighted random selection - weightedChoices := make([]Choice, 0, len(candidates)) + weightedChoices := make([]choice, 0, len(candidates)) total := 0 for _, p := range candidates { @@ -110,18 +110,18 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos // Map to integer weights var w int switch s.headroomStrategy { - case HeadroomStrategyLeast: + case headroomStrategyLeast: // prefer smaller combined headroom (pack closer to limits) - w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 - case HeadroomStrategyMost: + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 + case headroomStrategyMost: // prefer larger combined headroom (more conservative / spread) - w = int(combined*float64(Wmax-minWeight)) + minWeight + 1 + w = int(combined*float64(wMax-minWeight)) + minWeight + 1 default: // Fallback to least - w = int((1.0-combined)*float64(Wmax-minWeight)) + minWeight + 1 + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 } - weightedChoices = append(weightedChoices, Choice{PodName: p.Pod, Weight: w}) + weightedChoices = append(weightedChoices, choice{podName: p.Pod, weight: w}) total += w logger.V(logutil.TRACE).Info("Positive headroom blended weight", @@ -137,7 +137,7 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos // selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic // Modified to strictly prefer pods with 0 running requests -func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { logger := log.FromContext(ctx) if len(negHeadroomPods) == 1 { @@ -145,7 +145,7 @@ func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, neg } // First, separate pods by running request count - var zeroRunningRequestPods, nonZeroRunningRequestPods []PodPredictionResult + var zeroRunningRequestPods, nonZeroRunningRequestPods []podPredictionResult for _, p := range negHeadroomPods { runningRequestCount := s.getPodRunningRequestCount(p.Pod) @@ -172,7 +172,7 @@ func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, neg } // selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods -func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { if len(negHeadroomPods) == 1 { return negHeadroomPods[0].Pod } @@ -186,14 +186,14 @@ func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Cont } switch s.headroomStrategy { - case HeadroomStrategyCompositeMost: - return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) - case HeadroomStrategyCompositeLeast: - return s.selectFromCompositeScores(ctx, candidates, r, HeadroomStrategyCompositeMost) + case headroomStrategyCompositeMost: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) + case headroomStrategyCompositeLeast: + return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeMost) } // Build weighted choices for selection - weightedChoices := make([]Choice, 0, len(candidates)) + weightedChoices := make([]choice, 0, len(candidates)) total := 0 s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) @@ -206,8 +206,8 @@ func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Cont // Lower blended deficit => higher weight. func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( ctx context.Context, - pods []PodPredictionResult, - choices *[]Choice, + pods []podPredictionResult, + choices *[]choice, total *int, minWeight int, alpha, beta float64, // weights for TTFT and TPOT deficits @@ -223,7 +223,7 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( // Compute raw deficits (only when headroom is negative) type deficits struct { - pod PodPredictionResult + pod podPredictionResult ttftDef float64 tpotDef float64 } @@ -293,7 +293,7 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( // Ensure a floor so no pod is completely excluded within the bucket. w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 - *choices = append(*choices, Choice{PodName: d.pod.Pod, Weight: w}) + *choices = append(*choices, choice{podName: d.pod.Pod, weight: w}) *total += w logger.V(logutil.TRACE).Info("Negative bucket blended weighting", @@ -306,15 +306,15 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( ctx context.Context, - negHeadroomPods []PodPredictionResult, - choices *[]Choice, + negHeadroomPods []podPredictionResult, + choices *[]choice, total *int, minWeightForNegative int, ) { logger := log.FromContext(ctx) // Categorize pods by their headroom status - var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []PodPredictionResult + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []podPredictionResult for _, p := range negHeadroomPods { if p.TTFTHeadroom < 0 && p.Headroom < 0 { @@ -355,7 +355,7 @@ func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( // Priority 4: edge-case bucket -> minimal weight for _, p := range nonNegTTFTNonNegTPOT { - *choices = append(*choices, Choice{PodName: p.Pod, Weight: minWeightForNegative}) + *choices = append(*choices, choice{podName: p.Pod, weight: minWeightForNegative}) *total += minWeightForNegative } } @@ -367,7 +367,7 @@ func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { } if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { if topReq := runningReqs.Peek(); topReq != nil { - return topReq.TPOT + return topReq.tpot } } return 0 // no running requests or no TPOT SLOs diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go index 8030866d80..03844543f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -19,39 +19,39 @@ package slo_aware_router import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" -type HeadroomStrategy string +type headroomStrategy string -type Choice struct { - PodName schedulingtypes.Pod - Weight int +type choice struct { + podName schedulingtypes.Pod + weight int } const ( - // HeadroomStrategyLeast prioritizes pods with least positive headroom (better packing) - HeadroomStrategyLeast HeadroomStrategy = "least" - // HeadroomStrategyMost prioritizes pods with most positive headroom (more conservative) - HeadroomStrategyMost HeadroomStrategy = "most" + // headroomStrategyLeast prioritizes pods with least positive headroom (better packing) + headroomStrategyLeast headroomStrategy = "least" + // headroomStrategyMost prioritizes pods with most positive headroom (more conservative) + headroomStrategyMost headroomStrategy = "most" - HeadroomStrategyCompositeLeast HeadroomStrategy = "composite-least" - HeadroomStrategyCompositeMost HeadroomStrategy = "composite-most" - HeadroomStrategyCompositeOnly HeadroomStrategy = "composite-only" + headroomStrategyCompositeLeast headroomStrategy = "composite-least" + headroomStrategyCompositeMost headroomStrategy = "composite-most" + headroomStrategyCompositeOnly headroomStrategy = "composite-only" // TTFT header string - TTFTSLOHeaderKey = "x-slo-ttft-ms" + ttftSLOHeaderKey = "x-slo-ttft-ms" // TPOT header string - TPOTSLOHeaderKey = "x-slo-tpot-ms" + tpotSLOHeaderKey = "x-slo-tpot-ms" ) const ( SLOAwareRouterPluginType = "slo-aware-routing" eps = 1e-9 - Wmax = 100 + wMax = 100 minWeight = 1 ) -type PodSelectionMode string +type podSelectionMode string const ( - PodSelectionLinear PodSelectionMode = "linear" // weighted-random (current behavior) - PodSelectionMax PodSelectionMode = "max" // pick argmax weight + podSelectionLinear podSelectionMode = "linear" // weighted-random (current behavior) + podSelectionMax podSelectionMode = "max" // pick argmax weight ) From 17bd1f462c0624c2d0dca0cce961f901bc5463e9 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Tue, 18 Nov 2025 03:02:34 +0000 Subject: [PATCH 5/6] Fix lints --- .../plugins/multi/slo_aware_router/headers.go | 13 +++--- .../latencypredictor_helper.go | 34 ++++++++------- .../multi/slo_aware_router/prediction.go | 5 --- .../slo_aware_router/requestcontrol_hooks.go | 5 ++- .../requestcontrol_hooks_test.go | 43 +++++++++++-------- .../plugins/multi/slo_aware_router/scorer.go | 19 ++++---- .../multi/slo_aware_router/scorer_test.go | 6 ++- .../multi/slo_aware_router/selection.go | 9 ++-- .../profile/slo_aware_profile_handler.go | 4 +- 9 files changed, 72 insertions(+), 66 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go index 8574ec41b0..da0a86f200 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/headers.go @@ -18,7 +18,6 @@ limitations under the License. package slo_aware_router import ( - "fmt" "strconv" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -27,24 +26,24 @@ import ( // parseFloatHeader retrieves a header by name, parses it as a float64, // and returns the value or an error if the header is missing or invalid. -func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, bool, error) { +func parseFloatHeader(request schedulingtypes.LLMRequest, headerName string) (float64, error) { // 1. Get header value from the map headerValue, ok := request.Headers[headerName] if !ok { - return 0, false, nil // Header not found, return 0 and false + return 0, nil // Header not found, return 0 and false } // 2. Parse the header value to a float64 parsedFloat, err := strconv.ParseFloat(headerValue, 64) if err != nil { - return 0, false, errutil.Error{ + return 0, errutil.Error{ Code: errutil.BadRequest, - Msg: fmt.Sprintf("%s must be a float", headerName), + Msg: headerName + " must be a float", } } // 3. Return the successfully parsed value - return parsedFloat, true, nil + return parsedFloat, nil } // parseFloatHeader retrieves a header by name, parses it as a bool, @@ -61,7 +60,7 @@ func parseBoolHeader(request schedulingtypes.LLMRequest, headerName string) (boo if err != nil { return false, errutil.Error{ Code: errutil.BadRequest, - Msg: fmt.Sprintf("%s must be a bool", headerName), + Msg: headerName + " must be a bool", } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index 1a50847bed..c356bc87b7 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -19,6 +19,7 @@ package slo_aware_router import ( "context" + "errors" "fmt" "strings" "time" @@ -48,9 +49,9 @@ func refreshLastSeenMetrics(ctx context.Context, sloCtx *sloRequestContext) { } // GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics. -func getLatestMetricsForProfile(ctx context.Context, sloCtx *sloRequestContext) (*backendmetrics.MetricsState, error) { +func getLatestMetricsForProfile(sloCtx *sloRequestContext) (*backendmetrics.MetricsState, error) { if len(sloCtx.lastSeenMetrics) == 0 { - return nil, fmt.Errorf("no last seen metrics available for prediction") + return nil, errors.New("no last seen metrics available for prediction") } primaryProfileName := sloCtx.schedulingResult.PrimaryProfileName @@ -69,11 +70,11 @@ func processHeaderForLatencyPrediction( ) error { logger := log.FromContext(ctx) - //just for debugging, print the req context scheduling result cycle state - //print the raw scores in scheduling result + // just for debugging, print the req context scheduling result cycle state + // print the raw scores in scheduling result // Build prediction request - m, err := getLatestMetricsForProfile(ctx, sloCtx) + m, err := getLatestMetricsForProfile(sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return err @@ -95,13 +96,14 @@ func processHeaderForLatencyPrediction( start := time.Now() p, err := predictor.Predict(ctx, in) dur := time.Since(start) - if err != nil { + switch { + case err != nil: logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds()) sloCtx.predictedTTFT = 0 - } else if p == nil { + case p == nil: logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds()) sloCtx.predictedTTFT = 0 - } else { + default: logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds()) metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) @@ -133,7 +135,7 @@ func processFirstTokenForLatencyPrediction( // Actual TTFT sloCtx.ttft = float64(now.Sub(sloCtx.requestReceivedTimestamp).Milliseconds()) sloCtx.generatedTokenCount = 1 - m, err := getLatestMetricsForProfile(ctx, sloCtx) + m, err := getLatestMetricsForProfile(sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return @@ -156,7 +158,7 @@ func processFirstTokenForLatencyPrediction( if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") } - m, err = getLatestMetricsForProfile(ctx, sloCtx) + m, err = getLatestMetricsForProfile(sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -212,13 +214,13 @@ func processTokenForLatencyPrediction( latencyMs := float64(now.Sub(sloCtx.lastTokenTimestamp).Milliseconds()) sloCtx.generatedTokenCount++ - //log the inter-token latency for predicted samples - if sloCtx.generatedTokenCount == 2 || sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { //tricky logic, since next sample token is always +1 from current token + // log the inter-token latency for predicted samples + if sloCtx.generatedTokenCount == 2 || sloCtx.tokenSampler.shouldPredict(sloCtx.generatedTokenCount) { // tricky logic, since next sample token is always +1 from current token sloCtx.tpotObservations = append(sloCtx.tpotObservations, latencyMs) sloCtx.avgTPOT = calculateRunningAverage(sloCtx.avgTPOT, latencyMs, len(sloCtx.tpotObservations)) } - m, err := getLatestMetricsForProfile(ctx, sloCtx) + m, err := getLatestMetricsForProfile(sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -285,7 +287,7 @@ func predictWithMetrics( logger := log.FromContext(ctx) if metricsState == nil { - return nil, fmt.Errorf("metrics state cannot be nil") + return nil, errors.New("metrics state cannot be nil") } // Build prediction request @@ -318,7 +320,7 @@ func predictWithMetrics( if result == nil { logger.V(logutil.DEBUG).Info("prediction returned nil", "duration_ms", duration.Milliseconds()) - return nil, fmt.Errorf("prediction returned nil result") + return nil, errors.New("prediction returned nil result") } logger.V(logutil.DEBUG).Info("prediction succeeded", @@ -392,7 +394,7 @@ func bulkPredictWithMetrics( if bulkResponse == nil { logger.V(logutil.DEBUG).Info("bulk prediction returned nil", "duration_ms", duration.Milliseconds()) - return nil, fmt.Errorf("bulk prediction returned nil result") + return nil, errors.New("bulk prediction returned nil result") } // Convert to pointer slice for consistency with single prediction diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index eaaf56608d..19f55d92f7 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -67,10 +67,6 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul predResult.TTFT = prediction.TTFT predResult.TPOT = prediction.TPOT podMinTPOTSLO := 0.0 - //if pod.GetPod().RunningRequests.Peek() != nil { - // podMinTPOTSLO = pod.GetPod().RunningRequests.Peek().TPOT - //} - // Do this: podMinTPOTSLO = s.getPodMinTPOTSLO(pod) predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) @@ -122,7 +118,6 @@ func (s *SLOAwareRouter) validatePrediction( // a podMinTPOTSLO of 0 means no either no requests, or no TPOT SLOs specified on running requests if podMinTPOTSLO > 0 { if podMinTPOTSLO < sloCtx.avgTPOTSLO { - //print debug message log.FromContext(context.Background()).V(logutil.DEBUG).Info("Pod min TPOT SLO is less than the req SLO, adjusting", "podMinTPOTSLO", podMinTPOTSLO, "bufferedTPOT", sloCtx.avgTPOTSLO) } bufferedTPOT = min(bufferedTPOT, podMinTPOTSLO*SLOBufferFactor) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index 92ef6042e0..d505b11e38 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -18,6 +18,7 @@ package slo_aware_router import ( "context" + "errors" "fmt" "time" @@ -65,7 +66,7 @@ type sloRequestContext struct { // predictorBasedScheduling indicates whether to use predictor based scheduling. predictorBasedScheduling bool - //predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + // predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. predictedTTFTForScheduling map[string]float64 // predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. predictedTPOTForScheduling map[string]float64 @@ -124,7 +125,7 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) if request.Headers[requtil.RequestIdHeaderKey] == "" { - logger.V(logutil.DEBUG).Error(fmt.Errorf("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") + logger.V(logutil.DEBUG).Error(errors.New("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") } id := request.Headers[requtil.RequestIdHeaderKey] diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index 3b297b2bbd..5aaf1a2a24 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -35,9 +35,16 @@ import ( requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) +const ( + testModelName = "test-model" + kvUsage = 1 + runningQueue = 1 + waitingQueue = 1 +) + // Helper functions -func createTestSchedulingResult(pod *backend.Pod, kvUsage float64, runningQueue int, waitingQueue int) *schedulingtypes.SchedulingResult { +func createTestSchedulingResult(pod *backend.Pod) *schedulingtypes.SchedulingResult { mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningQueue, waitingQueue) @@ -156,7 +163,7 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO context sloCtx := newSLORequestContext(request) @@ -189,7 +196,7 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO context sloCtx := newSLORequestContext(request) @@ -214,7 +221,7 @@ func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { pod := createTestPod("test-pod", 1, 1, 1) request1 := createTestLLMRequest("test-id-1", 100, 50, true) request2 := createTestLLMRequest("test-id-2", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set initial SLO contexts sloCtx1 := newSLORequestContext(request1) @@ -321,7 +328,7 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) sloCtx := newSLORequestContext(request) sloCtx.requestReceivedTimestamp = time.Now() @@ -329,7 +336,7 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { sloCtx.schedulingRequest = *request sloCtx.ttftSLO = 100 sloCtx.avgTPOTSLO = 50 - sloCtx.incomingModelName = "test-model" + sloCtx.incomingModelName = testModelName sloCtx.predictedTTFT = 80.0 sloCtx.avgPredictedTPOT = 30.0 // ADD THIS - populate metrics @@ -372,7 +379,7 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) sloCtx := newSLORequestContext(request) sloCtx.requestReceivedTimestamp = time.Now() @@ -380,7 +387,7 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { sloCtx.schedulingRequest = *request sloCtx.ttftSLO = 100 sloCtx.avgTPOTSLO = 50 - sloCtx.incomingModelName = "test-model" + sloCtx.incomingModelName = testModelName sloCtx.predictedTTFT = 80.0 sloCtx.avgPredictedTPOT = 30.0 // ADD THIS - populate metrics @@ -422,7 +429,7 @@ func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) - sloCtx.incomingModelName = "test-model" + sloCtx.incomingModelName = testModelName sloCtx.targetPod = pod.GetPod() // ADD THIS to avoid other issues router.setSLOContextForRequest(request, sloCtx) @@ -596,7 +603,7 @@ func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { sloCtx := newSLORequestContext(request) sloCtx.ttft = 80 sloCtx.avgTPOT = 30 - sloCtx.incomingModelName = "test-model" + sloCtx.incomingModelName = testModelName router.setSLOContextForRequest(request, sloCtx) // Should handle missing SLOs gracefully @@ -713,7 +720,7 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { for i := 0; i < numGoroutines; i++ { wg.Add(1) - go func(id int) { + go func() { defer wg.Done() requestID := uuid.New().String() @@ -730,7 +737,7 @@ func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { // Delete context router.deleteSLOContextForRequest(request) - }(i) + }() } wg.Wait() @@ -748,7 +755,7 @@ func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { request2 := createTestLLMRequest("test-id-2", 100, 50, true) request3 := createTestLLMRequest("test-id-3", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create and set SLO contexts for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { @@ -777,12 +784,12 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { pod := createTestPod("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) // Create initial context sloCtx := newSLORequestContext(request) sloCtx.avgTPOTSLO = 50 - sloCtx.incomingModelName = "test-model" + sloCtx.incomingModelName = testModelName router.setSLOContextForRequest(request, sloCtx) // 1. PreRequest @@ -830,8 +837,8 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { request1 := createTestLLMRequest("test-id-1", 100, 50, true) request2 := createTestLLMRequest("test-id-2", 100, 50, true) - schedulingResult1 := createTestSchedulingResult(pod1.GetPod(), 1, 1, 1) - schedulingResult2 := createTestSchedulingResult(pod2.GetPod(), 1, 1, 1) + schedulingResult1 := createTestSchedulingResult(pod1.GetPod()) + schedulingResult2 := createTestSchedulingResult(pod2.GetPod()) // Create and set SLO contexts sloCtx1 := newSLORequestContext(request1) @@ -909,7 +916,7 @@ func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { router := createTestRouter() ctx := context.Background() pod := createTestPod("test-pod", 1, 1, 1) - schedulingResult := createTestSchedulingResult(pod.GetPod(), 1, 1, 1) + schedulingResult := createTestSchedulingResult(pod.GetPod()) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index 34cf36958e..f69006d5ab 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -155,12 +155,12 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle var err error // get request slos // Get Request SLOs from request header - sloCtx.ttftSLO, _, err = parseFloatHeader(*request, ttftSLOHeaderKey) + sloCtx.ttftSLO, err = parseFloatHeader(*request, ttftSLOHeaderKey) if err != nil { logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", ttftSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") } - sloCtx.avgTPOTSLO, _, err = parseFloatHeader(*request, tpotSLOHeaderKey) + sloCtx.avgTPOTSLO, err = parseFloatHeader(*request, tpotSLOHeaderKey) if err != nil { logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") } @@ -235,10 +235,11 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle var selectedPod schedulingtypes.Pod - if s.headroomStrategy == headroomStrategyCompositeOnly { + switch { + case s.headroomStrategy == headroomStrategyCompositeOnly: logger.V(logutil.DEBUG).Info("Selecting from composite scores only") selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) - } else if len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0 { + case len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0: // 99% chance to select from positive headroom pods, 1% from negative if r.Float64() < EpsilonExploreNeg { logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") @@ -247,19 +248,19 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) } - } else if len(posHeadroomPods) > 0 { + case len(posHeadroomPods) > 0: // If only positive headroom pods exist, select from them logger.V(logutil.DEBUG).Info("Only positive headroom pods available") selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - } else if len(negHeadroomPods) > 0 { + case len(negHeadroomPods) > 0: // If only negative headroom pods exist, select from them logger.V(logutil.DEBUG).Info("Only negative headroom pods available") selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) - } else if len(allPreds) > 0 { + case len(allPreds) > 0: // fallback - select randomly from valid pods logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") selectedPod = allPreds[r.Intn(len(allPreds))].Pod - } else { + default: // No valid pods - return all zeros logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") return scores @@ -295,7 +296,7 @@ func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleSta if err != nil { // The prefix cache plugin might not be enabled, which is a valid scenario. - log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0: %v", err, "pod", pod.GetPod().String()) + log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) return 0.0 } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index 762d62232f..a15cb29ac4 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -18,7 +18,9 @@ package slo_aware_router import ( "context" + "errors" "fmt" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -125,7 +127,7 @@ func createTestLLMRequest(reqID string, ttftSLO, tpotSLO float64, predictionBase if tpotSLO > 0 { headers["x-avg-tpot-slo"] = fmt.Sprintf("%f", tpotSLO) } - headers["x-prediction-based-scheduling"] = fmt.Sprintf("%t", predictionBased) + headers["x-prediction-based-scheduling"] = strconv.FormatBool(predictionBased) return &schedulingtypes.LLMRequest{ Headers: headers, @@ -226,7 +228,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { { name: "Prediction errors - fallback to composite scoring", predictor: &mockPredictor{ - err: fmt.Errorf("prediction failed"), + err: errors.New("prediction failed"), }, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go index 682473d13d..c5278e0bcb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -317,13 +317,14 @@ func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []podPredictionResult for _, p := range negHeadroomPods { - if p.TTFTHeadroom < 0 && p.Headroom < 0 { + switch { + case p.TTFTHeadroom < 0 && p.Headroom < 0: negTTFTNegTPOT = append(negTTFTNegTPOT, p) - } else if p.TTFTHeadroom < 0 && p.Headroom >= 0 { + case p.TTFTHeadroom < 0 && p.Headroom >= 0: negTTFTNonNegTPOT = append(negTTFTNonNegTPOT, p) - } else if p.TTFTHeadroom >= 0 && p.Headroom < 0 { + case p.TTFTHeadroom >= 0 && p.Headroom < 0: nonNegTTFTNegTPOT = append(nonNegTTFTNegTPOT, p) - } else { + default: nonNegTTFTNonNegTPOT = append(nonNegTTFTNonNegTPOT, p) } } diff --git a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go index 900335c9ef..c4bd638e65 100644 --- a/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go +++ b/pkg/epp/scheduling/framework/plugins/profile/slo_aware_profile_handler.go @@ -57,9 +57,7 @@ func NewSLOAwareProfileHandler() *SLOAwareProfileHandler { // When the request has PredictorBasedScheduling=true, it uses the SLO profile result to select // the destination pod. Otherwise, it uses the default profile result. type SLOAwareProfileHandler struct { - typedName plugins.TypedName - prefixProfile string // the profile that should be executed first - + typedName plugins.TypedName } // TypedName returns the type and name tuple of this plugin instance. From 177b93284f20ec9dab19612a377f1a8b29d68215 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Tue, 18 Nov 2025 22:40:19 +0000 Subject: [PATCH 6/6] Break out larger predictor functions into helpers, switch to using bulk prediction, add bulk prediction tests --- .../latencypredictor_helper.go | 116 ++++++------------ .../latencypredictor_helper_test.go | 100 +++++++++++++++ .../multi/slo_aware_router/prediction.go | 43 ++++--- .../plugins/multi/slo_aware_router/sampler.go | 25 ---- .../plugins/multi/slo_aware_router/scorer.go | 62 +--------- .../multi/slo_aware_router/scorer_helpers.go | 102 +++++++++++++++ .../multi/slo_aware_router/selection.go | 80 +----------- .../slo_aware_router/selection_helpers.go | 114 +++++++++++++++++ 8 files changed, 390 insertions(+), 252 deletions(-) create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go create mode 100644 pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index c356bc87b7..7482de93c7 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -125,12 +125,7 @@ func processFirstTokenForLatencyPrediction( ) { logger := log.FromContext(ctx) - // Initialize sampler - if sloCtx.tokenSampler == nil { - requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] - sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) - logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) - } + initializeSampler(ctx, sloCtx) // Actual TTFT sloCtx.ttft = float64(now.Sub(sloCtx.requestReceivedTimestamp).Milliseconds()) @@ -141,8 +136,36 @@ func processFirstTokenForLatencyPrediction( return } targetPod := sloCtx.targetPod - prefix_cache_score := sloCtx.prefixCacheScoresForPods[targetPod.String()] + prefixCacheScore := sloCtx.prefixCacheScoresForPods[targetPod.String()] + + recordTTFTTrainingData(ctx, predictor, sloCtx, m, now, prefixCacheScore) + predictFirstTPOT(ctx, predictor, sloCtx) + + // Advance timestamp + sloCtx.lastTokenTimestamp = now + // Refresh metrics + refreshLastSeenMetrics(ctx, sloCtx) +} + +func initializeSampler(ctx context.Context, sloCtx *sloRequestContext) { + if sloCtx.tokenSampler == nil { + logger := log.FromContext(ctx) + requestID := sloCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey] + sloCtx.tokenSampler = newTokenSampler(requestID, DefaultSamplingMean, MaxSampledTokens) + logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.tokenSampler.getNextSampleToken()) + } +} + +func recordTTFTTrainingData( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, + m *backendmetrics.MetricsState, + now time.Time, + prefixCacheScore float64, +) { + logger := log.FromContext(ctx) // Train TTFT entry := latencypredictor.TrainingEntry{ KVCachePercentage: m.KVCacheUsagePercent, @@ -153,12 +176,20 @@ func processFirstTokenForLatencyPrediction( NumRequestWaiting: m.WaitingQueueSize, NumRequestRunning: m.RunningQueueSize, NumTokensGenerated: 0, - PrefixCacheScore: prefix_cache_score, + PrefixCacheScore: prefixCacheScore, } if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil { logger.V(logutil.DEBUG).Error(err, "record TTFT training failed") } - m, err = getLatestMetricsForProfile(sloCtx) +} + +func predictFirstTPOT( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + sloCtx *sloRequestContext, +) { + logger := log.FromContext(ctx) + m, err := getLatestMetricsForProfile(sloCtx) if err != nil { logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics", "error", err) @@ -187,11 +218,6 @@ func processFirstTokenForLatencyPrediction( sloCtx.avgPredictedTPOT = calculateRunningAverage(sloCtx.avgPredictedTPOT, p.TPOT, len(sloCtx.predictedTPOTObservations)) } metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.schedulingRequest.TargetModel, sloCtx.incomingModelName, dur.Seconds()) - - // Advance timestamp - sloCtx.lastTokenTimestamp = now - // Refresh metrics - refreshLastSeenMetrics(ctx, sloCtx) } // ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp. @@ -275,68 +301,6 @@ func processTokenForLatencyPrediction( refreshLastSeenMetrics(ctx, sloCtx) } -// predictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count. -func predictWithMetrics( - ctx context.Context, - predictor latencypredictor.PredictorInterface, - metricsState *backendmetrics.MetricsState, - prompt string, - generatedTokenCount int, - prefixcachescore float64, -) (*latencypredictor.PredictionResponse, error) { - logger := log.FromContext(ctx) - - if metricsState == nil { - return nil, errors.New("metrics state cannot be nil") - } - - // Build prediction request - in := latencypredictor.PredictionRequest{ - KVCachePercentage: metricsState.KVCacheUsagePercent, - InputTokenLength: len(strings.Fields(prompt)), - NumRequestWaiting: metricsState.WaitingQueueSize, - NumRequestRunning: metricsState.RunningQueueSize, - NumTokensGenerated: generatedTokenCount, - PrefixCacheScore: prefixcachescore, - } - - // Perform prediction - start := time.Now() - result, err := predictor.Predict(ctx, in) - duration := time.Since(start) - - if err != nil { - logger.V(logutil.DEBUG).Error(err, "prediction failed", - "duration_ms", duration.Milliseconds(), - "input_tokens", in.InputTokenLength, - "generated_tokens", generatedTokenCount, - "kv_cache_percent", in.KVCachePercentage, - "waiting_queue", in.NumRequestWaiting, - "running_queue", in.NumRequestRunning, - "prefix_cache_score", in.PrefixCacheScore) - return nil, err - } - - if result == nil { - logger.V(logutil.DEBUG).Info("prediction returned nil", - "duration_ms", duration.Milliseconds()) - return nil, errors.New("prediction returned nil result") - } - - logger.V(logutil.DEBUG).Info("prediction succeeded", - "tpot_ms", result.TPOT, - "ttft_ms", result.TTFT, - "duration_ms", duration.Milliseconds(), - "input_tokens", in.InputTokenLength, - "generated_tokens", generatedTokenCount, - "kv_cache_percent", in.KVCachePercentage, - "waiting_queue", in.NumRequestWaiting, - "running_queue", in.NumRequestRunning, - "prefix_cache_score", in.PrefixCacheScore) - - return result, nil -} - // bulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. // Returns predictions in the same order as the input slices. func bulkPredictWithMetrics( diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go new file mode 100644 index 0000000000..92227cba62 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper_test.go @@ -0,0 +1,100 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" +) + +func TestBulkPredictWithMetrics(t *testing.T) { + mockPredictor := &mockPredictor{ + predictions: map[string]*latencypredictor.PredictionResponse{ + "0.5": {TTFT: 0.5, TPOT: 0.03}, + "0.6": {TTFT: 0.6, TPOT: 0.04}, + }, + } + + metricsStates := []*backendmetrics.MetricsState{ + {KVCacheUsagePercent: 0.5}, + {KVCacheUsagePercent: 0.6}, + } + prompts := []string{"prompt1", "prompt2"} + generatedTokenCounts := []int{1, 1} + prefixCacheScores := []float64{0.0, 0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, 0.5, results[0].TTFT) + assert.Equal(t, 0.03, results[0].TPOT) + assert.Equal(t, 0.6, results[1].TTFT) + assert.Equal(t, 0.04, results[1].TPOT) +} + +func TestBulkPredictWithMetrics_Error(t *testing.T) { + mockPredictor := &mockPredictor{ + err: errors.New("prediction failed"), + } + + metricsStates := []*backendmetrics.MetricsState{ + {KVCacheUsagePercent: 0.5}, + } + prompts := []string{"prompt1"} + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) +} + +func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) { + mockPredictor := &mockPredictor{} + metricsStates := []*backendmetrics.MetricsState{{}} + prompts := []string{"prompt1", "prompt2"} // Mismatch length + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) + assert.True(t, strings.Contains(err.Error(), "input slice lengths must match")) +} + +func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) { + mockPredictor := &mockPredictor{} + metricsStates := []*backendmetrics.MetricsState{nil} // Nil metrics state + prompts := []string{"prompt1"} + generatedTokenCounts := []int{1} + prefixCacheScores := []float64{0.0} + + results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + + assert.Error(t, err) + assert.Nil(t, results) + assert.True(t, strings.Contains(err.Error(), "metrics state at index 0 cannot be nil")) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index 19f55d92f7..7d41de95c9 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -21,6 +21,7 @@ import ( "context" "sigs.k8s.io/controller-runtime/pkg/log" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" @@ -44,35 +45,49 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul logger := log.FromContext(ctx) predictions := make([]podPredictionResult, 0, len(candidatePods)) - for _, pod := range candidatePods { - predResult := podPredictionResult{Pod: pod} + // Prepare inputs for bulk prediction + metricsStates := make([]*backendmetrics.MetricsState, len(candidatePods)) + prompts := make([]string, len(candidatePods)) + generatedTokenCounts := make([]int, len(candidatePods)) + prefixCacheScores := make([]float64, len(candidatePods)) + for i, pod := range candidatePods { logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) // Get prefix cache score for the pod prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) - // Generate prediction - prediction, err := predictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore) - if err != nil { - logger.V(logutil.DEBUG).Error(err, "Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err) - predResult.Error = err - return nil, err - } - predResult.PrefixCacheScore = prefixCacheScore + metricsStates[i] = pod.GetMetrics() + prompts[i] = request.Body.Completions.Prompt + generatedTokenCounts[i] = 1 + prefixCacheScores[i] = prefixCacheScore + } + + // Bulk predict + bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores) + if err != nil { + logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed") + return nil, err + } + + // Process results + for i, pod := range candidatePods { + prediction := bulkPredictions[i] + predResult := podPredictionResult{Pod: pod} + + predResult.PrefixCacheScore = prefixCacheScores[i] predResult.TTFT = prediction.TTFT predResult.TPOT = prediction.TPOT - podMinTPOTSLO := 0.0 - podMinTPOTSLO = s.getPodMinTPOTSLO(pod) + + podMinTPOTSLO := s.getPodMinTPOTSLO(pod) predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) logger.V(logutil.DEBUG).Info("Prediction for scheduling", "pod", pod.GetPod().String(), - "prefixCacheScore", prefixCacheScore, + "prefixCacheScore", predResult.PrefixCacheScore, "TTFT", prediction.TTFT, "TPOT", prediction.TPOT, "buffer", SLOBufferFactor, diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go index cd021d36a5..13d2543a60 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/sampler.go @@ -33,21 +33,6 @@ type tokenSampler struct { sampleCount int } -// setSamplingMean sets the sampling mean (lambda) for the Poisson distribution -func (ts *tokenSampler) setSamplingMean(mean float64) { - ts.samplingMean = mean -} - -// setMaxSamples sets the maximum number of samples -func (ts *tokenSampler) setMaxSamples(max int) { - ts.maxSamples = max -} - -// setSampleCount sets the current number of predictions made -func (ts *tokenSampler) setSampleCount(count int) { - ts.sampleCount = count -} - func newTokenSampler(requestID string, samplingMean float64, maxSamples int) *tokenSampler { // Use request ID hash as seed for reproducibility seed := int64(0) @@ -124,13 +109,3 @@ func (ts *tokenSampler) recordPrediction(currentToken int) { func (ts *tokenSampler) getNextSampleToken() int { return ts.nextSampleToken } - -// setNextSampleToken sets the next token to predict for -func (ts *tokenSampler) setNextSampleToken(token int) { - ts.nextSampleToken = token -} - -// getSampleCount returns the current number of predictions made -func (ts *tokenSampler) getSampleCount() int { - return ts.sampleCount -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index f69006d5ab..8d05418e43 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -30,7 +30,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) @@ -152,22 +151,7 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle sloCtx := s.getOrMakeSLORequestContext(request) - var err error - // get request slos - // Get Request SLOs from request header - sloCtx.ttftSLO, err = parseFloatHeader(*request, ttftSLOHeaderKey) - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", ttftSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") - } - - sloCtx.avgTPOTSLO, err = parseFloatHeader(*request, tpotSLOHeaderKey) - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") - } - sloCtx.predictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") - if err != nil { - logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") - } + s.parseSLOHeaders(ctx, request, sloCtx) // Check if SLOs are provided if !sloCtx.predictorBasedScheduling { @@ -218,53 +202,13 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle } // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% - var posHeadroomPods, negHeadroomPods []podPredictionResult - for _, p := range allPreds { - // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom - if p.Headroom > 0 && p.TTFTHeadroom > 0 { - posHeadroomPods = append(posHeadroomPods, p) - } else { - // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom - negHeadroomPods = append(negHeadroomPods, p) - } - } + posHeadroomPods, negHeadroomPods := s.classifyPodsByHeadroom(allPreds) logger.V(logutil.DEBUG).Info("Pod headroom distribution", "positivePods", len(posHeadroomPods), "negativePods", len(negHeadroomPods)) - var selectedPod schedulingtypes.Pod - - switch { - case s.headroomStrategy == headroomStrategyCompositeOnly: - logger.V(logutil.DEBUG).Info("Selecting from composite scores only") - selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) - case len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0: - // 99% chance to select from positive headroom pods, 1% from negative - if r.Float64() < EpsilonExploreNeg { - logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) - } else { - logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - } - case len(posHeadroomPods) > 0: - // If only positive headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only positive headroom pods available") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - case len(negHeadroomPods) > 0: - // If only negative headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only negative headroom pods available") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) - case len(allPreds) > 0: - // fallback - select randomly from valid pods - logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") - selectedPod = allPreds[r.Intn(len(allPreds))].Pod - default: - // No valid pods - return all zeros - logger.V(logutil.DEBUG).Info("No valid pods available, returning all zero scores") - return scores - } + selectedPod := s.selectPodBasedOnStrategy(ctx, r, allPreds, posHeadroomPods, negHeadroomPods) // Set score = 1 for selected pod, 0 for all others if selectedPod != nil { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go new file mode 100644 index 0000000000..5b3fea8887 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go @@ -0,0 +1,102 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "fmt" + "math/rand" + + "sigs.k8s.io/controller-runtime/pkg/log" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) parseSLOHeaders(ctx context.Context, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext) { + logger := log.FromContext(ctx) + var err error + + // Get Request SLOs from request header + sloCtx.ttftSLO, err = parseFloatHeader(*request, ttftSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", ttftSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TTFT SLO from header") + } + + sloCtx.avgTPOTSLO, err = parseFloatHeader(*request, tpotSLOHeaderKey) + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("%v must be a float: %v", tpotSLOHeaderKey, err)}, "SLOAwareRouter: Error parsing TPOT SLO from header") + } + sloCtx.predictorBasedScheduling, err = parseBoolHeader(*request, "x-prediction-based-scheduling") + if err != nil { + logger.V(logutil.DEBUG).Error(errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-prediction-based-scheduling must be a bool: %v", err)}, "SLOAwareRouter: Error parsing PredictorBasedScheduling from header") + } +} + +func (s *SLOAwareRouter) classifyPodsByHeadroom(allPreds []podPredictionResult) (posHeadroomPods, negHeadroomPods []podPredictionResult) { + for _, p := range allPreds { + // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + if p.Headroom > 0 && p.TTFTHeadroom > 0 { + posHeadroomPods = append(posHeadroomPods, p) + } else { + // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomPods = append(negHeadroomPods, p) + } + } + return +} + +func (s *SLOAwareRouter) selectPodBasedOnStrategy( + ctx context.Context, + r *rand.Rand, + allPreds, posHeadroomPods, negHeadroomPods []podPredictionResult, +) schedulingtypes.Pod { + logger := log.FromContext(ctx) + var selectedPod schedulingtypes.Pod + + switch { + case s.headroomStrategy == headroomStrategyCompositeOnly: + logger.V(logutil.DEBUG).Info("Selecting from composite scores only") + selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) + case len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0: + // 99% chance to select from positive headroom pods, 1% from negative + if r.Float64() < EpsilonExploreNeg { + logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + } else { + logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + } + case len(posHeadroomPods) > 0: + // If only positive headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom pods available") + selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + case len(negHeadroomPods) > 0: + // If only negative headroom pods exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom pods available") + selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + case len(allPreds) > 0: + // fallback - select randomly from valid pods + logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") + selectedPod = allPreds[r.Intn(len(allPreds))].Pod + default: + // No valid pods - return nil (caller handles this) + logger.V(logutil.DEBUG).Info("No valid pods available") + return nil + } + return selectedPod +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go index c5278e0bcb..02a99cc6e8 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -31,7 +31,6 @@ import ( // selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy // Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { - logger := log.FromContext(ctx) if len(posHeadroomPods) == 1 { return posHeadroomPods[0].Pod @@ -52,87 +51,12 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos } // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] - minTPOTH, maxTPOTH := math.MaxFloat64, -math.MaxFloat64 - minTTFTH, maxTTFTH := math.MaxFloat64, -math.MaxFloat64 - - for _, p := range candidates { - if p.Headroom < minTPOTH { - minTPOTH = p.Headroom - } - if p.Headroom > maxTPOTH { - maxTPOTH = p.Headroom - } - if p.TTFTHeadroom < minTTFTH { - minTTFTH = p.TTFTHeadroom - } - if p.TTFTHeadroom > maxTTFTH { - maxTTFTH = p.TTFTHeadroom - } - } - - tpotRange := maxTPOTH - minTPOTH - ttftRange := maxTTFTH - minTTFTH - - // Precompute blend weights (renormalize if user sets both to 0) - alpha := HeadroomTTFTWeight - beta := HeadroomTPOTWeight - if alpha+beta <= 0 { - alpha = 1.0 - beta = 0.0 - } - sum := alpha + beta - alpha /= sum - beta /= sum - - logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", - "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, - "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, - "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + minTPOTH, maxTPOTH, minTTFTH, maxTTFTH := s.calculateHeadroomRanges(candidates) // Calculate weights for weighted random selection - weightedChoices := make([]choice, 0, len(candidates)) - total := 0 - - for _, p := range candidates { - // Normalize to [0,1] within the cohort - nTPOTH := 0.5 - if tpotRange > eps { - nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) - } - nTTFTH := 0.5 - if ttftRange > eps { - nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) - } - - // Blend: larger combined -> "safer"; smaller -> "tighter packing" - combined := alpha*nTTFTH + beta*nTPOTH - - // Map to integer weights - var w int - switch s.headroomStrategy { - case headroomStrategyLeast: - // prefer smaller combined headroom (pack closer to limits) - w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 - case headroomStrategyMost: - // prefer larger combined headroom (more conservative / spread) - w = int(combined*float64(wMax-minWeight)) + minWeight + 1 - default: - // Fallback to least - w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 - } - - weightedChoices = append(weightedChoices, choice{podName: p.Pod, weight: w}) - total += w - - logger.V(logutil.TRACE).Info("Positive headroom blended weight", - "pod", p.Pod.GetPod().String(), - "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, - "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, - "combined", combined, "weight", w) - } + weightedChoices, total := s.calculateWeightedChoices(ctx, candidates, minTPOTH, maxTPOTH, minTTFTH, maxTTFTH) return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) - } // selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go new file mode 100644 index 0000000000..cdf3d965a1 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go @@ -0,0 +1,114 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package slo_aware_router + +import ( + "context" + "math" + + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func (s *SLOAwareRouter) calculateHeadroomRanges(candidates []podPredictionResult) (minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64) { + minTPOTH, maxTPOTH = math.MaxFloat64, -math.MaxFloat64 + minTTFTH, maxTTFTH = math.MaxFloat64, -math.MaxFloat64 + + for _, p := range candidates { + if p.Headroom < minTPOTH { + minTPOTH = p.Headroom + } + if p.Headroom > maxTPOTH { + maxTPOTH = p.Headroom + } + if p.TTFTHeadroom < minTTFTH { + minTTFTH = p.TTFTHeadroom + } + if p.TTFTHeadroom > maxTTFTH { + maxTTFTH = p.TTFTHeadroom + } + } + return +} + +func (s *SLOAwareRouter) calculateWeightedChoices( + ctx context.Context, + candidates []podPredictionResult, + minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64, +) ([]choice, int) { + logger := log.FromContext(ctx) + tpotRange := maxTPOTH - minTPOTH + ttftRange := maxTTFTH - minTTFTH + + // Precompute blend weights (renormalize if user sets both to 0) + alpha := HeadroomTTFTWeight + beta := HeadroomTPOTWeight + if alpha+beta <= 0 { + alpha = 1.0 + beta = 0.0 + } + sum := alpha + beta + alpha /= sum + beta /= sum + + logger.V(logutil.DEBUG).Info("Positive headroom normalization ranges", + "minTPOTHeadroom", minTPOTH, "maxTPOTHeadroom", maxTPOTH, + "minTTFTHeadroom", minTTFTH, "maxTTFTHeadroom", maxTTFTH, + "alphaTTFT", alpha, "betaTPOT", beta, "strategy", s.headroomStrategy) + + weightedChoices := make([]choice, 0, len(candidates)) + total := 0 + + for _, p := range candidates { + // Normalize to [0,1] within the cohort + nTPOTH := 0.5 + if tpotRange > eps { + nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + } + nTTFTH := 0.5 + if ttftRange > eps { + nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + } + + // Blend: larger combined -> "safer"; smaller -> "tighter packing" + combined := alpha*nTTFTH + beta*nTPOTH + + // Map to integer weights + var w int + switch s.headroomStrategy { + case headroomStrategyLeast: + // prefer smaller combined headroom (pack closer to limits) + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 + case headroomStrategyMost: + // prefer larger combined headroom (more conservative / spread) + w = int(combined*float64(wMax-minWeight)) + minWeight + 1 + default: + // Fallback to least + w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 + } + + weightedChoices = append(weightedChoices, choice{podName: p.Pod, weight: w}) + total += w + + logger.V(logutil.TRACE).Info("Positive headroom blended weight", + "pod", p.Pod.GetPod().String(), + "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "combined", combined, "weight", w) + } + return weightedChoices, total +}