Skip to content

Commit 8efabfe

Browse files
authored
feat: Add Google Gemini provider (#79)
This commit adds a new provider for Google Gemini, allowing users to select Gemini as their LLM provider in ShellOracle. The following changes were made: - Created `src/shelloracle/providers/google.py` with the `Google` provider class. - Added `google-generativeai` as a dependency in `pyproject.toml`. - Modified `src/shelloracle/providers/__init__.py` to include the new provider. - Updated `~/.shelloracle/config.toml` to include a configuration section for the Google provider.
1 parent fb79a6d commit 8efabfe

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ dependencies = [
3232
"prompt-toolkit",
3333
"yaspin",
3434
"tomlkit",
35-
"tomli >= 1.1.0; python_version < '3.11'"
35+
"tomli >= 1.1.0; python_version < '3.11'",
36+
"google-generativeai"
3637
]
3738

3839
[project.scripts]

src/shelloracle/providers/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,16 @@ def _providers() -> dict[str, type[Provider]]:
7979
from shelloracle.providers.ollama import Ollama
8080
from shelloracle.providers.openai import OpenAI
8181
from shelloracle.providers.xai import XAI
82-
83-
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI, Deepseek.name: Deepseek}
82+
from shelloracle.providers.google import Google
83+
84+
return {
85+
Ollama.name: Ollama,
86+
OpenAI.name: OpenAI,
87+
LocalAI.name: LocalAI,
88+
XAI.name: XAI,
89+
Deepseek.name: Deepseek,
90+
Google.name: Google,
91+
}
8492

8593

8694
def get_provider(name: str) -> type[Provider]:
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from collections.abc import AsyncIterator
2+
3+
import google.generativeai as genai
4+
5+
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
6+
7+
8+
class Google(Provider):
9+
name = "Google"
10+
11+
api_key = Setting(default="")
12+
model = Setting(default="gemini-pro") # Assuming a default model name
13+
14+
def __init__(self):
15+
if not self.api_key:
16+
msg = "No API key provided"
17+
raise ProviderError(msg)
18+
genai.configure(api_key=self.api_key)
19+
self.model_instance = genai.GenerativeModel(self.model)
20+
21+
22+
async def generate(self, prompt: str) -> AsyncIterator[str]:
23+
try:
24+
response = await self.model_instance.generate_content_async(
25+
[
26+
{"role": "user", "parts": [system_prompt]},
27+
{"role": "model", "parts": ["Okay."]}, # Gemini requires a model response before user input
28+
{"role": "user", "parts": [prompt]},
29+
],
30+
stream=True
31+
)
32+
33+
async for chunk in response:
34+
yield chunk.text
35+
except Exception as e:
36+
msg = f"Something went wrong while querying Google Gemini: {e}"
37+
raise ProviderError(msg) from e

0 commit comments

Comments
 (0)