55
66from ..config .settings import Settings
77from .llm import LLM
8- from google .genai import Client
9- from google .genai import types
108
9+ from langchain_google_genai import ChatGoogleGenerativeAI
10+ from langchain_core .messages import AIMessage , HumanMessage , SystemMessage
1111
12- class GeminiModel (LLM ):
13- """
14- A subclass of LLM that uses Google's Gemini as the backend for text generation.
15-
16- This class provides an interface to interact with Google's Generative AI,
17- enabling text generation with various models supported by the Gemini API.
18-
19- Attributes:
20- model_name (str): The name of the model to use with Gemini.
21- role (str): The role of the user in the conversation, typically "user".
22- system_prompt (str): The system prompt to use for text generation.
23- model (Client): The Gemini client configured to interact with the API.
24- """
2512
13+ class GeminiModel (LLM ):
2614 def __init__ (
2715 self ,
2816 model_name : str ,
@@ -31,106 +19,49 @@ def __init__(
3119 api_base : Optional [str ] = None ,
3220 role : str = "user" ,
3321 ) -> None :
34- """
35- Initializes an instance of GeminiModel.
36-
37- Args:
38- model_name (str): The name of the model to use with Gemini.
39- system_prompt (Optional[str]): The system prompt to use. If not provided, it will be loaded from system_prompt_file or use the default value.
40- system_prompt_file (Optional[str]): The path to the file to load the system prompt from. If provided, it takes precedence over system_prompt.
41- role (str): The role of the user in the conversation, defaults to "user".
42- """
4322 self .api_base = api_base or Settings .DEFAULT_GOOGLE_CLIENT
4423 super ().__init__ (model_name , system_prompt , system_prompt_file , self .api_base )
4524 logging .info (f"Using Gemini with { model_name } model 🤖" )
4625 self .role : str = role
4726
4827 @override
49- def load (self ) -> Client :
50- """
51- Loads the Gemini client using the modern google.generativeai SDK.
52-
53- Returns:
54- Client: The client object to interact with Gemini API.
55- """
56- return Client (
57- api_key = Settings .GEMINI_API_KEY ,
58- http_options = types .HttpOptions (base_url = self .api_base ),
59- )
60-
61- @override
62- def generate (self , input : Dict [str , Any ]) -> str :
63- """
64- Generates text using the Gemini model.
65- It constructs a structured 'contents' payload using the 'types' module
66- as requested for proper input management.
67-
68- Args:
69- input (Dict[str, Any]): The input data for text generation.
70-
71- Returns:
72- str: The text generated by the model.
73- """
74- history = input .get ("history" , [])
75- contents = []
76-
77- for msg in history :
78- role = "model" if msg ["role" ] == "assistant" else "user"
79- contents .append (
80- types .Content (role = role , parts = [types .Part (text = msg ["content" ])])
81- )
82-
83- contents .append (
84- types .Content (
85- role = "user" , parts = [types .Part (text = input .get ("question" , "" ))]
86- )
28+ def load (self ) -> ChatGoogleGenerativeAI :
29+ return ChatGoogleGenerativeAI (
30+ model = self .model_name ,
31+ google_api_key = Settings .GEMINI_API_KEY ,
8732 )
8833
89- config = None
34+ def _build_messages (self , input : Dict [str , Any ]):
35+ messages = []
9036 if self .system_prompt :
91- config = types .GenerateContentConfig (system_instruction = self .system_prompt )
92-
93- try :
94- response = self .model .models .generate_content (
95- model = self .model_name ,
96- contents = contents ,
97- config = config ,
98- )
99- if not response .candidates :
100- logging .warning ("Response was blocked. Checking prompt feedback." )
101- if response .prompt_feedback :
102- logging .warning (f"Prompt Feedback: { response .prompt_feedback } " )
103- return "Response blocked due to safety settings."
104- return response .text
105- except Exception as e :
106- logging .error (f"An error occurred during Gemini content generation: { e } " )
107- return f"Error: { e } "
37+ messages .append (SystemMessage (content = self .system_prompt ))
38+ for msg in input .get ("history" , []):
39+ if msg ["role" ] == "assistant" :
40+ messages .append (AIMessage (content = msg ["content" ]))
41+ else :
42+ messages .append (HumanMessage (content = msg ["content" ]))
43+
44+ question = input .get ("question" , "" )
45+ if "images" in input :
46+ content = [{"type" : "text" , "text" : question }]
47+ for image in input ["images" ]:
48+ try :
49+ content .append ({"type" : "image_url" , "image_url" : f"data:image/jpeg;base64,{ image ['base64' ]} " })
50+ except Exception as e :
51+ logging .error (f"Could not read image: { e } " )
52+ messages .append (HumanMessage (content = content ))
53+ else :
54+ messages .append (HumanMessage (content = question ))
55+ return messages
10856
10957 @override
110- def generate_streaming (self , input : Dict [str , Any ]) -> Iterable [str ]:
111- history = input .get ("history" , [])
112- contents = []
113-
114- for msg in history :
115- role = "model" if msg ["role" ] == "assistant" else "user"
116- contents .append (
117- types .Content (role = role , parts = [types .Part (text = msg ["content" ])])
118- )
119-
120- contents .append (
121- types .Content (
122- role = "user" , parts = [types .Part (text = input .get ("question" , "" ))]
123- )
124- )
125-
126- config = None
127- if self .system_prompt :
128- config = types .GenerateContentConfig (system_instruction = self .system_prompt )
58+ def generate (self , input : Dict [str , Any ]) -> str :
59+ response = self .model .invoke (self ._build_messages (input ))
60+ return response .content
12961
130- for chunk in self .model .models .generate_content_stream (
131- model = self .model_name ,
132- contents = contents ,
133- config = config ,
134- ):
135- if chunk .text :
136- yield chunk .text
62+ @override
63+ def generate_streaming (self , input : Dict [str , Any ], callbacks = None ) -> Iterable [str ]:
64+ config = {"callbacks" : callbacks } if callbacks else {}
65+ for chunk in self .model .stream (self ._build_messages (input ), config = config ):
66+ if chunk .content :
67+ yield chunk .content
0 commit comments