Skip to content

Commit 3826af3

Browse files
committed
improve ai.catalog config
1 parent 4bc2644 commit 3826af3

File tree

3 files changed

+58
-36
lines changed

3 files changed

+58
-36
lines changed

default_config.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,15 @@ enabled = false
230230
# Assign models to tasks by their alias (which must be defined in the Model Catalog below)
231231
[ai.models]
232232
# The alias of the model to use for generating command templates from natural language
233-
suggest = "gemini"
233+
suggest = "main"
234234
# The alias of the model used to fix or explain a failing command
235-
fix = "gemini"
235+
fix = "main"
236236
# The alias of the model to use when importing commands
237-
import = "gemini"
237+
import = "main"
238238
# The alias of the model to use when generating a command for a dynamic variable completion
239-
completion = "gemini"
239+
completion = "main"
240240
# The alias of a model to use as a fallback if the primary model fails due to rate limits
241-
fallback = "gemini-fallback"
241+
fallback = "fallback"
242242

243243
# --- Model Catalog ---
244244
# This is where you define the specific configuration for each AI model alias used above.
@@ -249,11 +249,11 @@ fallback = "gemini-fallback"
249249
# - "anthropic": ANTHROPIC_API_KEY
250250
# - "ollama": OLLAMA_API_KEY (often not required for local instances)
251251

252-
[ai.catalog.gemini]
252+
[ai.catalog.main]
253253
provider = "gemini"
254254
model = "gemini-2.5-flash"
255255

256-
[ai.catalog.gemini-fallback]
256+
[ai.catalog.fallback]
257257
provider = "gemini"
258258
model = "gemini-2.0-flash-lite"
259259

docs/src/configuration/ai.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,10 @@ environment variables associated with each one:
6565

6666
#### Configuration Examples
6767

68-
Here are some examples of how to configure different models in your catalog.
69-
Each model you define must be under `ai.catalog.<your-alias-name>`.
68+
Below are several examples for configuring different models within your `[ai.catalog]`.
7069

71-
> ⚠️ **IMPORTANT**
72-
>
73-
> When you add your first model to `[ai.catalog]`, it replaces the _entire_ default catalog. Therefore, you must ensure
74-
> the model aliases you create match the ones assigned in the `[ai.models]` section above.
70+
> **💡 Shortcut**: You can also just overwrite the default `main` and `fallback` aliases directly in your catalog. This
71+
> changes the model used without needing to edit `[ai.models]`.
7572
7673
- **OpenAI**
7774

src/config.rs

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ pub struct AiConfig {
299299
///
300300
/// Each entry maps a custom alias (e.g., `fast-model`, `smart-model`) to its specific provider settings. These
301301
/// aliases are then referenced by the `suggest`, `fix`, `import`, and `fallback` fields.
302+
#[serde(deserialize_with = "deserialize_catalog_with_defaults")]
302303
pub catalog: BTreeMap<String, AiModelConfig>,
303304
}
304305

@@ -842,41 +843,44 @@ impl Default for SearchVariableContextTuning {
842843
Self { points: 700 }
843844
}
844845
}
846+
fn default_ai_catalog() -> BTreeMap<String, AiModelConfig> {
847+
BTreeMap::from([
848+
(
849+
"main".to_string(),
850+
AiModelConfig::Gemini(GeminiModelConfig {
851+
model: "gemini-2.5-flash".to_string(),
852+
url: default_gemini_url(),
853+
api_key_env: default_gemini_api_key_env(),
854+
}),
855+
),
856+
(
857+
"fallback".to_string(),
858+
AiModelConfig::Gemini(GeminiModelConfig {
859+
model: "gemini-2.0-flash-lite".to_string(),
860+
url: default_gemini_url(),
861+
api_key_env: default_gemini_api_key_env(),
862+
}),
863+
),
864+
])
865+
}
845866
impl Default for AiConfig {
846867
fn default() -> Self {
847868
Self {
848869
enabled: false,
849870
models: AiModelsConfig::default(),
850871
prompts: AiPromptsConfig::default(),
851-
catalog: BTreeMap::from([
852-
(
853-
"gemini".to_string(),
854-
AiModelConfig::Gemini(GeminiModelConfig {
855-
model: "gemini-2.5-flash".to_string(),
856-
url: default_gemini_url(),
857-
api_key_env: default_gemini_api_key_env(),
858-
}),
859-
),
860-
(
861-
"gemini-fallback".to_string(),
862-
AiModelConfig::Gemini(GeminiModelConfig {
863-
model: "gemini-2.0-flash-lite".to_string(),
864-
url: default_gemini_url(),
865-
api_key_env: default_gemini_api_key_env(),
866-
}),
867-
),
868-
]),
872+
catalog: default_ai_catalog(),
869873
}
870874
}
871875
}
872876
impl Default for AiModelsConfig {
873877
fn default() -> Self {
874878
Self {
875-
suggest: "gemini".to_string(),
876-
fix: "gemini".to_string(),
877-
import: "gemini".to_string(),
878-
completion: "gemini".to_string(),
879-
fallback: "gemini-fallback".to_string(),
879+
suggest: "main".to_string(),
880+
fix: "main".to_string(),
881+
import: "main".to_string(),
882+
completion: "main".to_string(),
883+
fallback: "fallback".to_string(),
880884
}
881885
}
882886
}
@@ -1332,6 +1336,27 @@ fn parse_color_inner(raw: &str) -> Result<Color, String> {
13321336
})
13331337
}
13341338

1339+
/// Custom deserialization for the AI model catalog that merges user-defined models with default models.
1340+
///
1341+
/// User-defined models in the configuration file will override any defaults with the same name.
1342+
/// Any default models not defined by the user will be added to the final catalog.
1343+
fn deserialize_catalog_with_defaults<'de, D>(deserializer: D) -> Result<BTreeMap<String, AiModelConfig>, D::Error>
1344+
where
1345+
D: Deserializer<'de>,
1346+
{
1347+
#[allow(unused_mut)]
1348+
// Deserialize the map as provided in the user's config
1349+
let mut user_catalog = BTreeMap::<String, AiModelConfig>::deserialize(deserializer)?;
1350+
1351+
// Get the default catalog and merge it in
1352+
#[cfg(not(test))]
1353+
for (key, default_model) in default_ai_catalog() {
1354+
user_catalog.entry(key).or_insert(default_model);
1355+
}
1356+
1357+
Ok(user_catalog)
1358+
}
1359+
13351360
#[cfg(test)]
13361361
mod tests {
13371362
use pretty_assertions::assert_eq;

0 commit comments

Comments
 (0)