Skip to content

Commit 22e4561

Browse files
committed
feat: support hugging face inference router
1 parent 03485fd commit 22e4561

File tree

4 files changed

+650
-0
lines changed

4 files changed

+650
-0
lines changed

cmd/huggingface/main.go

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Package main provides a command-line tool to fetch models from Hugging Face Router
2+
// and generate a configuration file for the provider.
3+
package main
4+
5+
import (
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"log"
11+
"net/http"
12+
"os"
13+
"slices"
14+
"time"
15+
16+
"github.com/charmbracelet/catwalk/pkg/catwalk"
17+
)
18+
19+
// SupportedProviders defines which providers we want to support.
20+
// Add or remove providers from this slice to control which ones are included.
21+
var SupportedProviders = []string{
22+
// "together", // Multiple issues
23+
"fireworks-ai",
24+
// "nebius", // Tool Call ID not unique
25+
// "novita", // Usage report is wrong
26+
"groq",
27+
"cerebras",
28+
// "hyperbolic",
29+
// "nscale",
30+
// "sambanova",
31+
// "cohere",
32+
"hf-inference",
33+
}
34+
35+
// HFModel represents a model from the Hugging Face Router API.
36+
type HFModel struct {
37+
ID string `json:"id"`
38+
Object string `json:"object"`
39+
Created int64 `json:"created"`
40+
OwnedBy string `json:"owned_by"`
41+
Providers []HFProvider `json:"providers"`
42+
}
43+
44+
// HFProvider represents a provider configuration for a model.
45+
type HFProvider struct {
46+
Provider string `json:"provider"`
47+
Status string `json:"status"`
48+
ContextLength int64 `json:"context_length,omitempty"`
49+
Pricing *HFPricing `json:"pricing,omitempty"`
50+
SupportsTools bool `json:"supports_tools"`
51+
SupportsStructuredOutput bool `json:"supports_structured_output"`
52+
}
53+
54+
// HFPricing contains the pricing information for a provider.
55+
type HFPricing struct {
56+
Input float64 `json:"input"`
57+
Output float64 `json:"output"`
58+
}
59+
60+
// HFModelsResponse is the response structure for the Hugging Face Router models API.
61+
type HFModelsResponse struct {
62+
Object string `json:"object"`
63+
Data []HFModel `json:"data"`
64+
}
65+
66+
func fetchHuggingFaceModels() (*HFModelsResponse, error) {
67+
client := &http.Client{Timeout: 30 * time.Second}
68+
req, _ := http.NewRequestWithContext(
69+
context.Background(),
70+
"GET",
71+
"https://router.huggingface.co/v1/models",
72+
nil,
73+
)
74+
req.Header.Set("User-Agent", "Crush-Client/1.0")
75+
resp, err := client.Do(req)
76+
if err != nil {
77+
return nil, err //nolint:wrapcheck
78+
}
79+
defer resp.Body.Close() //nolint:errcheck
80+
if resp.StatusCode != 200 {
81+
body, _ := io.ReadAll(resp.Body)
82+
return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
83+
}
84+
var mr HFModelsResponse
85+
if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
86+
return nil, err //nolint:wrapcheck
87+
}
88+
return &mr, nil
89+
}
90+
91+
// findContextWindow looks for a context window from any provider for the given model.
92+
func findContextWindow(model HFModel) int64 {
93+
for _, provider := range model.Providers {
94+
if provider.ContextLength > 0 {
95+
return provider.ContextLength
96+
}
97+
}
98+
return 0
99+
}
100+
101+
func main() {
102+
modelsResp, err := fetchHuggingFaceModels()
103+
if err != nil {
104+
log.Fatal("Error fetching Hugging Face models:", err)
105+
}
106+
107+
hfProvider := catwalk.Provider{
108+
Name: "Hugging Face",
109+
ID: catwalk.InferenceProviderHuggingFace,
110+
APIKey: "$HUGGINGFACE_API_KEY",
111+
APIEndpoint: "https://router.huggingface.co/v1",
112+
Type: catwalk.TypeOpenAI,
113+
DefaultLargeModelID: "moonshotai/Kimi-K2-Instruct-0905:groq",
114+
DefaultSmallModelID: "openai/gpt-oss-20b",
115+
Models: []catwalk.Model{},
116+
DefaultHeaders: map[string]string{
117+
"HTTP-Referer": "https://charm.land",
118+
"X-Title": "Crush",
119+
},
120+
}
121+
122+
for _, model := range modelsResp.Data {
123+
// Find context window from any provider for this model
124+
fallbackContextLength := findContextWindow(model)
125+
if fallbackContextLength == 0 {
126+
fmt.Printf("Skipping model %s - no context window found in any provider\n", model.ID)
127+
continue
128+
}
129+
130+
for _, provider := range model.Providers {
131+
// Skip unsupported providers
132+
if !slices.Contains(SupportedProviders, provider.Provider) {
133+
continue
134+
}
135+
136+
// Skip providers that don't support tools
137+
if !provider.SupportsTools {
138+
continue
139+
}
140+
141+
// Skip non-live providers
142+
if provider.Status != "live" {
143+
continue
144+
}
145+
146+
// Create model with provider-specific ID and name
147+
modelID := fmt.Sprintf("%s:%s", model.ID, provider.Provider)
148+
modelName := fmt.Sprintf("%s (%s)", model.ID, provider.Provider)
149+
150+
// Use provider's context length, or fallback if not available
151+
contextLength := provider.ContextLength
152+
if contextLength == 0 {
153+
contextLength = fallbackContextLength
154+
}
155+
156+
// Calculate pricing (convert from per-token to per-1M tokens)
157+
var costPer1MIn, costPer1MOut float64
158+
if provider.Pricing != nil {
159+
costPer1MIn = provider.Pricing.Input
160+
costPer1MOut = provider.Pricing.Output
161+
}
162+
163+
// Set default max tokens (conservative estimate)
164+
defaultMaxTokens := min(contextLength/4, 10000)
165+
166+
m := catwalk.Model{
167+
ID: modelID,
168+
Name: modelName,
169+
CostPer1MIn: costPer1MIn,
170+
CostPer1MOut: costPer1MOut,
171+
CostPer1MInCached: 0, // Not provided by HF Router
172+
CostPer1MOutCached: 0, // Not provided by HF Router
173+
ContextWindow: contextLength,
174+
DefaultMaxTokens: defaultMaxTokens,
175+
CanReason: false, // Not provided by HF Router
176+
SupportsImages: false, // Not provided by HF Router
177+
}
178+
179+
hfProvider.Models = append(hfProvider.Models, m)
180+
fmt.Printf("Added model %s with context window %d from provider %s\n",
181+
modelID, contextLength, provider.Provider)
182+
}
183+
}
184+
185+
// Save the JSON in internal/providers/configs/huggingface.json
186+
data, err := json.MarshalIndent(hfProvider, "", " ")
187+
if err != nil {
188+
log.Fatal("Error marshaling Hugging Face provider:", err)
189+
}
190+
191+
if err := os.WriteFile("internal/providers/configs/huggingface.json", data, 0o600); err != nil {
192+
log.Fatal("Error writing Hugging Face provider config:", err)
193+
}
194+
195+
fmt.Printf("Generated huggingface.json with %d models\n", len(hfProvider.Models))
196+
}

0 commit comments

Comments
 (0)