Skip to content
Merged
28 changes: 19 additions & 9 deletions pkg/epp/datalayer/plugins/approximateprefix/data_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,38 @@ const (
)

type PrefixCacheMatchInfo struct {
matchLength int
// matched prefix length in blocks
matchBlocks int
// total length in blocks
totalBlocks int
// block length in tokens
blockSizeTokens int
}

func NewPrefixCacheMatchInfo(matchLen int, blockHashLen int) *PrefixCacheMatchInfo {
func NewPrefixCacheMatchInfo(matchBlocks int, totalBlocks int, blockSizeTokens int) *PrefixCacheMatchInfo {
return &PrefixCacheMatchInfo{
matchLength: matchLen,
totalBlocks: blockHashLen,
matchBlocks: matchBlocks,
totalBlocks: totalBlocks,
blockSizeTokens: blockSizeTokens,
}
}

func (p *PrefixCacheMatchInfo) MatchLength() int {
return p.matchLength
func (p *PrefixCacheMatchInfo) MatchBlocks() int {
return p.matchBlocks
}

func (p *PrefixCacheMatchInfo) TotalLength() int {
func (p *PrefixCacheMatchInfo) TotalBlocks() int {
return p.totalBlocks
}

func (p *PrefixCacheMatchInfo) BlockSizeTokens() int {
return p.blockSizeTokens
}

func (p *PrefixCacheMatchInfo) Clone() datalayer.Cloneable {
return &PrefixCacheMatchInfo{
matchLength: p.matchLength,
totalBlocks: p.totalBlocks,
matchBlocks: p.matchBlocks,
totalBlocks: p.totalBlocks,
blockSizeTokens: p.blockSizeTokens,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (s *PredictedLatency) PrepareRequestData(ctx context.Context, request *sche

if prefixCacheInfoRaw, ok := endpoint.Get(approximateprefix.PrefixCacheMatchInfoKey); ok {
prefixCacheInfo := prefixCacheInfoRaw.(*approximateprefix.PrefixCacheMatchInfo)
prefixCacheScore = float64(prefixCacheInfo.MatchLength()) / float64(prefixCacheInfo.TotalLength())
prefixCacheScore = float64(prefixCacheInfo.MatchBlocks()) / float64(prefixCacheInfo.TotalBlocks())
if !math.IsNaN(prefixCacheScore) {
logger.V(logutil.DEBUG).Info("Found prefix cache score in pod attribute", "pod", endpoint.GetMetadata().NamespacedName.Name, "score", prefixCacheScore)
} else {
Expand Down
100 changes: 64 additions & 36 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
Expand All @@ -37,8 +38,8 @@ import (
)

const (
// vLLM default token block size is 16, and a good guess of average characters per token is 4.
DefaultBlockSize = 64
// vLLM default token block size is 16 tokens
DefaultBlockSizeTokens = 16
// The maximum number of blocks to match. Two long requests with the same prefix up to this
// limit will be indistinguishable.
// This parameter provides a trade-off between cache size, prefix matching speed and matching
Expand Down Expand Up @@ -73,7 +74,8 @@ const (

var DefaultConfig = Config{
AutoTune: true,
BlockSize: DefaultBlockSize,
BlockSize: 0,
BlockSizeTokens: 0,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
Expand All @@ -82,8 +84,12 @@ type Config struct {
// If set to true, the plugin will automatically adjust the configuration based on various
// metrics from the model servers.
AutoTune bool `json:"autoTune"`
// The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests
// The input prompt is broken into sizes of BlockSizeTokens to calculate block hashes. Requests
// with length shorter than the block size will be ignored.
BlockSizeTokens int `json:"blockSizeTokens"`
// Depricated: Legacy block size defined in number of characters.
// In case only BlockSize is defined in the configuration - plugin initialization will fail.
// In case both BlockSize and BlockSizeTokens are defined - BlockSizeTokens is used.
BlockSize int `json:"blockSize"`
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
// be ignored.
Expand Down Expand Up @@ -131,7 +137,7 @@ var _ plugin.StateData = &SchedulingContextState{}
type SchedulingContextState struct {
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
PrefixHashes []BlockHash
// A map of server to its longest prefix cache match length.
// A map of server to its longest prefix cache match length in blocks.
PrefixCacheServers map[ServerID]int
}

Expand Down Expand Up @@ -165,13 +171,25 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle
}
}

p := New(handle.Context(), parameters).WithName(name)
p, err := New(handle.Context(), parameters)
if err != nil {
return nil, err
}

p.WithName(name)
go p.CleanUpInactivePods(handle.Context(), handle)
return p, nil
}

// New initializes a new prefix Plugin and returns its pointer.
func New(ctx context.Context, config Config) *Plugin {
func New(ctx context.Context, config Config) (*Plugin, error) {
// invalid configuration: only BlockSize is defined
if config.BlockSize > 0 && config.BlockSizeTokens <= 0 {
err := errors.New("BlockSize is depricated, use BlockSizeTokens instead, the value should be defined in tokens")
log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "invalid prefix plugin configuration")
return nil, err
}

if config.LRUCapacityPerServer <= 0 {
config.LRUCapacityPerServer = DefaultLRUCapacityPerServer
log.FromContext(ctx).V(logutil.DEFAULT).Info(
Expand All @@ -180,10 +198,10 @@ func New(ctx context.Context, config Config) *Plugin {
)
}

if config.BlockSize <= 0 {
config.BlockSize = DefaultBlockSize
if config.BlockSizeTokens <= 0 {
config.BlockSizeTokens = DefaultBlockSizeTokens
log.FromContext(ctx).V(logutil.DEFAULT).Info("BlockSize is not positive, using default value",
"default", DefaultBlockSize)
"default", DefaultBlockSizeTokens)
}

if config.MaxPrefixBlocksToMatch <= 0 {
Expand All @@ -198,7 +216,7 @@ func New(ctx context.Context, config Config) *Plugin {
config: config,
pluginState: plugin.NewPluginState(ctx),
indexer: newIndexer(ctx, config.LRUCapacityPerServer),
}
}, nil
}

// TypedName returns the type and name tuple of this plugin instance.
Expand Down Expand Up @@ -227,7 +245,8 @@ func (p *Plugin) Consumes() map[string]any {

// PrepareRequestData hashes prompt, finds longest prefix match and stores it in endpoint as attribute.
func (p *Plugin) PrepareRequestData(ctx context.Context, request *framework.LLMRequest, endpoints []framework.Endpoint) error {
hashes := hashPrompt(ctx, request, getBlockSize(endpoints, p.config), p.config.MaxPrefixBlocksToMatch)
blockSize := getBlockSize(endpoints, p.config)
hashes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
Expand All @@ -236,8 +255,9 @@ func (p *Plugin) PrepareRequestData(ctx context.Context, request *framework.LLMR

for _, endpoint := range endpoints {
matchLen := state.PrefixCacheServers[ServerID(endpoint.GetMetadata().NamespacedName)]
endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total))
endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total, blockSize))
}

// Store the state in plugin state for later use.
p.pluginState.Write(request.RequestId, plugin.StateKey(p.TypedName().String()), state)
return nil
Expand All @@ -246,7 +266,8 @@ func (p *Plugin) PrepareRequestData(ctx context.Context, request *framework.LLMR
// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, cycleState *framework.CycleState, request *framework.LLMRequest, endpoints []framework.Endpoint) map[framework.Endpoint]float64 {
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, getBlockSize(endpoints, p.config), p.config.MaxPrefixBlocksToMatch)
blockSize := getBlockSize(endpoints, p.config)
hashes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
Expand All @@ -257,20 +278,21 @@ func (p *Plugin) Score(ctx context.Context, cycleState *framework.CycleState, re
// store the state in plugin state for later use in PreRequest. This may go away once we default to prepare request data plugin hook.
p.pluginState.Write(request.RequestId, plugin.StateKey(p.TypedName().String()), state)
log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes)
// calculate the scores of pods
// calculate the scores of endpoints
scores := make(map[framework.Endpoint]float64, len(endpoints))

// total prefix length in tokens
total := len(state.PrefixHashes)
podScoreFunc := func(endpoint framework.Endpoint) float64 {
endpointScoreFunc := func(endpoint framework.Endpoint) float64 {
if total == 0 {
return 0
}
matchLen := state.PrefixCacheServers[ServerID(endpoint.GetMetadata().NamespacedName)]
return float64(matchLen) / float64(total)
}

for _, pod := range endpoints {
scores[pod] = podScoreFunc(pod)
for _, endpoint := range endpoints {
scores[endpoint] = endpointScoreFunc(endpoint)
}
return scores
}
Expand Down Expand Up @@ -308,7 +330,8 @@ func (p *Plugin) PreRequest(ctx context.Context, request *framework.LLMRequest,
matchLen := state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)]

blockSize := getBlockSize(primaryProfileResult.TargetEndpoints, p.config)
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
// report matched and total prefix length in chars
metrics.RecordPrefixCacheMatch(matchLen*blockSize*averageCharactersPerToken, total*blockSize*averageCharactersPerToken)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liu-cong I recommend to change the semantics of the metric and report in tokens instead (meaning remove this multiplication by avg chars per token), while it is not "backward compatible" I think it is safe to do that, wdyt?

}

func (p *Plugin) makeServer(targetEndpoint framework.Endpoint) Server {
Expand All @@ -322,21 +345,20 @@ func (p *Plugin) makeServer(targetEndpoint framework.Endpoint) Server {
}
}

// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
// matchLongestPrefix returns a map of servers and length of prefix that each server caches, prefix length is defined in blocks.
func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map[ServerID]int {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
res := make(map[ServerID]int)
// Use a greedy strategy to search from the longest prefix.
// NOTE: It's possible to further optimize this with a binary search.
for i := 0; i < len(hashes); i++ {
hash := hashes[i]
for i, hash := range hashes {
cachedServers := p.indexer.Get(hash)
if len(cachedServers) == 0 {
break
} else {
loggerTrace.Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i)
for server := range cachedServers {
// Update servers with their longest prefix match.
// Update servers with their longest prefix match, prefix length is in blocks.
res[server]++
}
}
Expand Down Expand Up @@ -374,7 +396,7 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugin.Handle)
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
// For block i, hash(i) = hash(block i content, hash(i-1)).
func hashPrompt(ctx context.Context, request *framework.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
func hashPrompt(ctx context.Context, request *framework.LLMRequest, blockSizeTokens int, maxPrefixBlocks int) []BlockHash {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
if request == nil || request.Body == nil {
loggerDebug.Info("Request or request data is nil, skipping hashing")
Expand All @@ -387,17 +409,20 @@ func hashPrompt(ctx context.Context, request *framework.LLMRequest, cacheBlockSi
return nil
}

if len(userInput) < cacheBlockSize {
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
// convert block size from tokens to characters
cacheBlockSizeChars := blockSizeTokens * averageCharactersPerToken

if len(userInput) < cacheBlockSizeChars {
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size in chars", cacheBlockSizeChars)
return nil
}
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
if len(userInput) > cacheBlockSizeChars*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size in chars", cacheBlockSizeChars)
userInput = userInput[:maxPrefixBlocks*cacheBlockSizeChars]
}
// Split the body into blocks of size cacheBlockSize.
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
res := make([]BlockHash, 0, len(userInput)/cacheBlockSizeChars)
// Add the model to the first block hash so that different models have different hashes even with the same body.
h := xxhash.New()
_, _ = h.Write([]byte(request.TargetModel))
Expand All @@ -406,14 +431,15 @@ func hashPrompt(ctx context.Context, request *framework.LLMRequest, cacheBlockSi
}

prevBlockHash := BlockHash(h.Sum64())
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
for i := 0; i+cacheBlockSizeChars <= len(userInput); i += cacheBlockSizeChars {
h.Reset()
_, _ = h.Write(userInput[i : i+cacheBlockSize])
_, _ = h.Write(userInput[i : i+cacheBlockSizeChars])
_, _ = h.Write(toBytes(prevBlockHash))
res = append(res, BlockHash(h.Sum64()))

prevBlockHash = res[len(res)-1]
}

return res
}

Expand All @@ -432,23 +458,25 @@ func getUserInputBytes(request *framework.LLMRequest) ([]byte, error) {
return json.Marshal(request.Body.ChatCompletions.Messages)
}

// getBlockSize returns the block size in tokens.
// In case of auto-tune uses the block size from the first endpoint, otherwise uses the block size from the configuration
func getBlockSize(endpoints []framework.Endpoint, config Config) int {
if !config.AutoTune {
return config.BlockSize
return config.BlockSizeTokens
}

// Fallback to BlockSize if no metrics are available.
if len(endpoints) == 0 {
return config.BlockSize
return config.BlockSizeTokens
}

// Since all Endpoints originate from the same inference pool, they are considered to have identical configurations.
// Therefore, using the CacheBlockSize value from the first Endpoint suffices.
if endpoint := endpoints[0]; endpoint.GetMetrics() != nil {
cacheBlockSize := endpoint.GetMetrics().CacheBlockSize
if cacheBlockSize > 0 {
return cacheBlockSize * averageCharactersPerToken
return cacheBlockSize
}
}
return config.BlockSize
return config.BlockSizeTokens
}
Loading