Skip to content
Merged
4 changes: 2 additions & 2 deletions pkg/bbr/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
reqenvoy "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/request"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

const (
Expand Down Expand Up @@ -197,7 +197,7 @@ func (s *Server) HandleRequestHeaders(reqCtx *RequestContext, headers *eppb.Http

if headers != nil && headers.Headers != nil {
for _, header := range headers.Headers.Headers {
reqCtx.Request.Headers[header.Key] = requtil.GetHeaderValue(header)
reqCtx.Request.Headers[header.Key] = reqenvoy.GetHeaderValue(header)
}
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/bbr/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework"
reqenvoy "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/request"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
reqcommon "sigs.k8s.io/gateway-api-inference-extension/pkg/common/request"
)

type Datastore interface {
Expand Down Expand Up @@ -108,8 +109,8 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
// If streaming and the body is not empty, then headers are handled when processing request body.
loggerVerbose.Info("Received headers, passing off header processing until body arrives...")
} else {
if requestId := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestId) > 0 {
logger = logger.WithValues(requtil.RequestIdHeaderKey, requestId)
if requestId := reqenvoy.ExtractHeaderValue(v, reqcommon.RequestIdHeaderKey); len(requestId) > 0 {
logger = logger.WithValues(reqcommon.RequestIdHeaderKey, requestId)
loggerVerbose = logger.V(logutil.VERBOSE)
ctx = log.IntoContext(ctx, logger)
}
Expand Down
47 changes: 47 additions & 0 deletions pkg/common/envoy/request/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package request

import (
"strings"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
)

// GetHeaderValue safely extracts the string value from an Envoy HeaderValue field.
func GetHeaderValue(header *corev3.HeaderValue) string {
if len(header.RawValue) > 0 {
return string(header.RawValue)
}
return header.Value
}

// ExtractHeaderValue searches for a specific header key in the processing request and returns its value.
// The lookup is case-insensitive.
// Returns an empty string if the header is missing or if the request structure is nil.
func ExtractHeaderValue(req *extProcPb.ProcessingRequest_RequestHeaders, headerKey string) string {
headerKeyInLower := strings.ToLower(headerKey)
if req != nil && req.RequestHeaders != nil && req.RequestHeaders.Headers != nil {
for _, headerKv := range req.RequestHeaders.Headers.Headers {
if strings.ToLower(headerKv.Key) == headerKeyInLower {
return GetHeaderValue(headerKv)
}
}
}
return ""
}
File renamed without changes.
21 changes: 21 additions & 0 deletions pkg/common/request/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package request
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what do you think about putting this under pkg/common/envoy/...?
the functions here are using envoy structs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

I moved most of the code under pkg/common/requests to pkg/common/envoy/requests. I left behind the definition of the header key, which isn't Envoy related.


const (
RequestIdHeaderKey = "x-request-id"
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log"

logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
reqcommon "sigs.k8s.io/gateway-api-inference-extension/pkg/common/request"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
)

Expand Down Expand Up @@ -203,7 +203,7 @@ func processFirstTokenForLatencyPrediction(
func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatencyCtx, samplingMean float64, maxSampledTokens int) {
if predictedLatencyCtx.tokenSampler == nil {
logger := log.FromContext(ctx)
requestID := predictedLatencyCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey]
requestID := predictedLatencyCtx.schedulingRequest.Headers[reqcommon.RequestIdHeaderKey]
predictedLatencyCtx.tokenSampler = newTokenSampler(requestID, samplingMean, maxSampledTokens)
logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", predictedLatencyCtx.tokenSampler.getNextSampleToken())
}
Expand Down Expand Up @@ -268,7 +268,7 @@ func processTokenForLatencyPrediction(

// Initialize sampler if not yet
if predictedLatencyCtx.tokenSampler == nil {
requestID := predictedLatencyCtx.schedulingRequest.Headers[requtil.RequestIdHeaderKey]
requestID := predictedLatencyCtx.schedulingRequest.Headers[reqcommon.RequestIdHeaderKey]
predictedLatencyCtx.tokenSampler = newTokenSampler(requestID, samplingMean, maxSampledTokens)
logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", predictedLatencyCtx.tokenSampler.getNextSampleToken())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log"

logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
reqcommon "sigs.k8s.io/gateway-api-inference-extension/pkg/common/request"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

const (
Expand Down Expand Up @@ -100,20 +100,20 @@ func newPredictedLatencyContext(request *schedulingtypes.LLMRequest) *predictedL
}

func (s *PredictedLatency) getPredictedLatencyContextForRequest(request *schedulingtypes.LLMRequest) (*predictedLatencyCtx, error) {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
if item := s.sloContextStore.Get(id); item != nil {
return item.Value(), nil
}
return nil, fmt.Errorf("SLO context not found for request ID: %s", id)
}

func (s *PredictedLatency) setPredictedLatencyContextForRequest(request *schedulingtypes.LLMRequest, ctx *predictedLatencyCtx) {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
s.sloContextStore.Set(id, ctx, ttlcache.DefaultTTL)
}

func (s *PredictedLatency) deletePredictedLatencyContextForRequest(request *schedulingtypes.LLMRequest) {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
s.sloContextStore.Delete(id)
}

Expand Down Expand Up @@ -141,21 +141,21 @@ func (t *PredictedLatency) PreRequest(ctx context.Context, request *schedulingty
Namespace: targetMetadata.NamespacedName.Namespace,
}

logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "endpointName", endpointName)
if request.Headers[requtil.RequestIdHeaderKey] == "" {
logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[reqcommon.RequestIdHeaderKey], "endpointName", endpointName)
if request.Headers[reqcommon.RequestIdHeaderKey] == "" {
logger.V(logutil.DEBUG).Error(errors.New("missing request ID"), "PredictedLatency.PreRequest: Request is missing request ID header")
return
}

id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]

// Get or create queue for this endpoint using sync.Map
actual, _ := t.runningRequestLists.LoadOrStore(endpointName, newRequestPriorityQueue())
endpointRequestList := actual.(*requestPriorityQueue)

predictedLatencyCtx, err := t.getPredictedLatencyContextForRequest(request)
if err != nil {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
logger.V(logutil.DEBUG).Info("PredictedLatency.PreRequest: Failed to get SLO context for request", "error", err.Error(), "requestID", id)
return
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func (t *PredictedLatency) ResponseStreaming(ctx context.Context, request *sched
now := time.Now()
predictedLatencyCtx, err := t.getPredictedLatencyContextForRequest(request)
if err != nil {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
logger.V(logutil.DEBUG).Info("PredictedLatency.ResponseStreaming: Failed to get SLO context for request", "error", err.Error(), "requestID", id)
return
}
Expand All @@ -229,7 +229,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu

predictedLatencyCtx, err := t.getPredictedLatencyContextForRequest(request)
if err != nil {
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
logger.V(logutil.DEBUG).Info("PredictedLatency.ResponseComplete: Failed to get SLO context for request", "error", err.Error(), "requestID", id)
return
}
Expand All @@ -256,7 +256,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
}
}

id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
t.removeRequestFromQueue(id, predictedLatencyCtx)
t.deletePredictedLatencyContextForRequest(request)
}
Expand Down Expand Up @@ -284,14 +284,14 @@ func (t *PredictedLatency) AdmitRequest(ctx context.Context, request *scheduling
predictedLatencyCtx, err := t.getPredictedLatencyContextForRequest(request)
if err != nil {
// If we can't find the predictedLatency context, we log the error but allow the request to proceed. This is a fail-open approach to avoid rejecting requests due to internal errors in our plugin.
id := request.Headers[requtil.RequestIdHeaderKey]
id := request.Headers[reqcommon.RequestIdHeaderKey]
logger.V(logutil.DEBUG).Error(err, "PredictedLatency.AdmitRequest: Failed to get PredictedLatency context for request", "requestID", id)
return nil
}

// If there is no valid pod for the request, reject it
if !predictedLatencyCtx.hasValidEndpoint && request.Objectives.Priority < 0 {
logger.V(logutil.DEBUG).Info("PredictedLatency.AdmitRequest: Rejecting a sheddable request as no valid endpoint available due to slo violation", "requestID", request.Headers[requtil.RequestIdHeaderKey])
logger.V(logutil.DEBUG).Info("PredictedLatency.AdmitRequest: Rejecting a sheddable request as no valid endpoint available due to slo violation", "requestID", request.Headers[reqcommon.RequestIdHeaderKey])
return errors.New("no valid endpoint available to serve the request")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

reqcommon "sigs.k8s.io/gateway-api-inference-extension/pkg/common/request"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

const (
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestNewPredictedLatencyContext(t *testing.T) {

func TestNewPredictedLatencyContext_NilBody(t *testing.T) {
request := &schedulingtypes.LLMRequest{
Headers: map[string]string{requtil.RequestIdHeaderKey: "test-nil-body"},
Headers: map[string]string{reqcommon.RequestIdHeaderKey: "test-nil-body"},
Body: nil,
}
ctx := newPredictedLatencyContext(request)
Expand Down Expand Up @@ -378,7 +378,7 @@ func TestPredictedLatency_ResponseStreaming_FirstToken(t *testing.T) {

// Initialize the queue and add the request
queue := newRequestPriorityQueue()
queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0)
queue.Add(request.Headers[reqcommon.RequestIdHeaderKey], 50.0)
router.runningRequestLists.Store(endpoint.GetMetadata().NamespacedName, queue)

beforeTime := time.Now()
Expand Down Expand Up @@ -432,7 +432,7 @@ func TestPredictedLatency_ResponseStreaming_SubsequentTokens(t *testing.T) {

// Initialize the queue and add the request
queue := newRequestPriorityQueue()
queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0)
queue.Add(request.Headers[reqcommon.RequestIdHeaderKey], 50.0)
router.runningRequestLists.Store(endpoint.GetMetadata().NamespacedName, queue)

router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata())
Expand Down Expand Up @@ -498,7 +498,7 @@ func TestPredictedLatency_ResponseComplete_Success(t *testing.T) {
// Create queue and add request
queue := newRequestPriorityQueue()
router.runningRequestLists.Store(endpoint.GetMetadata().NamespacedName, queue)
queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0)
queue.Add(request.Headers[reqcommon.RequestIdHeaderKey], 50.0)

predictedLatencyCtx := newPredictedLatencyContext(request)
predictedLatencyCtx.ttft = 80
Expand Down Expand Up @@ -591,7 +591,7 @@ func TestPredictedLatency_ResponseComplete_WithMetrics(t *testing.T) {
// Create queue
queue := newRequestPriorityQueue()
router.runningRequestLists.Store(endpoint.GetMetadata().NamespacedName, queue)
queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0)
queue.Add(request.Headers[reqcommon.RequestIdHeaderKey], 50.0)

predictedLatencyCtx := newPredictedLatencyContext(request)
predictedLatencyCtx.ttft = 80
Expand Down Expand Up @@ -624,7 +624,7 @@ func TestPredictedLatency_ResponseComplete_NoSLOs(t *testing.T) {
// Create queue
queue := newRequestPriorityQueue()
router.runningRequestLists.Store(endpoint.GetMetadata().NamespacedName, queue)
queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0)
queue.Add(request.Headers[reqcommon.RequestIdHeaderKey], 0)

predictedLatencyCtx := newPredictedLatencyContext(request)
predictedLatencyCtx.ttft = 80
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import (
"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"

reqcommon "sigs.k8s.io/gateway-api-inference-extension/pkg/common/request"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
fwksched "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
)
Expand Down Expand Up @@ -140,7 +140,7 @@ func createTestChatCompletionsLLMRequest(reqID string, ttftSLO, tpotSLO float64)

func createTestLLMRequestWithBody(reqID string, ttftSLO, tpotSLO float64, body *fwksched.LLMRequestBody) *fwksched.LLMRequest {
headers := make(map[string]string)
headers[requtil.RequestIdHeaderKey] = reqID
headers[reqcommon.RequestIdHeaderKey] = reqID
if ttftSLO > 0 {
headers["x-ttft-slo"] = fmt.Sprintf("%f", ttftSLO)
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func setupPredictionContext(router *PredictedLatency, request *fwksched.LLMReque
}

// Store the context using the request ID
reqID := request.Headers[requtil.RequestIdHeaderKey]
reqID := request.Headers[reqcommon.RequestIdHeaderKey]
router.sloContextStore.Set(reqID, predictedLatencyCtx, ttlcache.DefaultTTL)
}

Expand Down Expand Up @@ -772,7 +772,7 @@ func TestSloContextStoreEviction(t *testing.T) {

req := &fwksched.LLMRequest{
Headers: map[string]string{
requtil.RequestIdHeaderKey: requestID,
reqcommon.RequestIdHeaderKey: requestID,
},
}

Expand Down
12 changes: 3 additions & 9 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,12 @@ import (
"google.golang.org/protobuf/types/known/structpb"

"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
reqenvoy "sigs.k8s.io/gateway-api-inference-extension/pkg/common/envoy/request"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)

const (
// defaultFairnessID is the default fairness ID used when no ID is provided in the request.
// This ensures that requests without explicit fairness identifiers are still grouped and managed by the Flow Control
// system.
defaultFairnessID = "default-flow"
)

func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
reqCtx.RequestReceivedTimestamp = time.Now()

Expand All @@ -61,7 +55,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
}

for _, header := range req.RequestHeaders.Headers.Headers {
reqCtx.Request.Headers[header.Key] = request.GetHeaderValue(header)
reqCtx.Request.Headers[header.Key] = reqenvoy.GetHeaderValue(header)
switch header.Key {
case metadata.FlowFairnessIDKey:
reqCtx.FairnessID = reqCtx.Request.Headers[header.Key]
Expand All @@ -73,7 +67,7 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
}

if reqCtx.FairnessID == "" {
reqCtx.FairnessID = defaultFairnessID
reqCtx.FairnessID = metadata.DefaultFairnessID
}

return nil
Expand Down
Loading