1
1
use anyhow:: { Result , anyhow} ;
2
2
use futures:: { FutureExt , StreamExt , future:: BoxFuture , stream:: BoxStream } ;
3
+ use futures:: { Stream , TryFutureExt , stream} ;
3
4
use gpui:: { AnyView , App , AsyncApp , Context , Subscription , Task } ;
4
5
use http_client:: HttpClient ;
5
6
use language_model:: {
6
7
AuthenticateError , LanguageModelCompletionError , LanguageModelCompletionEvent ,
8
+ LanguageModelRequestTool , LanguageModelToolUse , LanguageModelToolUseId , StopReason ,
7
9
} ;
8
10
use language_model:: {
9
11
LanguageModel , LanguageModelId , LanguageModelName , LanguageModelProvider ,
10
12
LanguageModelProviderId , LanguageModelProviderName , LanguageModelProviderState ,
11
13
LanguageModelRequest , RateLimiter , Role ,
12
14
} ;
13
15
use ollama:: {
14
- ChatMessage , ChatOptions , ChatRequest , KeepAlive , get_models , preload_model ,
15
- stream_chat_completion,
16
+ ChatMessage , ChatOptions , ChatRequest , ChatResponseDelta , KeepAlive , OllamaFunctionTool ,
17
+ OllamaToolCall , get_models , preload_model , show_model , stream_chat_completion,
16
18
} ;
17
19
use schemars:: JsonSchema ;
18
20
use serde:: { Deserialize , Serialize } ;
19
21
use settings:: { Settings , SettingsStore } ;
22
+ use std:: pin:: Pin ;
23
+ use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
20
24
use std:: { collections:: BTreeMap , sync:: Arc } ;
21
25
use ui:: { ButtonLike , Indicator , List , prelude:: * } ;
22
26
use util:: ResultExt ;
@@ -47,6 +51,8 @@ pub struct AvailableModel {
47
51
pub max_tokens : usize ,
48
52
/// The number of seconds to keep the connection open after the last request
49
53
pub keep_alive : Option < KeepAlive > ,
54
+ /// Whether the model supports tools
55
+ pub supports_tools : bool ,
50
56
}
51
57
52
58
pub struct OllamaLanguageModelProvider {
@@ -68,26 +74,44 @@ impl State {
68
74
69
75
fn fetch_models ( & mut self , cx : & mut Context < Self > ) -> Task < Result < ( ) > > {
70
76
let settings = & AllLanguageModelSettings :: get_global ( cx) . ollama ;
71
- let http_client = self . http_client . clone ( ) ;
77
+ let http_client = Arc :: clone ( & self . http_client ) ;
72
78
let api_url = settings. api_url . clone ( ) ;
73
79
74
80
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
75
81
cx. spawn ( async move |this, cx| {
76
82
let models = get_models ( http_client. as_ref ( ) , & api_url, None ) . await ?;
77
83
78
- let mut models : Vec < ollama :: Model > = models
84
+ let tasks = models
79
85
. into_iter ( )
80
86
// Since there is no metadata from the Ollama API
81
87
// indicating which models are embedding models,
82
88
// simply filter out models with "-embed" in their name
83
89
. filter ( |model| !model. name . contains ( "-embed" ) )
84
- . map ( |model| ollama:: Model :: new ( & model. name , None , None ) )
85
- . collect ( ) ;
90
+ . map ( |model| {
91
+ let http_client = Arc :: clone ( & http_client) ;
92
+ let api_url = api_url. clone ( ) ;
93
+ async move {
94
+ let name = model. name . as_str ( ) ;
95
+ let capabilities = show_model ( http_client. as_ref ( ) , & api_url, name) . await ?;
96
+ let ollama_model =
97
+ ollama:: Model :: new ( name, None , None , capabilities. supports_tools ( ) ) ;
98
+ Ok ( ollama_model)
99
+ }
100
+ } ) ;
101
+
102
+ // Rate-limit capability fetches
103
+ // since there is an arbitrary number of models available
104
+ let mut ollama_models: Vec < _ > = futures:: stream:: iter ( tasks)
105
+ . buffer_unordered ( 5 )
106
+ . collect :: < Vec < Result < _ > > > ( )
107
+ . await
108
+ . into_iter ( )
109
+ . collect :: < Result < Vec < _ > > > ( ) ?;
86
110
87
- models . sort_by ( |a, b| a. name . cmp ( & b. name ) ) ;
111
+ ollama_models . sort_by ( |a, b| a. name . cmp ( & b. name ) ) ;
88
112
89
113
this. update ( cx, |this, cx| {
90
- this. available_models = models ;
114
+ this. available_models = ollama_models ;
91
115
cx. notify ( ) ;
92
116
} )
93
117
} )
@@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
189
213
display_name : model. display_name . clone ( ) ,
190
214
max_tokens : model. max_tokens ,
191
215
keep_alive : model. keep_alive . clone ( ) ,
216
+ supports_tools : model. supports_tools ,
192
217
} ,
193
218
) ;
194
219
}
@@ -269,7 +294,7 @@ impl OllamaLanguageModel {
269
294
temperature : request. temperature . or ( Some ( 1.0 ) ) ,
270
295
..Default :: default ( )
271
296
} ) ,
272
- tools : vec ! [ ] ,
297
+ tools : request . tools . into_iter ( ) . map ( tool_into_ollama ) . collect ( ) ,
273
298
}
274
299
}
275
300
}
@@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel {
292
317
}
293
318
294
319
fn supports_tools ( & self ) -> bool {
295
- false
320
+ self . model . supports_tools
296
321
}
297
322
298
323
fn telemetry_id ( & self ) -> String {
@@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel {
341
366
} ;
342
367
343
368
let future = self . request_limiter . stream ( async move {
344
- let response = stream_chat_completion ( http_client. as_ref ( ) , & api_url, request) . await ?;
345
- let stream = response
346
- . filter_map ( |response| async move {
347
- match response {
348
- Ok ( delta) => {
349
- let content = match delta. message {
350
- ChatMessage :: User { content } => content,
351
- ChatMessage :: Assistant { content, .. } => content,
352
- ChatMessage :: System { content } => content,
353
- } ;
354
- Some ( Ok ( content) )
355
- }
356
- Err ( error) => Some ( Err ( error) ) ,
357
- }
358
- } )
359
- . boxed ( ) ;
369
+ let stream = stream_chat_completion ( http_client. as_ref ( ) , & api_url, request) . await ?;
370
+ let stream = map_to_language_model_completion_events ( stream) ;
360
371
Ok ( stream)
361
372
} ) ;
362
373
363
- async move {
364
- Ok ( future
365
- . await ?
366
- . map ( |result| {
367
- result
368
- . map ( LanguageModelCompletionEvent :: Text )
369
- . map_err ( LanguageModelCompletionError :: Other )
370
- } )
371
- . boxed ( ) )
372
- }
373
- . boxed ( )
374
+ future. map_ok ( |f| f. boxed ( ) ) . boxed ( )
374
375
}
375
376
}
376
377
378
+ fn map_to_language_model_completion_events (
379
+ stream : Pin < Box < dyn Stream < Item = anyhow:: Result < ChatResponseDelta > > + Send > > ,
380
+ ) -> impl Stream < Item = Result < LanguageModelCompletionEvent , LanguageModelCompletionError > > {
381
+ // Used for creating unique tool use ids
382
+ static TOOL_CALL_COUNTER : AtomicU64 = AtomicU64 :: new ( 0 ) ;
383
+
384
+ struct State {
385
+ stream : Pin < Box < dyn Stream < Item = anyhow:: Result < ChatResponseDelta > > + Send > > ,
386
+ used_tools : bool ,
387
+ }
388
+
389
+ // We need to create a ToolUse and Stop event from a single
390
+ // response from the original stream
391
+ let stream = stream:: unfold (
392
+ State {
393
+ stream,
394
+ used_tools : false ,
395
+ } ,
396
+ async move |mut state| {
397
+ let response = state. stream . next ( ) . await ?;
398
+
399
+ let delta = match response {
400
+ Ok ( delta) => delta,
401
+ Err ( e) => {
402
+ let event = Err ( LanguageModelCompletionError :: Other ( anyhow ! ( e) ) ) ;
403
+ return Some ( ( vec ! [ event] , state) ) ;
404
+ }
405
+ } ;
406
+
407
+ let mut events = Vec :: new ( ) ;
408
+
409
+ match delta. message {
410
+ ChatMessage :: User { content } => {
411
+ events. push ( Ok ( LanguageModelCompletionEvent :: Text ( content) ) ) ;
412
+ }
413
+ ChatMessage :: System { content } => {
414
+ events. push ( Ok ( LanguageModelCompletionEvent :: Text ( content) ) ) ;
415
+ }
416
+ ChatMessage :: Assistant {
417
+ content,
418
+ tool_calls,
419
+ } => {
420
+ // Check for tool calls
421
+ if let Some ( tool_call) = tool_calls. and_then ( |v| v. into_iter ( ) . next ( ) ) {
422
+ match tool_call {
423
+ OllamaToolCall :: Function ( function) => {
424
+ let tool_id = format ! (
425
+ "{}-{}" ,
426
+ & function. name,
427
+ TOOL_CALL_COUNTER . fetch_add( 1 , Ordering :: Relaxed )
428
+ ) ;
429
+ let event =
430
+ LanguageModelCompletionEvent :: ToolUse ( LanguageModelToolUse {
431
+ id : LanguageModelToolUseId :: from ( tool_id) ,
432
+ name : Arc :: from ( function. name ) ,
433
+ raw_input : function. arguments . to_string ( ) ,
434
+ input : function. arguments ,
435
+ is_input_complete : true ,
436
+ } ) ;
437
+ events. push ( Ok ( event) ) ;
438
+ state. used_tools = true ;
439
+ }
440
+ }
441
+ } else {
442
+ events. push ( Ok ( LanguageModelCompletionEvent :: Text ( content) ) ) ;
443
+ }
444
+ }
445
+ } ;
446
+
447
+ if delta. done {
448
+ if state. used_tools {
449
+ state. used_tools = false ;
450
+ events. push ( Ok ( LanguageModelCompletionEvent :: Stop ( StopReason :: ToolUse ) ) ) ;
451
+ } else {
452
+ events. push ( Ok ( LanguageModelCompletionEvent :: Stop ( StopReason :: EndTurn ) ) ) ;
453
+ }
454
+ }
455
+
456
+ Some ( ( events, state) )
457
+ } ,
458
+ ) ;
459
+
460
+ stream. flat_map ( futures:: stream:: iter)
461
+ }
462
+
377
463
struct ConfigurationView {
378
464
state : gpui:: Entity < State > ,
379
465
loading_models_task : Option < Task < ( ) > > ,
@@ -509,3 +595,13 @@ impl Render for ConfigurationView {
509
595
}
510
596
}
511
597
}
598
+
599
+ fn tool_into_ollama ( tool : LanguageModelRequestTool ) -> ollama:: OllamaTool {
600
+ ollama:: OllamaTool :: Function {
601
+ function : OllamaFunctionTool {
602
+ name : tool. name ,
603
+ description : Some ( tool. description ) ,
604
+ parameters : Some ( tool. input_schema ) ,
605
+ } ,
606
+ }
607
+ }
0 commit comments