@@ -27,6 +27,7 @@ def __init__(
2727 options : Optional [Dict ] = None ,
2828 system_prompt : Optional [str ] = None ,
2929 system_prompt_file : Optional [str ] = None ,
30+ preload_model : Optional [bool ] = False ,
3031 api_base : Optional [str ] = None ,
3132 role : str = "user" ,
3233 headers : Optional [Mapping [str , str ]] = None ,
@@ -44,10 +45,11 @@ def __init__(
4445 """
4546 self .api_base = api_base or Settings .DEFAULT_OLLAMA_CLIENT
4647 self .headers = headers
48+ self .preload_model = preload_model
49+ self .options = options
4750 super ().__init__ (model_name , system_prompt , system_prompt_file , self .api_base )
4851 logging .info (f"Using Ollama with { model_name } model 🤖" )
4952 self .role : str = role
50- self .options = options
5153
5254 @override
5355 def load (self ) -> Client :
@@ -57,7 +59,15 @@ def load(self) -> Client:
5759 Returns:
5860 Client: An instance of the Ollama model client, configured with the necessary host and headers.
5961 """
60- return Client (host = self .api_base , headers = self .headers )
62+ ollama_client = Client (host = self .api_base , headers = self .headers )
63+
64+ if self .preload_model :
65+ ollama_client .chat (
66+ model = self .model_name ,
67+ messages = [],
68+ options = self .options ,
69+ )
70+ return ollama_client
6171
6272 @override
6373 def generate (self , input : Dict [str , Any ]) -> str :
0 commit comments