Skip to content

Commit 769ec59

Browse files
tidelyas-ciinathansobo
authored
ollama: Add tool call support (#29563)
The goal of this PR is to support tool calls using ollama. A lot of the serialization work was done in #15803 however the abstraction over language models always disables tools. ## Changelog: - Use `serde_json::Value` inside `OllamaFunctionCall` just as it's used in `OllamaFunctionCall`. This fixes deserialization of ollama tool calls. - Added deserialization tests using json from official ollama api docs. - Fetch model capabilities during model enumeration from ollama provider - Added `supports_tools` setting to manually configure if a model supports tools ## TODO: - [x] Fix tool call serialization/deserialization - [x] Fetch model capabilities from ollama api - [x] Add tests for parsing model capabilities - [ ] Documentation for `supports_tools` field for ollama language model config - [ ] Convert between generic language model types - [x] Pass tools to ollama Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <[email protected]> Co-authored-by: Nathan Sobo <[email protected]>
1 parent e961625 commit 769ec59

File tree

3 files changed

+360
-88
lines changed

3 files changed

+360
-88
lines changed

crates/assistant_settings/src/assistant_settings.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,12 @@ impl AssistantSettingsContent {
315315
_ => None,
316316
};
317317
settings.provider = Some(AssistantProviderContentV1::Ollama {
318-
default_model: Some(ollama::Model::new(&model, None, None)),
318+
default_model: Some(ollama::Model::new(
319+
&model,
320+
None,
321+
None,
322+
language_model.supports_tools(),
323+
)),
319324
api_url,
320325
});
321326
}

crates/language_models/src/provider/ollama.rs

Lines changed: 133 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
use anyhow::{Result, anyhow};
22
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3+
use futures::{Stream, TryFutureExt, stream};
34
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
45
use http_client::HttpClient;
56
use language_model::{
67
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
8+
LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason,
79
};
810
use language_model::{
911
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
1012
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
1113
LanguageModelRequest, RateLimiter, Role,
1214
};
1315
use ollama::{
14-
ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
15-
stream_chat_completion,
16+
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
17+
OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
1618
};
1719
use schemars::JsonSchema;
1820
use serde::{Deserialize, Serialize};
1921
use settings::{Settings, SettingsStore};
22+
use std::pin::Pin;
23+
use std::sync::atomic::{AtomicU64, Ordering};
2024
use std::{collections::BTreeMap, sync::Arc};
2125
use ui::{ButtonLike, Indicator, List, prelude::*};
2226
use util::ResultExt;
@@ -47,6 +51,8 @@ pub struct AvailableModel {
4751
pub max_tokens: usize,
4852
/// The number of seconds to keep the connection open after the last request
4953
pub keep_alive: Option<KeepAlive>,
54+
/// Whether the model supports tools
55+
pub supports_tools: bool,
5056
}
5157

5258
pub struct OllamaLanguageModelProvider {
@@ -68,26 +74,44 @@ impl State {
6874

6975
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
7076
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
71-
let http_client = self.http_client.clone();
77+
let http_client = Arc::clone(&self.http_client);
7278
let api_url = settings.api_url.clone();
7379

7480
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
7581
cx.spawn(async move |this, cx| {
7682
let models = get_models(http_client.as_ref(), &api_url, None).await?;
7783

78-
let mut models: Vec<ollama::Model> = models
84+
let tasks = models
7985
.into_iter()
8086
// Since there is no metadata from the Ollama API
8187
// indicating which models are embedding models,
8288
// simply filter out models with "-embed" in their name
8389
.filter(|model| !model.name.contains("-embed"))
84-
.map(|model| ollama::Model::new(&model.name, None, None))
85-
.collect();
90+
.map(|model| {
91+
let http_client = Arc::clone(&http_client);
92+
let api_url = api_url.clone();
93+
async move {
94+
let name = model.name.as_str();
95+
let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
96+
let ollama_model =
97+
ollama::Model::new(name, None, None, capabilities.supports_tools());
98+
Ok(ollama_model)
99+
}
100+
});
101+
102+
// Rate-limit capability fetches
103+
// since there is an arbitrary number of models available
104+
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
105+
.buffer_unordered(5)
106+
.collect::<Vec<Result<_>>>()
107+
.await
108+
.into_iter()
109+
.collect::<Result<Vec<_>>>()?;
86110

87-
models.sort_by(|a, b| a.name.cmp(&b.name));
111+
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
88112

89113
this.update(cx, |this, cx| {
90-
this.available_models = models;
114+
this.available_models = ollama_models;
91115
cx.notify();
92116
})
93117
})
@@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
189213
display_name: model.display_name.clone(),
190214
max_tokens: model.max_tokens,
191215
keep_alive: model.keep_alive.clone(),
216+
supports_tools: model.supports_tools,
192217
},
193218
);
194219
}
@@ -269,7 +294,7 @@ impl OllamaLanguageModel {
269294
temperature: request.temperature.or(Some(1.0)),
270295
..Default::default()
271296
}),
272-
tools: vec![],
297+
tools: request.tools.into_iter().map(tool_into_ollama).collect(),
273298
}
274299
}
275300
}
@@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel {
292317
}
293318

294319
fn supports_tools(&self) -> bool {
295-
false
320+
self.model.supports_tools
296321
}
297322

298323
fn telemetry_id(&self) -> String {
@@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel {
341366
};
342367

343368
let future = self.request_limiter.stream(async move {
344-
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
345-
let stream = response
346-
.filter_map(|response| async move {
347-
match response {
348-
Ok(delta) => {
349-
let content = match delta.message {
350-
ChatMessage::User { content } => content,
351-
ChatMessage::Assistant { content, .. } => content,
352-
ChatMessage::System { content } => content,
353-
};
354-
Some(Ok(content))
355-
}
356-
Err(error) => Some(Err(error)),
357-
}
358-
})
359-
.boxed();
369+
let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
370+
let stream = map_to_language_model_completion_events(stream);
360371
Ok(stream)
361372
});
362373

363-
async move {
364-
Ok(future
365-
.await?
366-
.map(|result| {
367-
result
368-
.map(LanguageModelCompletionEvent::Text)
369-
.map_err(LanguageModelCompletionError::Other)
370-
})
371-
.boxed())
372-
}
373-
.boxed()
374+
future.map_ok(|f| f.boxed()).boxed()
374375
}
375376
}
376377

378+
fn map_to_language_model_completion_events(
379+
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
380+
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
381+
// Used for creating unique tool use ids
382+
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
383+
384+
struct State {
385+
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
386+
used_tools: bool,
387+
}
388+
389+
// We need to create a ToolUse and Stop event from a single
390+
// response from the original stream
391+
let stream = stream::unfold(
392+
State {
393+
stream,
394+
used_tools: false,
395+
},
396+
async move |mut state| {
397+
let response = state.stream.next().await?;
398+
399+
let delta = match response {
400+
Ok(delta) => delta,
401+
Err(e) => {
402+
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
403+
return Some((vec![event], state));
404+
}
405+
};
406+
407+
let mut events = Vec::new();
408+
409+
match delta.message {
410+
ChatMessage::User { content } => {
411+
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
412+
}
413+
ChatMessage::System { content } => {
414+
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
415+
}
416+
ChatMessage::Assistant {
417+
content,
418+
tool_calls,
419+
} => {
420+
// Check for tool calls
421+
if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
422+
match tool_call {
423+
OllamaToolCall::Function(function) => {
424+
let tool_id = format!(
425+
"{}-{}",
426+
&function.name,
427+
TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
428+
);
429+
let event =
430+
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
431+
id: LanguageModelToolUseId::from(tool_id),
432+
name: Arc::from(function.name),
433+
raw_input: function.arguments.to_string(),
434+
input: function.arguments,
435+
is_input_complete: true,
436+
});
437+
events.push(Ok(event));
438+
state.used_tools = true;
439+
}
440+
}
441+
} else {
442+
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
443+
}
444+
}
445+
};
446+
447+
if delta.done {
448+
if state.used_tools {
449+
state.used_tools = false;
450+
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
451+
} else {
452+
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
453+
}
454+
}
455+
456+
Some((events, state))
457+
},
458+
);
459+
460+
stream.flat_map(futures::stream::iter)
461+
}
462+
377463
struct ConfigurationView {
378464
state: gpui::Entity<State>,
379465
loading_models_task: Option<Task<()>>,
@@ -509,3 +595,13 @@ impl Render for ConfigurationView {
509595
}
510596
}
511597
}
598+
599+
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
600+
ollama::OllamaTool::Function {
601+
function: OllamaFunctionTool {
602+
name: tool.name,
603+
description: Some(tool.description),
604+
parameters: Some(tool.input_schema),
605+
},
606+
}
607+
}

0 commit comments

Comments
 (0)