Skip to content

Commit 827b32d

Browse files
authored
Merge pull request #143 from Bessouat40/feature/langfuse-fix
fix langfuse integration
2 parents 5344590 + 26fc173 commit 827b32d

16 files changed

Lines changed: 304 additions & 732 deletions
Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,28 @@
11
from __future__ import annotations
2-
from typing import Optional, List, Any
2+
from typing import Optional, List
33
from typing_extensions import override
44

5-
from google.genai import Client
6-
from google.genai import types
5+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
76

87
from ..config.settings import Settings
98
from .embeddings_model import EmbeddingsModel
109

1110

1211
class GeminiEmbeddingsModel(EmbeddingsModel):
13-
"""
14-
Concrete implementation of the EmbeddingsModel for Gemini models using the official Google GenAI library.
15-
"""
16-
1712
def __init__(self, model_name: str, api_base: Optional[str] = None) -> None:
18-
"""
19-
Initializes a GeminiEmbeddingsModel instance.
20-
21-
Args:
22-
model_name (str): The name of the Gemini model to load (e.g., "models/embedding-001").
23-
api_base (Optional[str]): Not strictly used by the official lib as it relies on global config,
24-
but kept for interface consistency.
25-
"""
2613
super().__init__(model_name, api_base)
2714

2815
@override
29-
def load(self) -> Any:
30-
"""
31-
Configures the Google GenAI library.
32-
Returns the module reference as the 'client'.
33-
"""
34-
return Client(
35-
api_key=Settings.GEMINI_API_KEY,
36-
http_options=types.HttpOptions(base_url=self.api_base),
16+
def load(self) -> GoogleGenerativeAIEmbeddings:
17+
return GoogleGenerativeAIEmbeddings(
18+
model=self.model_name,
19+
google_api_key=Settings.GEMINI_API_KEY,
3720
)
3821

3922
@override
4023
def embed_documents(self, texts: List[str]) -> List[List[float]]:
41-
"""
42-
Embed list of documents using Google GenAI.
43-
Specifies 'retrieval_document' task type for optimized document storage embeddings.
44-
"""
45-
result = self.model.embed_content(
46-
model=self.model_name, content=texts, task_type="retrieval_document"
47-
)
48-
return result["embedding"]
24+
return self.model.embed_documents(texts)
4925

5026
@override
5127
def embed_query(self, text: str) -> List[float]:
52-
"""
53-
Embed a single query text.
54-
Specifies 'retrieval_query' task type for optimized search query embeddings.
55-
"""
56-
result = self.model.embed_content(
57-
model=self.model_name, content=text, task_type="retrieval_query"
58-
)
59-
return result["embedding"]
28+
return self.model.embed_query(text)

src/raglight/embeddings/ollama_embeddings.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
from typing import Optional, List, Dict, Any
33
from typing_extensions import override
44

5+
from langchain_ollama import OllamaEmbeddings
6+
57
from ..config.settings import Settings
68
from .embeddings_model import EmbeddingsModel
7-
from ollama import Client
89

910

1011
class OllamaEmbeddingsModel(EmbeddingsModel):
11-
"""
12-
Concrete implementation of the EmbeddingsModel for Ollama models using the official python library.
13-
"""
14-
1512
def __init__(
1613
self,
1714
model_name: str,
@@ -20,37 +17,23 @@ def __init__(
2017
) -> None:
2118
resolved_api_base = api_base or Settings.DEFAULT_OLLAMA_CLIENT
2219
super().__init__(model_name, api_base=resolved_api_base)
23-
2420
self.options = options or {}
25-
26-
# Keep critical config to prevent internal Ollama "panic" on large docs
2721
if "num_batch" not in self.options:
2822
self.options["num_batch"] = 8192
2923
if "num_ctx" not in self.options:
3024
self.options["num_ctx"] = 8192
3125

3226
@override
33-
def load(self) -> Client:
34-
return Client(host=self.api_base)
27+
def load(self) -> OllamaEmbeddings:
28+
return OllamaEmbeddings(
29+
model=self.model_name,
30+
base_url=self.api_base,
31+
)
3532

3633
@override
3734
def embed_documents(self, texts: List[str]) -> List[List[float]]:
38-
"""
39-
Embed list of documents using the optimized batch 'embed' method.
40-
"""
41-
# OPTIMIZATION: Use 'embed' (not 'embeddings') to process the whole list at once.
42-
# This sends a single request and leverages GPU batch processing.
43-
response = self.model.embed(
44-
model=self.model_name, input=texts, options=self.options
45-
)
46-
return response["embeddings"]
35+
return self.model.embed_documents(texts)
4736

4837
@override
4938
def embed_query(self, text: str) -> List[float]:
50-
"""
51-
Embed a single query text.
52-
"""
53-
response = self.model.embeddings(
54-
model=self.model_name, prompt=text, options=self.options
55-
)
56-
return response["embedding"]
39+
return self.model.embed_query(text)

src/raglight/embeddings/openai_embeddings.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,29 @@
22
from typing import Optional, List
33
from typing_extensions import override
44

5-
from openai import OpenAI
5+
from langchain_openai import OpenAIEmbeddings
66

77
from ..config.settings import Settings
88
from .embeddings_model import EmbeddingsModel
99

1010

1111
class OpenAIEmbeddingsModel(EmbeddingsModel):
12-
"""
13-
Concrete implementation of the EmbeddingsModel for OpenAI models using the official python library.
14-
"""
15-
1612
def __init__(self, model_name: str, api_base: Optional[str] = None) -> None:
17-
"""
18-
Initializes an OpenAIEmbeddingsModel instance.
19-
20-
Args:
21-
model_name (str): The name of the OpenAI model to load.
22-
api_base (Optional[str]): The base URL for the API (optional).
23-
"""
2413
resolved_api_base = api_base or Settings.DEFAULT_OPENAI_CLIENT
2514
super().__init__(model_name, api_base=resolved_api_base)
2615

2716
@override
28-
def load(self) -> OpenAI:
29-
"""
30-
Loads the OpenAI client.
31-
32-
Returns:
33-
OpenAI: The initialized OpenAI client.
34-
"""
35-
return OpenAI(
17+
def load(self) -> OpenAIEmbeddings:
18+
return OpenAIEmbeddings(
19+
model=self.model_name,
3620
api_key=Settings.OPENAI_API_KEY,
3721
base_url=self.api_base,
3822
)
3923

4024
@override
4125
def embed_documents(self, texts: List[str]) -> List[List[float]]:
42-
"""
43-
Embed list of documents using the official OpenAI client.
44-
"""
45-
response = self.model.embeddings.create(input=texts, model=self.model_name)
46-
return [data.embedding for data in response.data]
26+
return self.model.embed_documents(texts)
4727

4828
@override
4929
def embed_query(self, text: str) -> List[float]:
50-
"""
51-
Embed a single query text.
52-
"""
53-
response = self.model.embeddings.create(input=[text], model=self.model_name)
54-
return response.data[0].embedding
30+
return self.model.embed_query(text)

src/raglight/llm/bedrock_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def generate(self, input: Dict[str, Any]) -> str:
9393
return response.content
9494

9595
@override
96-
def generate_streaming(self, input: Dict[str, Any]) -> Iterable[str]:
96+
def generate_streaming(self, input: Dict[str, Any], callbacks=None) -> Iterable[str]:
9797
history = input.get("history", [])
9898
messages = []
9999

@@ -108,6 +108,7 @@ def generate_streaming(self, input: Dict[str, Any]) -> Iterable[str]:
108108

109109
messages.append(HumanMessage(content=input.get("question", "")))
110110

111-
for chunk in self.model.stream(messages):
111+
stream_config = {"callbacks": callbacks} if callbacks else {}
112+
for chunk in self.model.stream(messages, config=stream_config):
112113
if chunk.content:
113114
yield chunk.content

src/raglight/llm/gemini_model.py

Lines changed: 37 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,12 @@
55

66
from ..config.settings import Settings
77
from .llm import LLM
8-
from google.genai import Client
9-
from google.genai import types
108

9+
from langchain_google_genai import ChatGoogleGenerativeAI
10+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
1111

12-
class GeminiModel(LLM):
13-
"""
14-
A subclass of LLM that uses Google's Gemini as the backend for text generation.
15-
16-
This class provides an interface to interact with Google's Generative AI,
17-
enabling text generation with various models supported by the Gemini API.
18-
19-
Attributes:
20-
model_name (str): The name of the model to use with Gemini.
21-
role (str): The role of the user in the conversation, typically "user".
22-
system_prompt (str): The system prompt to use for text generation.
23-
model (Client): The Gemini client configured to interact with the API.
24-
"""
2512

13+
class GeminiModel(LLM):
2614
def __init__(
2715
self,
2816
model_name: str,
@@ -31,106 +19,49 @@ def __init__(
3119
api_base: Optional[str] = None,
3220
role: str = "user",
3321
) -> None:
34-
"""
35-
Initializes an instance of GeminiModel.
36-
37-
Args:
38-
model_name (str): The name of the model to use with Gemini.
39-
system_prompt (Optional[str]): The system prompt to use. If not provided, it will be loaded from system_prompt_file or use the default value.
40-
system_prompt_file (Optional[str]): The path to the file to load the system prompt from. If provided, it takes precedence over system_prompt.
41-
role (str): The role of the user in the conversation, defaults to "user".
42-
"""
4322
self.api_base = api_base or Settings.DEFAULT_GOOGLE_CLIENT
4423
super().__init__(model_name, system_prompt, system_prompt_file, self.api_base)
4524
logging.info(f"Using Gemini with {model_name} model 🤖")
4625
self.role: str = role
4726

4827
@override
49-
def load(self) -> Client:
50-
"""
51-
Loads the Gemini client using the modern google.generativeai SDK.
52-
53-
Returns:
54-
Client: The client object to interact with Gemini API.
55-
"""
56-
return Client(
57-
api_key=Settings.GEMINI_API_KEY,
58-
http_options=types.HttpOptions(base_url=self.api_base),
59-
)
60-
61-
@override
62-
def generate(self, input: Dict[str, Any]) -> str:
63-
"""
64-
Generates text using the Gemini model.
65-
It constructs a structured 'contents' payload using the 'types' module
66-
as requested for proper input management.
67-
68-
Args:
69-
input (Dict[str, Any]): The input data for text generation.
70-
71-
Returns:
72-
str: The text generated by the model.
73-
"""
74-
history = input.get("history", [])
75-
contents = []
76-
77-
for msg in history:
78-
role = "model" if msg["role"] == "assistant" else "user"
79-
contents.append(
80-
types.Content(role=role, parts=[types.Part(text=msg["content"])])
81-
)
82-
83-
contents.append(
84-
types.Content(
85-
role="user", parts=[types.Part(text=input.get("question", ""))]
86-
)
28+
def load(self) -> ChatGoogleGenerativeAI:
29+
return ChatGoogleGenerativeAI(
30+
model=self.model_name,
31+
google_api_key=Settings.GEMINI_API_KEY,
8732
)
8833

89-
config = None
34+
def _build_messages(self, input: Dict[str, Any]):
35+
messages = []
9036
if self.system_prompt:
91-
config = types.GenerateContentConfig(system_instruction=self.system_prompt)
92-
93-
try:
94-
response = self.model.models.generate_content(
95-
model=self.model_name,
96-
contents=contents,
97-
config=config,
98-
)
99-
if not response.candidates:
100-
logging.warning("Response was blocked. Checking prompt feedback.")
101-
if response.prompt_feedback:
102-
logging.warning(f"Prompt Feedback: {response.prompt_feedback}")
103-
return "Response blocked due to safety settings."
104-
return response.text
105-
except Exception as e:
106-
logging.error(f"An error occurred during Gemini content generation: {e}")
107-
return f"Error: {e}"
37+
messages.append(SystemMessage(content=self.system_prompt))
38+
for msg in input.get("history", []):
39+
if msg["role"] == "assistant":
40+
messages.append(AIMessage(content=msg["content"]))
41+
else:
42+
messages.append(HumanMessage(content=msg["content"]))
43+
44+
question = input.get("question", "")
45+
if "images" in input:
46+
content = [{"type": "text", "text": question}]
47+
for image in input["images"]:
48+
try:
49+
content.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{image['base64']}"})
50+
except Exception as e:
51+
logging.error(f"Could not read image: {e}")
52+
messages.append(HumanMessage(content=content))
53+
else:
54+
messages.append(HumanMessage(content=question))
55+
return messages
10856

10957
@override
110-
def generate_streaming(self, input: Dict[str, Any]) -> Iterable[str]:
111-
history = input.get("history", [])
112-
contents = []
113-
114-
for msg in history:
115-
role = "model" if msg["role"] == "assistant" else "user"
116-
contents.append(
117-
types.Content(role=role, parts=[types.Part(text=msg["content"])])
118-
)
119-
120-
contents.append(
121-
types.Content(
122-
role="user", parts=[types.Part(text=input.get("question", ""))]
123-
)
124-
)
125-
126-
config = None
127-
if self.system_prompt:
128-
config = types.GenerateContentConfig(system_instruction=self.system_prompt)
58+
def generate(self, input: Dict[str, Any]) -> str:
59+
response = self.model.invoke(self._build_messages(input))
60+
return response.content
12961

130-
for chunk in self.model.models.generate_content_stream(
131-
model=self.model_name,
132-
contents=contents,
133-
config=config,
134-
):
135-
if chunk.text:
136-
yield chunk.text
62+
@override
63+
def generate_streaming(self, input: Dict[str, Any], callbacks=None) -> Iterable[str]:
64+
config = {"callbacks": callbacks} if callbacks else {}
65+
for chunk in self.model.stream(self._build_messages(input), config=config):
66+
if chunk.content:
67+
yield chunk.content

src/raglight/llm/llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,15 @@ def generate(self, input: Dict[str, Any]) -> str:
7676
pass
7777

7878
@abstractmethod
79-
def generate_streaming(self, input: Dict[str, Any]) -> Iterable[str]:
79+
def generate_streaming(
80+
self, input: Dict[str, Any], callbacks: Optional[list] = None
81+
) -> Iterable[str]:
8082
"""
8183
Abstract method to generate text in streaming mode.
8284
8385
Args:
8486
input (Dict[str, Any]): A dictionary containing the input data for text generation.
87+
callbacks (Optional[list]): Optional list of LangChain callbacks (e.g. Langfuse).
8588
8689
Yields:
8790
str: Successive chunks of the generated output.

0 commit comments

Comments
 (0)