Skip to content

Commit dc44c57

Browse files
committed
Small refactors to star to use HF for discovery
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 8937998 commit dc44c57

File tree

5 files changed

+75
-47
lines changed

5 files changed

+75
-47
lines changed
Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,57 @@
11
package importers
22

33
import (
4+
"encoding/json"
5+
6+
"github.com/rs/zerolog/log"
7+
48
"github.com/mudler/LocalAI/core/gallery"
5-
"github.com/mudler/LocalAI/core/schema"
9+
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
610
)
711

812
var DefaultImporters = []Importer{
913
&LlamaCPPImporter{},
1014
&MLXImporter{},
1115
}
1216

17+
type Details struct {
18+
HuggingFace *hfapi.ModelDetails
19+
URI string
20+
Preferences json.RawMessage
21+
}
22+
1323
type Importer interface {
14-
Match(uri string, request schema.ImportModelRequest) bool
15-
Import(uri string, request schema.ImportModelRequest) (gallery.ModelConfig, error)
24+
Match(details Details) bool
25+
Import(details Details) (gallery.ModelConfig, error)
26+
}
27+
28+
func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
29+
var err error
30+
var modelConfig gallery.ModelConfig
31+
32+
hf := hfapi.NewClient()
33+
34+
hfDetails, err := hf.GetModelDetails(uri)
35+
if err != nil {
36+
// maybe not a HF repository
37+
// TODO: maybe we can check if the URI is a valid HF repository
38+
log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository")
39+
}
40+
41+
details := Details{
42+
HuggingFace: hfDetails,
43+
URI: uri,
44+
Preferences: preferences,
45+
}
46+
47+
for _, importer := range DefaultImporters {
48+
if importer.Match(details) {
49+
modelConfig, err = importer.Import(details)
50+
if err != nil {
51+
continue
52+
}
53+
break
54+
}
55+
}
56+
return modelConfig, err
1657
}

core/gallery/importers/llama-cpp.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ var _ Importer = &LlamaCPPImporter{}
1616

1717
type LlamaCPPImporter struct{}
1818

19-
func (i *LlamaCPPImporter) Match(uri string, request schema.ImportModelRequest) bool {
20-
preferences, err := request.Preferences.MarshalJSON()
19+
func (i *LlamaCPPImporter) Match(details Details) bool {
20+
preferences, err := details.Preferences.MarshalJSON()
2121
if err != nil {
2222
return false
2323
}
@@ -31,11 +31,11 @@ func (i *LlamaCPPImporter) Match(uri string, request schema.ImportModelRequest)
3131
return true
3232
}
3333

34-
return strings.HasSuffix(uri, ".gguf")
34+
return strings.HasSuffix(details.URI, ".gguf")
3535
}
3636

37-
func (i *LlamaCPPImporter) Import(uri string, request schema.ImportModelRequest) (gallery.ModelConfig, error) {
38-
preferences, err := request.Preferences.MarshalJSON()
37+
func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
38+
preferences, err := details.Preferences.MarshalJSON()
3939
if err != nil {
4040
return gallery.ModelConfig{}, err
4141
}
@@ -47,12 +47,12 @@ func (i *LlamaCPPImporter) Import(uri string, request schema.ImportModelRequest)
4747

4848
name, ok := preferencesMap["name"].(string)
4949
if !ok {
50-
name = filepath.Base(uri)
50+
name = filepath.Base(details.URI)
5151
}
5252

5353
description, ok := preferencesMap["description"].(string)
5454
if !ok {
55-
description = "Imported from " + uri
55+
description = "Imported from " + details.URI
5656
}
5757

5858
modelConfig := config.ModelConfig{
@@ -62,7 +62,7 @@ func (i *LlamaCPPImporter) Import(uri string, request schema.ImportModelRequest)
6262
Backend: "llama-cpp",
6363
PredictionOptions: schema.PredictionOptions{
6464
BasicModelRequest: schema.BasicModelRequest{
65-
Model: filepath.Base(uri),
65+
Model: filepath.Base(details.URI),
6666
},
6767
},
6868
TemplateConfig: config.TemplateConfig{
@@ -86,8 +86,8 @@ func (i *LlamaCPPImporter) Import(uri string, request schema.ImportModelRequest)
8686
ConfigFile: string(data),
8787
Files: []gallery.File{
8888
{
89-
URI: uri,
90-
Filename: filepath.Base(uri),
89+
URI: details.URI,
90+
Filename: filepath.Base(details.URI),
9191
},
9292
},
9393
}, nil

core/gallery/importers/mlx.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ var _ Importer = &MLXImporter{}
1515

1616
type MLXImporter struct{}
1717

18-
func (i *MLXImporter) Match(uri string, request schema.ImportModelRequest) bool {
19-
preferences, err := request.Preferences.MarshalJSON()
18+
func (i *MLXImporter) Match(details Details) bool {
19+
preferences, err := details.Preferences.MarshalJSON()
2020
if err != nil {
2121
return false
2222
}
@@ -32,15 +32,15 @@ func (i *MLXImporter) Match(uri string, request schema.ImportModelRequest) bool
3232
}
3333

3434
// All https://huggingface.co/mlx-community/*
35-
if strings.Contains(uri, "mlx-community/") {
35+
if strings.Contains(details.URI, "mlx-community/") {
3636
return true
3737
}
3838

3939
return false
4040
}
4141

42-
func (i *MLXImporter) Import(uri string, request schema.ImportModelRequest) (gallery.ModelConfig, error) {
43-
preferences, err := request.Preferences.MarshalJSON()
42+
func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
43+
preferences, err := details.Preferences.MarshalJSON()
4444
if err != nil {
4545
return gallery.ModelConfig{}, err
4646
}
@@ -52,12 +52,12 @@ func (i *MLXImporter) Import(uri string, request schema.ImportModelRequest) (gal
5252

5353
name, ok := preferencesMap["name"].(string)
5454
if !ok {
55-
name = filepath.Base(uri)
55+
name = filepath.Base(details.URI)
5656
}
5757

5858
description, ok := preferencesMap["description"].(string)
5959
if !ok {
60-
description = "Imported from " + uri
60+
description = "Imported from " + details.URI
6161
}
6262

6363
backend := "mlx"
@@ -73,7 +73,7 @@ func (i *MLXImporter) Import(uri string, request schema.ImportModelRequest) (gal
7373
Backend: backend,
7474
PredictionOptions: schema.PredictionOptions{
7575
BasicModelRequest: schema.BasicModelRequest{
76-
Model: uri,
76+
Model: details.URI,
7777
},
7878
},
7979
TemplateConfig: config.TemplateConfig{

core/http/endpoints/localai/import_model.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,9 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
3030
return err
3131
}
3232

33-
var err error
34-
var modelConfig gallery.ModelConfig
35-
36-
for _, importer := range importers.DefaultImporters {
37-
if importer.Match(input.URI, *input) {
38-
modelConfig, err = importer.Import(input.URI, *input)
39-
if err != nil {
40-
continue
41-
}
42-
break
43-
}
33+
modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences)
34+
if err != nil {
35+
return fmt.Errorf("failed to discover model config: %w", err)
4436
}
4537

4638
uuid, err := uuid.NewUUID()

core/startup/model_preload.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package startup
22

33
import (
4+
"encoding/json"
45
"errors"
56
"fmt"
67
"os"
@@ -13,7 +14,6 @@ import (
1314
"github.com/mudler/LocalAI/core/config"
1415
"github.com/mudler/LocalAI/core/gallery"
1516
"github.com/mudler/LocalAI/core/gallery/importers"
16-
"github.com/mudler/LocalAI/core/schema"
1717
"github.com/mudler/LocalAI/core/services"
1818
"github.com/mudler/LocalAI/pkg/downloader"
1919
"github.com/mudler/LocalAI/pkg/model"
@@ -165,22 +165,17 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
165165
continue
166166
}
167167

168-
// TODO: start autoimporter
169-
var err error
170-
var modelConfig gallery.ModelConfig
171-
for _, importer := range importers.DefaultImporters {
172-
if importer.Match(url, schema.ImportModelRequest{}) {
173-
modelConfig, err = importer.Import(url, schema.ImportModelRequest{})
174-
if err != nil {
175-
continue
176-
}
177-
break
178-
}
168+
// TODO: we should just use the discoverModelConfig here and default to this.
169+
modelConfig, discoverErr := importers.DiscoverModelConfig(url, json.RawMessage{})
170+
if discoverErr != nil {
171+
err = errors.Join(discoverErr, fmt.Errorf("failed to discover model config: %w", err))
172+
continue
179173
}
180174

181-
uuid, err := uuid.NewUUID()
182-
if err != nil {
183-
return err
175+
uuid, uuidErr := uuid.NewUUID()
176+
if uuidErr != nil {
177+
err = errors.Join(uuidErr, fmt.Errorf("failed to generate UUID: %w", uuidErr))
178+
continue
184179
}
185180

186181
galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{

0 commit comments

Comments
 (0)