Skip to content

Commit 6b84d77

Browse files
authored
Fix OpenAI init model handling and defaults (#181)
1 parent feeb909 commit 6b84d77

File tree

4 files changed

+165
-19
lines changed

4 files changed

+165
-19
lines changed

cli/init.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ This command will:
4040

4141
func init() {
4242
initCmd.Flags().StringVarP(&initProvider, "provider", "p", "", "Embedding provider (ollama, lmstudio, openai, synthetic, or openrouter)")
43-
initCmd.Flags().StringVarP(&initModel, "model", "m", "", "Embedding model (for openrouter: text-embedding-3-small, text-embedding-3-large, qwen3-embedding-8b)")
43+
initCmd.Flags().StringVarP(&initModel, "model", "m", "", "Embedding model (for openai/openrouter: text-embedding-3-small, text-embedding-3-large; openrouter also supports qwen3-embedding-8b)")
4444
initCmd.Flags().StringVarP(&initBackend, "backend", "b", "", "Storage backend (gob, postgres, or qdrant)")
4545
initCmd.Flags().BoolVar(&initNonInteractive, "yes", false, "Use defaults without prompting")
4646
initCmd.Flags().BoolVar(&initInherit, "inherit", false, "Inherit configuration from main worktree (for git worktrees)")
@@ -143,8 +143,9 @@ func runInit(cmd *cobra.Command, args []string) error {
143143
cfg.Embedder.Dimensions = &dim
144144
case "3", "openai":
145145
cfg.Embedder.Provider = "openai"
146-
cfg.Embedder.Model = "text-embedding-3-small"
146+
cfg.Embedder.Model = config.DefaultOpenAIEmbeddingModel
147147
cfg.Embedder.Endpoint = "https://api.openai.com/v1"
148+
cfg.Embedder.Parallelism = config.DefaultOpenAIParallelism
148149
// OpenAI: leave Dimensions nil to use model's native dimensions
149150
case "4", "synthetic":
150151
cfg.Embedder.Provider = "synthetic"
@@ -194,16 +195,17 @@ func runInit(cmd *cobra.Command, args []string) error {
194195
dim := lmStudioEmbeddingDimensions
195196
cfg.Embedder.Dimensions = &dim
196197
case "openai":
197-
cfg.Embedder.Model = "text-embedding-3-small"
198+
cfg.Embedder.Model = resolveInitModel(initProvider, initModel)
198199
cfg.Embedder.Endpoint = "https://api.openai.com/v1"
200+
cfg.Embedder.Parallelism = config.DefaultOpenAIParallelism
199201
// OpenAI: leave Dimensions nil to use model's native dimensions
200202
case "synthetic":
201203
cfg.Embedder.Model = "hf:nomic-ai/nomic-embed-text-v1.5"
202204
cfg.Embedder.Endpoint = "https://api.synthetic.new/openai/v1"
203205
dim := 768
204206
cfg.Embedder.Dimensions = &dim
205207
case "openrouter":
206-
cfg.Embedder.Model = "openai/text-embedding-3-small"
208+
cfg.Embedder.Model = resolveInitModel(initProvider, initModel)
207209
cfg.Embedder.Endpoint = "https://openrouter.ai/api/v1"
208210
// OpenRouter: leave Dimensions nil to use model's native dimensions
209211
}
@@ -280,9 +282,10 @@ func runInit(cmd *cobra.Command, args []string) error {
280282
dim := lmStudioEmbeddingDimensions
281283
cfg.Embedder.Dimensions = &dim
282284
case "openai":
283-
cfg.Embedder.Model = "text-embedding-3-small"
285+
cfg.Embedder.Model = resolveInitModel(initProvider, initModel)
284286
cfg.Embedder.Endpoint = "https://api.openai.com/v1"
285287
cfg.Embedder.Dimensions = nil
288+
cfg.Embedder.Parallelism = config.DefaultOpenAIParallelism
286289
case "synthetic":
287290
cfg.Embedder.Model = "hf:nomic-ai/nomic-embed-text-v1.5"
288291
cfg.Embedder.Endpoint = "https://api.synthetic.new/openai/v1"
@@ -291,15 +294,7 @@ func runInit(cmd *cobra.Command, args []string) error {
291294
case "openrouter":
292295
cfg.Embedder.Endpoint = "https://openrouter.ai/api/v1"
293296
cfg.Embedder.Dimensions = nil
294-
// Use provided model flag or default
295-
switch initModel {
296-
case "text-embedding-3-large":
297-
cfg.Embedder.Model = "openai/text-embedding-3-large"
298-
case "qwen3-embedding-8b":
299-
cfg.Embedder.Model = "qwen/qwen3-embedding-8b"
300-
default:
301-
cfg.Embedder.Model = "openai/text-embedding-3-small"
302-
}
297+
cfg.Embedder.Model = resolveInitModel(initProvider, initModel)
303298
}
304299
}
305300
if initBackend != "" {
@@ -353,3 +348,27 @@ func runInit(cmd *cobra.Command, args []string) error {
353348
func shouldPromptInheritChoice(shouldInherit, nonInteractive, uiMode bool) bool {
354349
return !shouldInherit && !nonInteractive && !uiMode
355350
}
351+
352+
func resolveInitModel(provider, requestedModel string) string {
353+
requestedModel = strings.TrimSpace(requestedModel)
354+
switch provider {
355+
case "openai":
356+
if requestedModel != "" {
357+
return requestedModel
358+
}
359+
return config.DefaultOpenAIEmbeddingModel
360+
case "openrouter":
361+
switch requestedModel {
362+
case "text-embedding-3-large":
363+
return config.OpenRouterEmbeddingModelLarge
364+
case "qwen3-embedding-8b":
365+
return config.OpenRouterEmbeddingModelQwen8B
366+
case "text-embedding-3-small", "":
367+
return config.DefaultOpenRouterEmbeddingModel
368+
default:
369+
return requestedModel
370+
}
371+
default:
372+
return requestedModel
373+
}
374+
}

cli/init_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package cli
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/yoanbernabeu/grepai/config"
8+
)
9+
10+
func withInitTestState(t *testing.T, dir string, configure func()) {
11+
t.Helper()
12+
13+
prevCwd, err := os.Getwd()
14+
if err != nil {
15+
t.Fatalf("getwd: %v", err)
16+
}
17+
if err := os.Chdir(dir); err != nil {
18+
t.Fatalf("chdir(%s): %v", dir, err)
19+
}
20+
21+
prevProvider := initProvider
22+
prevModel := initModel
23+
prevBackend := initBackend
24+
prevNonInteractive := initNonInteractive
25+
prevInherit := initInherit
26+
prevUI := initUI
27+
28+
initProvider = ""
29+
initModel = ""
30+
initBackend = ""
31+
initNonInteractive = false
32+
initInherit = false
33+
initUI = false
34+
configure()
35+
36+
t.Cleanup(func() {
37+
_ = os.Chdir(prevCwd)
38+
initProvider = prevProvider
39+
initModel = prevModel
40+
initBackend = prevBackend
41+
initNonInteractive = prevNonInteractive
42+
initInherit = prevInherit
43+
initUI = prevUI
44+
})
45+
}
46+
47+
func TestRunInit_OpenAIExplicitModelHonored(t *testing.T) {
48+
tmpDir := t.TempDir()
49+
withInitTestState(t, tmpDir, func() {
50+
initProvider = "openai"
51+
initModel = "text-embedding-3-large"
52+
initBackend = "gob"
53+
initNonInteractive = true
54+
})
55+
56+
if err := runInit(nil, nil); err != nil {
57+
t.Fatalf("runInit: %v", err)
58+
}
59+
60+
cfg, err := config.Load(tmpDir)
61+
if err != nil {
62+
t.Fatalf("config.Load: %v", err)
63+
}
64+
if cfg.Embedder.Model != "text-embedding-3-large" {
65+
t.Fatalf("model = %q, want text-embedding-3-large", cfg.Embedder.Model)
66+
}
67+
if cfg.Embedder.Parallelism != config.DefaultOpenAIParallelism {
68+
t.Fatalf("parallelism = %d, want %d", cfg.Embedder.Parallelism, config.DefaultOpenAIParallelism)
69+
}
70+
if cfg.Embedder.GetDimensions() != config.DefaultOpenAILargeDimensions {
71+
t.Fatalf("dimensions = %d, want %d", cfg.Embedder.GetDimensions(), config.DefaultOpenAILargeDimensions)
72+
}
73+
}
74+
75+
func TestRunInit_OpenAIDefaultsToOpenAISmallModel(t *testing.T) {
76+
tmpDir := t.TempDir()
77+
withInitTestState(t, tmpDir, func() {
78+
initProvider = "openai"
79+
initBackend = "gob"
80+
initNonInteractive = true
81+
})
82+
83+
if err := runInit(nil, nil); err != nil {
84+
t.Fatalf("runInit: %v", err)
85+
}
86+
87+
cfg, err := config.Load(tmpDir)
88+
if err != nil {
89+
t.Fatalf("config.Load: %v", err)
90+
}
91+
if cfg.Embedder.Model != config.DefaultOpenAIEmbeddingModel {
92+
t.Fatalf("model = %q, want %q", cfg.Embedder.Model, config.DefaultOpenAIEmbeddingModel)
93+
}
94+
if cfg.Embedder.Parallelism != config.DefaultOpenAIParallelism {
95+
t.Fatalf("parallelism = %d, want %d", cfg.Embedder.Parallelism, config.DefaultOpenAIParallelism)
96+
}
97+
}

config/config.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ const (
2424
DefaultOpenAIEmbeddingModel = "text-embedding-3-small"
2525
DefaultSyntheticEmbeddingModel = "hf:nomic-ai/nomic-embed-text-v1.5"
2626
DefaultOpenRouterEmbeddingModel = "openai/text-embedding-3-small"
27+
OpenAIEmbeddingModelLarge = "text-embedding-3-large"
28+
OpenRouterEmbeddingModelLarge = "openai/text-embedding-3-large"
29+
OpenRouterEmbeddingModelQwen8B = "qwen/qwen3-embedding-8b"
2730

2831
DefaultOllamaEndpoint = "http://localhost:11434"
2932
DefaultLMStudioEndpoint = "http://127.0.0.1:1234"
@@ -33,6 +36,9 @@ const (
3336

3437
DefaultLocalEmbeddingDimensions = 768
3538
DefaultOpenAIDimensions = 1536
39+
DefaultOpenAILargeDimensions = 3072
40+
DefaultQwen8BDimensions = 4096
41+
DefaultOpenAIParallelism = 4
3642

3743
DefaultPostgresDSN = "postgres://localhost:5432/grepai"
3844
DefaultQdrantEndpoint = "localhost"
@@ -110,7 +116,14 @@ func (e *EmbedderConfig) GetDimensions() int {
110116
}
111117
switch e.Provider {
112118
case "openai", "openrouter":
113-
return DefaultOpenAIDimensions
119+
switch strings.TrimSpace(e.Model) {
120+
case OpenAIEmbeddingModelLarge, OpenRouterEmbeddingModelLarge:
121+
return DefaultOpenAILargeDimensions
122+
case OpenRouterEmbeddingModelQwen8B, "qwen3-embedding-8b":
123+
return DefaultQwen8BDimensions
124+
default:
125+
return DefaultOpenAIDimensions
126+
}
114127
default:
115128
return DefaultLocalEmbeddingDimensions
116129
}
@@ -143,10 +156,11 @@ func DefaultEmbedderForProvider(provider string) EmbedderConfig {
143156
}
144157
case "openai":
145158
return EmbedderConfig{
146-
Provider: "openai",
147-
Model: DefaultOpenAIEmbeddingModel,
148-
Endpoint: DefaultOpenAIEndpoint,
149-
Dimensions: nil,
159+
Provider: "openai",
160+
Model: DefaultOpenAIEmbeddingModel,
161+
Endpoint: DefaultOpenAIEndpoint,
162+
Dimensions: nil,
163+
Parallelism: DefaultOpenAIParallelism,
150164
}
151165
case "ollama":
152166
fallthrough

config/config_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ func TestDefaultEmbedderForProvider(t *testing.T) {
8484
if openai.Dimensions != nil {
8585
t.Fatalf("openai dimensions should be nil, got %v", openai.Dimensions)
8686
}
87+
if openai.Parallelism != DefaultOpenAIParallelism {
88+
t.Fatalf("openai parallelism = %d, want %d", openai.Parallelism, DefaultOpenAIParallelism)
89+
}
8790
}
8891

8992
func TestDefaultStoreForBackend(t *testing.T) {
@@ -477,6 +480,19 @@ store:
477480
expectedNil: true,
478481
expectedDimensions: 1536, // GetDimensions() returns default
479482
},
483+
{
484+
name: "openai large without dimensions infers large dimensions",
485+
configYAML: `version: 1
486+
embedder:
487+
provider: openai
488+
model: text-embedding-3-large
489+
api_key: sk-test
490+
store:
491+
backend: gob
492+
`,
493+
expectedNil: true,
494+
expectedDimensions: 3072,
495+
},
480496
{
481497
name: "openai with explicit dimensions sets pointer",
482498
configYAML: `version: 1

0 commit comments

Comments
 (0)