Skip to content

Commit 26fc173

Browse files
committed
fix tests
1 parent 6e1b6d6 commit 26fc173

5 files changed

Lines changed: 93 additions & 168 deletions

File tree

tests/tests_embeddings/test_gemini_embeddings.py

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,40 @@
11
import unittest
22
from unittest.mock import MagicMock, patch
3+
34
from raglight.embeddings.gemini_embeddings import GeminiEmbeddingsModel
45
from ..test_config import TestsConfig
5-
from google.genai import types
66

77

88
class TestGeminiEmbeddings(unittest.TestCase):
99

10-
@patch("raglight.embeddings.gemini_embeddings.Client")
11-
def test_model_load(self, MockClient: MagicMock):
12-
"""
13-
Test that the client is instantiated with the correct parameters.
14-
Note: The code uses Client(...) instead of configure().
15-
"""
10+
@patch("raglight.embeddings.gemini_embeddings.GoogleGenerativeAIEmbeddings")
11+
def test_model_load(self, MockEmbeddings: MagicMock):
1612
model = GeminiEmbeddingsModel(TestsConfig.GEMINI_EMBEDDING_MODEL)
13+
self.assertTrue(MockEmbeddings.called)
14+
self.assertIsNotNone(model.model)
1715

18-
self.assertTrue(MockClient.called)
19-
20-
self.assertIsNotNone(
21-
model.model, "Model (genai client instance) should be loaded."
22-
)
23-
24-
@patch("raglight.embeddings.gemini_embeddings.Client")
25-
def test_embed_documents(self, MockClient: MagicMock):
26-
"""Test document embedding with the correct task_type."""
27-
mock_client_instance = MockClient.return_value
28-
29-
mock_client_instance.embed_content.return_value = {
30-
"embedding": [[0.1, 0.2], [0.3, 0.4]]
31-
}
16+
@patch("raglight.embeddings.gemini_embeddings.GoogleGenerativeAIEmbeddings")
17+
def test_embed_documents(self, MockEmbeddings: MagicMock):
18+
mock_instance = MockEmbeddings.return_value
19+
mock_instance.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
3220

3321
model = GeminiEmbeddingsModel(TestsConfig.GEMINI_EMBEDDING_MODEL)
3422
texts = ["doc1", "doc2"]
35-
3623
result = model.embed_documents(texts)
3724

3825
self.assertEqual(len(result), 2)
26+
mock_instance.embed_documents.assert_called_with(texts)
3927

40-
mock_client_instance.embed_content.assert_called_with(
41-
model=TestsConfig.GEMINI_EMBEDDING_MODEL,
42-
content=texts,
43-
task_type="retrieval_document",
44-
)
45-
46-
@patch("raglight.embeddings.gemini_embeddings.Client")
47-
def test_embed_query(self, MockClient: MagicMock):
48-
"""Test query embedding with the correct task_type."""
49-
mock_client_instance = MockClient.return_value
50-
mock_client_instance.embed_content.return_value = {"embedding": [0.1, 0.2]}
28+
@patch("raglight.embeddings.gemini_embeddings.GoogleGenerativeAIEmbeddings")
29+
def test_embed_query(self, MockEmbeddings: MagicMock):
30+
mock_instance = MockEmbeddings.return_value
31+
mock_instance.embed_query.return_value = [0.1, 0.2]
5132

5233
model = GeminiEmbeddingsModel(TestsConfig.GEMINI_EMBEDDING_MODEL)
53-
text = "query"
54-
55-
result = model.embed_query(text)
34+
result = model.embed_query("query")
5635

5736
self.assertEqual(len(result), 2)
58-
59-
mock_client_instance.embed_content.assert_called_with(
60-
model=TestsConfig.GEMINI_EMBEDDING_MODEL,
61-
content=text,
62-
task_type="retrieval_query",
63-
)
37+
mock_instance.embed_query.assert_called_with("query")
6438

6539

6640
if __name__ == "__main__":

tests/tests_embeddings/test_ollama_embeddings.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,41 @@
11
import unittest
22
from unittest.mock import MagicMock, patch
3+
34
from raglight.embeddings.ollama_embeddings import OllamaEmbeddingsModel
45
from ..test_config import TestsConfig
56

67

78
class TestOllamaEmbeddings(unittest.TestCase):
89

9-
@patch("raglight.embeddings.ollama_embeddings.Client")
10-
def test_model_load(self, mock_client: MagicMock):
11-
"""Test Ollama client initialization."""
12-
mock_client.return_value = MagicMock()
10+
@patch("raglight.embeddings.ollama_embeddings.OllamaEmbeddings")
11+
def test_model_load(self, MockEmbeddings: MagicMock):
12+
MockEmbeddings.return_value = MagicMock()
1313
embeddings = OllamaEmbeddingsModel(TestsConfig.OLLAMA_EMBEDDING_MODEL)
1414
self.assertIsNotNone(embeddings.model)
15-
mock_client.assert_called_once()
15+
MockEmbeddings.assert_called_once()
1616

17-
@patch("raglight.embeddings.ollama_embeddings.Client")
18-
def test_embed_documents(self, mock_client: MagicMock):
19-
"""Test batch embedding with .embed() method."""
20-
mock_instance = mock_client.return_value
21-
mock_instance.embed.return_value = {"embeddings": [[0.1, 0.2], [0.3, 0.4]]}
17+
@patch("raglight.embeddings.ollama_embeddings.OllamaEmbeddings")
18+
def test_embed_documents(self, MockEmbeddings: MagicMock):
19+
mock_instance = MockEmbeddings.return_value
20+
mock_instance.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
2221

2322
model = OllamaEmbeddingsModel(TestsConfig.OLLAMA_EMBEDDING_MODEL)
2423
texts = ["doc1", "doc2"]
2524
result = model.embed_documents(texts)
2625

2726
self.assertEqual(len(result), 2)
28-
mock_instance.embed.assert_called_with(
29-
model=TestsConfig.OLLAMA_EMBEDDING_MODEL, input=texts, options=model.options
30-
)
27+
mock_instance.embed_documents.assert_called_with(texts)
3128

32-
@patch("raglight.embeddings.ollama_embeddings.Client")
33-
def test_embed_query(self, mock_client: MagicMock):
34-
"""Test single embedding with .embeddings() method."""
35-
mock_instance = mock_client.return_value
36-
mock_instance.embeddings.return_value = {"embedding": [0.9, 0.9]}
29+
@patch("raglight.embeddings.ollama_embeddings.OllamaEmbeddings")
30+
def test_embed_query(self, MockEmbeddings: MagicMock):
31+
mock_instance = MockEmbeddings.return_value
32+
mock_instance.embed_query.return_value = [0.9, 0.9]
3733

3834
model = OllamaEmbeddingsModel(TestsConfig.OLLAMA_EMBEDDING_MODEL)
39-
text = "query text"
40-
result = model.embed_query(text)
35+
result = model.embed_query("query text")
4136

4237
self.assertEqual(result, [0.9, 0.9])
43-
mock_instance.embeddings.assert_called_with(
44-
model=TestsConfig.OLLAMA_EMBEDDING_MODEL, prompt=text, options=model.options
45-
)
38+
mock_instance.embed_query.assert_called_with("query text")
4639

4740

4841
if __name__ == "__main__":

tests/tests_embeddings/test_openai_embeddings.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,41 @@
11
import unittest
22
from unittest.mock import MagicMock, patch
3+
34
from raglight.embeddings.openai_embeddings import OpenAIEmbeddingsModel
45
from ..test_config import TestsConfig
56

67

78
class TestOpenAIEmbeddings(unittest.TestCase):
89

9-
@patch("raglight.embeddings.openai_embeddings.OpenAI")
10-
def test_model_load(self, mock_openai: MagicMock):
11-
"""Test OpenAI client initialization."""
12-
mock_openai.return_value = MagicMock()
10+
@patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings")
11+
def test_model_load(self, MockEmbeddings: MagicMock):
12+
MockEmbeddings.return_value = MagicMock()
1313
model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL)
14-
1514
self.assertIsNotNone(model.model)
16-
mock_openai.assert_called_once()
17-
18-
@patch("raglight.embeddings.openai_embeddings.OpenAI")
19-
def test_embed_documents(self, mock_openai: MagicMock):
20-
"""Test document embedding (batch)."""
21-
mock_client = mock_openai.return_value
22-
23-
mock_data_1 = MagicMock()
24-
mock_data_1.embedding = [0.1, 0.1]
25-
mock_data_2 = MagicMock()
26-
mock_data_2.embedding = [0.2, 0.2]
15+
MockEmbeddings.assert_called_once()
2716

28-
mock_response = MagicMock()
29-
mock_response.data = [mock_data_1, mock_data_2]
30-
31-
mock_client.embeddings.create.return_value = mock_response
17+
@patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings")
18+
def test_embed_documents(self, MockEmbeddings: MagicMock):
19+
mock_instance = MockEmbeddings.return_value
20+
mock_instance.embed_documents.return_value = [[0.1, 0.1], [0.2, 0.2]]
3221

3322
model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL)
3423
result = model.embed_documents(["text1", "text2"])
3524

3625
self.assertEqual(len(result), 2)
3726
self.assertEqual(result[0], [0.1, 0.1])
38-
mock_client.embeddings.create.assert_called_with(
39-
input=["text1", "text2"], model=TestsConfig.OPENAI_EMBEDDING_MODEL
40-
)
41-
42-
@patch("raglight.embeddings.openai_embeddings.OpenAI")
43-
def test_embed_query(self, mock_openai: MagicMock):
44-
"""Test single query embedding."""
45-
mock_client = mock_openai.return_value
46-
47-
mock_data = MagicMock()
48-
mock_data.embedding = [0.5, 0.5]
49-
50-
mock_response = MagicMock()
51-
mock_response.data = [mock_data]
27+
mock_instance.embed_documents.assert_called_with(["text1", "text2"])
5228

53-
mock_client.embeddings.create.return_value = mock_response
29+
@patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings")
30+
def test_embed_query(self, MockEmbeddings: MagicMock):
31+
mock_instance = MockEmbeddings.return_value
32+
mock_instance.embed_query.return_value = [0.5, 0.5]
5433

5534
model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL)
5635
result = model.embed_query("query")
5736

5837
self.assertEqual(result, [0.5, 0.5])
59-
mock_client.embeddings.create.assert_called_with(
60-
input=["query"], model=TestsConfig.OPENAI_EMBEDDING_MODEL
61-
)
38+
mock_instance.embed_query.assert_called_with("query")
6239

6340

6441
if __name__ == "__main__":

tests/tests_llm/test_gemini_model.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
import os
32
from unittest.mock import MagicMock, patch
43

54
from raglight.llm.gemini_model import GeminiModel
@@ -9,28 +8,24 @@
98
class TestGeminiModel(unittest.TestCase):
109
_MOCK_RESPONSE = "Hello! This is a test response."
1110

11+
@patch("raglight.llm.gemini_model.ChatGoogleGenerativeAI")
1212
@patch("raglight.Settings.GEMINI_API_KEY", "DUMMY_KEY")
13-
def setUp(self):
14-
model_name = TestsConfig.GEMINI_LLM_MODEL
15-
self.model = GeminiModel(
16-
model_name=model_name,
17-
)
18-
13+
def setUp(self, MockChatGemini):
14+
mock_lc_client = MagicMock()
1915
mock_response = MagicMock()
20-
mock_response.text = self._MOCK_RESPONSE
21-
mock_response.candidates = [MagicMock()] # Non-empty to pass the check
22-
mock_generate = MagicMock(return_value=mock_response)
16+
mock_response.content = self._MOCK_RESPONSE
17+
mock_lc_client.invoke.return_value = mock_response
18+
MockChatGemini.return_value = mock_lc_client
2319

24-
mock_client = MagicMock()
25-
mock_client.models.generate_content = mock_generate
26-
self.model.model = mock_client
20+
self.model = GeminiModel(model_name=TestsConfig.GEMINI_LLM_MODEL)
2721

2822
def test_generate_response(self):
2923
prompt = "Say hello."
3024
response = self.model.generate({"question": prompt})
3125
self.assertIsInstance(response, str, "Response should be a string.")
3226
self.assertGreater(len(response), 0, "Response should not be empty.")
3327
self.assertEqual(response, self._MOCK_RESPONSE)
28+
self.model.model.invoke.assert_called_once()
3429

3530

3631
if __name__ == "__main__":

0 commit comments

Comments
 (0)