Skip to content

Add Fine-Tuning Feature Skeleton #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions gateway/internal/api/v1/finetune.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package v1

import (
"context"
"encoding/json"
"net/http"

"connectrpc.com/connect"
"github.com/missingstudio/ai/gateway/core/provider"
"github.com/missingstudio/ai/gateway/internal/provider/huggingface"
"github.com/missingstudio/common/errors"
llmv1 "github.com/missingstudio/protos/pkg/llm/v1"
)

func (s *V1Handler) InitiateFineTuning(ctx context.Context, req *connect.Request[llmv1.FineTuneRequest]) (*connect.Response[llmv1.FineTuneResponse], error) {
hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
if err != nil {
return nil, errors.NewInternal("failed to get HuggingFace provider")
}

jobID, err := hfProvider.(*huggingface.HuggingFaceProvider).InitiateFineTuning(ctx, req.Payload.Model, req.Payload.Parameters)
if err != nil {
return nil, errors.NewInternal("failed to initiate fine-tuning: " + err.Error())
}

return connect.NewResponse(&llmv1.FineTuneResponse{
JobId: jobID,
}), nil
}

func (s *V1Handler) CheckFineTuningStatus(ctx context.Context, req *connect.Request[llmv1.FineTuneStatusRequest]) (*connect.Response[llmv1.FineTuneStatusResponse], error) {
hfProvider, err := s.iProviderService.GetProvider(provider.Provider{Name: "HuggingFace"})
if err != nil {
return nil, errors.NewInternal("failed to get HuggingFace provider")
}

result, err := hfProvider.(*huggingface.HuggingFaceProvider).RetrieveFineTuningResults(ctx, req.Payload.JobId)
if err != nil {
return nil, errors.NewInternal("failed to retrieve fine-tuning results: " + err.Error())
}

status, ok := result["status"].(string)
if !ok {
return nil, errors.NewInternal("unexpected response format from HuggingFace")
}

return connect.NewResponse(&llmv1.FineTuneStatusResponse{
Status: status,
}), nil
}
5 changes: 5 additions & 0 deletions gateway/internal/api/v1/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ func (s *V1Handler) ListModels(ctx context.Context, req *connect.Request[llmv1.M
allProviderModels := map[string]*llmv1.ProviderModels{}

for name := range base.ProviderRegistry {
// Check if the provider is healthy before fetching models
if !router.DefaultHealthChecker{}.IsHealthy(name) {
continue
}

provider, err := s.iProviderService.GetProvider(provider.Provider{Name: name})
if err != nil {
continue
Expand Down
24 changes: 24 additions & 0 deletions gateway/internal/api/v1/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt

data := []*llmv1.Provider{}
for _, provider := range providers {
// Check if the provider is healthy before adding to the list
if !router.DefaultHealthChecker{}.IsHealthy(provider.Info().Name) {
continue
}
providerInfo := provider.Info()
data = append(data, &llmv1.Provider{
Title: providerInfo.Title,
Expand All @@ -34,6 +38,11 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt
}

func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.GetProviderRequest]) (*connect.Response[llmv1.GetProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name})
if err != nil {
return nil, errors.NewNotFound(err.Error())
Expand Down Expand Up @@ -63,6 +72,11 @@ func (s *V1Handler) GetProvider(ctx context.Context, req *connect.Request[llmv1.
}

func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llmv1.CreateProviderRequest]) (*connect.Response[llmv1.CreateProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()}

p, err := s.iProviderService.GetProvider(provider)
Expand Down Expand Up @@ -111,6 +125,11 @@ func (s *V1Handler) CreateProvider(ctx context.Context, req *connect.Request[llm
}

func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llmv1.UpdateProviderRequest]) (*connect.Response[llmv1.UpdateProviderResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

provider := provider.Provider{Name: req.Msg.Name, Config: req.Msg.Config.AsMap()}

p, err := s.iProviderService.GetProvider(provider)
Expand Down Expand Up @@ -172,6 +191,11 @@ func (s *V1Handler) UpsertProvider(ctx context.Context, req *connect.Request[llm
}

func (s *V1Handler) GetProviderConfig(ctx context.Context, req *connect.Request[llmv1.GetProviderConfigRequest]) (*connect.Response[llmv1.GetProviderConfigResponse], error) {
// First, check if the provider is healthy
if !router.DefaultHealthChecker{}.IsHealthy(req.Msg.Name) {
return nil, errors.NewNotFound("Provider is unhealthy")
}

p, err := s.iProviderService.GetProvider(provider.Provider{Name: req.Msg.Name})
if err != nil {
return nil, errors.NewNotFound(err.Error())
Expand Down
116 changes: 116 additions & 0 deletions gateway/internal/provider/huggingface/huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package huggingface

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/missingstudio/ai/gateway/internal/provider/base"
"github.com/missingstudio/common/errors"
)

type HuggingFaceProvider struct {
APIKey string
BaseURL string
}

func (hfp *HuggingFaceProvider) Info() base.ProviderInfo {
return base.ProviderInfo{
Name: "HuggingFace",
Description: "Provider for interacting with HuggingFace's transformer models",
}
}

func (hfp *HuggingFaceProvider) Models(ctx context.Context) ([]string, error) {
url := fmt.Sprintf("%s/models", hfp.BaseURL)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}

req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, errors.NewBadRequest("failed to fetch models from HuggingFace")
}

var models []string
err = json.NewDecoder(resp.Body).Decode(&models)
if err != nil {
return nil, err
}

return models, nil
}

func (hfp *HuggingFaceProvider) InitiateFineTuning(ctx context.Context, model string, parameters map[string]interface{}) (string, error) {
url := fmt.Sprintf("%s/fine-tune", hfp.BaseURL)
payload, err := json.Marshal(map[string]interface{}{
"model": model,
"parameters": parameters,
})
if err != nil {
return "", err
}

req, err := http.NewRequestWithContext(ctx, "POST", url, ioutil.NopCloser(bytes.NewReader(payload)))
if err != nil {
return "", err
}

req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
req.Header.Add("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", errors.NewBadRequest("failed to initiate fine-tuning on HuggingFace")
}

var result map[string]string
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return "", err
}

return result["job_id"], nil
}

func (hfp *HuggingFaceProvider) RetrieveFineTuningResults(ctx context.Context, jobID string) (map[string]interface{}, error) {
url := fmt.Sprintf("%s/fine-tune/%s", hfp.BaseURL, jobID)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}

req.Header.Add("Authorization", "Bearer "+hfp.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, errors.NewBadRequest("failed to retrieve fine-tuning results from HuggingFace")
}

var result map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&result)
if err != nil {
return nil, err
}

return result, nil
}
17 changes: 17 additions & 0 deletions gateway/internal/router/health_checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package router

import (
"log"
)

type HealthChecker interface {
IsHealthy(providerName string) bool
}

type DefaultHealthChecker struct{}

func (d *DefaultHealthChecker) IsHealthy(providerName string) bool {
// Placeholder for actual health check logic
// Currently returns true, assuming all providers are healthy
return true
}
25 changes: 18 additions & 7 deletions gateway/internal/router/priority.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package router

import (
"sync/atomic"
"log"
"gateway/internal/router" // Importing to use HealthChecker
)

const (
Expand All @@ -21,11 +23,20 @@ func NewPriorityRouter(providers []RouterConfig) *PriorityRouter {
}

func (r *PriorityRouter) Next() (*RouterConfig, error) {
idx := int(r.idx.Load())

// Todo: make a check for healthy provider
model := &r.providers[idx]
r.idx.Add(1)

return model, nil
providerLen := len(r.providers)
originalIdx := r.idx.Load()
var healthyProvider *RouterConfig
for i := 0; i < providerLen; i++ {
idx := (originalIdx + uint64(i)) % uint64(providerLen)
if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) {
healthyProvider = &r.providers[idx]
r.idx.Store(idx + 1)
break
}
}
if healthyProvider == nil {
log.Println("Error: No healthy providers available.")
return nil, fmt.Errorf("no healthy providers available")
}
return healthyProvider, nil
}
23 changes: 19 additions & 4 deletions gateway/internal/router/round_robin.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package router

import (
"log"
"sync/atomic"
"gateway/internal/router" // Importing to use HealthChecker
)

const (
Expand All @@ -26,9 +28,22 @@ func (r *RoundRobinRouter) Iterator() RouterIterator {
func (r *RoundRobinRouter) Next() *RouterConfig {
providerLen := len(r.providers)

// Todo: make a check for healthy provider
idx := r.idx.Add(1) - 1
model := &r.providers[idx%uint64(providerLen)]
// Iterate through providers to find a healthy one
var healthyProvider *RouterConfig
originalIdx := r.idx.Load()
for i := 0; i < providerLen; i++ {
idx := (originalIdx + uint64(i)) % uint64(providerLen)
if router.DefaultHealthChecker{}.IsHealthy(r.providers[idx].Name) {
healthyProvider = &r.providers[idx]
r.idx.Add(1)
break
}
}

if healthyProvider == nil {
log.Println("Error: No healthy providers available.")
return nil
}

return model
return healthyProvider
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ interface ModelSelectorProps extends PopoverProps {}

export default function ModelSelector(props: ModelSelectorProps) {
const [open, setOpen] = React.useState(false);
const { providers } = useModelFetch();
const [isFineTuning, setIsFineTuning] = React.useState(false);
const { providers } = useModelFetch(isFineTuning);
const { model, setModel, setProvider } = useStore();

const toggleFineTuning = () => setIsFineTuning(!isFineTuning);

return (
<div className="flex items-center gap-2">
<Label htmlFor="model">Model: </Label>
<Button variant="outline" onClick={toggleFineTuning}>
{isFineTuning ? 'Select for Fine-Tuning' : 'Select Model'}
</Button>
<Popover open={open} onOpenChange={setOpen} {...props}>
<PopoverTrigger asChild>
<Button
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ interface Model {
const BASE_URL = process.env.NEXT_PUBLIC_GATEWAY_URL ?? "http://localhost:3000";
export function useModelFetch() {
const [providers, setProviders] = useState<ModelType[]>([]);
const [isFineTuning, setIsFineTuning] = useState<boolean>(false);

useEffect(() => {
const fetchEndpoint = isFineTuning ? `${BASE_URL}/api/v1/finetune/models` : `${BASE_URL}/api/v1/models`;
async function fetchModels() {
try {
const response = await fetch(`${BASE_URL}/api/v1/models`);
const response = await fetch(fetchEndpoint);
const { models } = await response.json();
const fetchedProviders: ModelType[] = Object.keys(models).map(
(key) => ({
Expand Down