Skip to content

add support for load balancing across multiple models and providers #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gateway/internal/api/v1/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M
allProviderModels := map[string]*llmv1.ProviderModels{}

for name := range base.ProviderRegistry {
// Check if the provider is healthy before fetching models
if !router.DefaultHealthChecker{}.IsHealthy(name) {
continue
}

provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name})
if err != nil {
continue
Expand Down
24 changes: 24 additions & 0 deletions gateway/internal/api/v1/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt

data := []*llmv1.Provider{}
for _, provider := range providers {
// Check if the provider is healthy before adding to the list
if !router.DefaultHealthChecker{}.IsHealthy(provider.Info().Name) {
continue
}
providerInfo := provider.Info()
data = append(data, &llmv1.Provider{
Title: providerInfo.Title,
Expand All @@ -34,6 +38,11 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt
}

func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name})
if err != nil {
return nil, errors.NewNotFound(err.Error())
Expand Down Expand Up @@ -63,6 +72,11 @@ func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.
}

func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llmv1.CreateProviderRequest]) (*connect.Response[llmv1.CreateProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()}

p, err := s.iProviderService.GetProvider(provider)
Expand Down Expand Up @@ -111,6 +125,11 @@ func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llm
}

func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llmv1.UpdateProviderRequest]) (*connect.Response[llmv1.UpdateProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()}

p, err := s.iProviderService.GetProvider(provider)
Expand Down Expand Up @@ -172,6 +191,11 @@ func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llm
}

func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

p, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name})
if err != nil {
return nil, errors.NewNotFound(err.Error())
Expand Down
17 changes: 17 additions & 0 deletions gateway/internal/router/health_checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package router

import (
"log"
)

type HealthChecker interface {
IsHealthy(providerName string) bool
}

type DefaultHealthChecker struct{}

func (d *DefaultHealthChecker) IsHealthy(providerName string) bool {
// Placeholder for actual health check logic
// Currently returns true, assuming all providers are healthy
return true
}
25 changes: 18 additions & 7 deletions gateway/internal/router/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package router

import (
"sync/atomic"
"log"
"gateway/internal/router" // Importing to use HealthChecker
)

const (
Expand All @@ -21,11 +23,20 @@ func NewPriorityRouter(providers []RouterConfig) *PriorityRouter {
}

func (r *PriorityRouter) Next() (*RouterConfig, error) {
idx := int(r.idx.Load())

// Todo: make a check for healthy provider
model := &r.providers[idx]
r.idx.Add(1)

return model, nil
providerLen := len(r.providers)
originalIdx := r.idx.Load()
var healthyProvider *RouterConfig
for i := 0; i < providerLen; i++ {
idx := (originalIdx + uint64(i)) % uint64(providerLen)
if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) {
healthyProvider = &r.providers[idx]
r.idx.Store(idx + 1)
break
}
}
if healthyProvider == nil {
log.Println("Error: No healthy providers available.")
return nil, fmt.Errorf("no healthy providers available")
}
return healthyProvider, nil
}
23 changes: 19 additions & 4 deletions gateway/internal/router/round_robin.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package router

import (
"log"
"sync/atomic"
"gateway/internal/router" // Importing to use HealthChecker
)

const (
Expand All @@ -26,9 +28,22 @@ func (r *RoundRobinRouter) Iterator() RouterIterator {
func (r *RoundRobinRouter) Next() *RouterConfig {
providerLen := len(r.providers)

// Todo: make a check for healthy provider
idx := r.idx.Add(1) - 1
model := &r.providers[idx%uint64(providerLen)]
// Iterate through providers to find a healthy one
var healthyProvider *RouterConfig
originalIdx := r.idx.Load()
for i := 0; i < providerLen; i++ {
idx := (originalIdx + uint64(i)) % uint64(providerLen)
if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) {
healthyProvider = &r.providers[idx]
r.idx.Add(1)
break
}
}

if healthyProvider == nil {
log.Println("Error: No healthy providers available.")
return nil
}

return model
return healthyProvider
}