|
1 | 1 | import unittest |
2 | 2 | from unittest.mock import MagicMock, patch |
| 3 | + |
3 | 4 | from raglight.embeddings.openai_embeddings import OpenAIEmbeddingsModel |
4 | 5 | from ..test_config import TestsConfig |
5 | 6 |
|
6 | 7 |
|
7 | 8 | class TestOpenAIEmbeddings(unittest.TestCase): |
8 | 9 |
|
9 | | - @patch("raglight.embeddings.openai_embeddings.OpenAI") |
10 | | - def test_model_load(self, mock_openai: MagicMock): |
11 | | - """Test OpenAI client initialization.""" |
12 | | - mock_openai.return_value = MagicMock() |
| 10 | + @patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings") |
| 11 | + def test_model_load(self, MockEmbeddings: MagicMock): |
| 12 | + MockEmbeddings.return_value = MagicMock() |
13 | 13 | model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL) |
14 | | - |
15 | 14 | self.assertIsNotNone(model.model) |
16 | | - mock_openai.assert_called_once() |
17 | | - |
18 | | - @patch("raglight.embeddings.openai_embeddings.OpenAI") |
19 | | - def test_embed_documents(self, mock_openai: MagicMock): |
20 | | - """Test document embedding (batch).""" |
21 | | - mock_client = mock_openai.return_value |
22 | | - |
23 | | - mock_data_1 = MagicMock() |
24 | | - mock_data_1.embedding = [0.1, 0.1] |
25 | | - mock_data_2 = MagicMock() |
26 | | - mock_data_2.embedding = [0.2, 0.2] |
| 15 | + MockEmbeddings.assert_called_once() |
27 | 16 |
|
28 | | - mock_response = MagicMock() |
29 | | - mock_response.data = [mock_data_1, mock_data_2] |
30 | | - |
31 | | - mock_client.embeddings.create.return_value = mock_response |
| 17 | + @patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings") |
| 18 | + def test_embed_documents(self, MockEmbeddings: MagicMock): |
| 19 | + mock_instance = MockEmbeddings.return_value |
| 20 | + mock_instance.embed_documents.return_value = [[0.1, 0.1], [0.2, 0.2]] |
32 | 21 |
|
33 | 22 | model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL) |
34 | 23 | result = model.embed_documents(["text1", "text2"]) |
35 | 24 |
|
36 | 25 | self.assertEqual(len(result), 2) |
37 | 26 | self.assertEqual(result[0], [0.1, 0.1]) |
38 | | - mock_client.embeddings.create.assert_called_with( |
39 | | - input=["text1", "text2"], model=TestsConfig.OPENAI_EMBEDDING_MODEL |
40 | | - ) |
41 | | - |
42 | | - @patch("raglight.embeddings.openai_embeddings.OpenAI") |
43 | | - def test_embed_query(self, mock_openai: MagicMock): |
44 | | - """Test single query embedding.""" |
45 | | - mock_client = mock_openai.return_value |
46 | | - |
47 | | - mock_data = MagicMock() |
48 | | - mock_data.embedding = [0.5, 0.5] |
49 | | - |
50 | | - mock_response = MagicMock() |
51 | | - mock_response.data = [mock_data] |
| 27 | + mock_instance.embed_documents.assert_called_with(["text1", "text2"]) |
52 | 28 |
|
53 | | - mock_client.embeddings.create.return_value = mock_response |
| 29 | + @patch("raglight.embeddings.openai_embeddings.OpenAIEmbeddings") |
| 30 | + def test_embed_query(self, MockEmbeddings: MagicMock): |
| 31 | + mock_instance = MockEmbeddings.return_value |
| 32 | + mock_instance.embed_query.return_value = [0.5, 0.5] |
54 | 33 |
|
55 | 34 | model = OpenAIEmbeddingsModel(TestsConfig.OPENAI_EMBEDDING_MODEL) |
56 | 35 | result = model.embed_query("query") |
57 | 36 |
|
58 | 37 | self.assertEqual(result, [0.5, 0.5]) |
59 | | - mock_client.embeddings.create.assert_called_with( |
60 | | - input=["query"], model=TestsConfig.OPENAI_EMBEDDING_MODEL |
61 | | - ) |
| 38 | + mock_instance.embed_query.assert_called_with("query") |
62 | 39 |
|
63 | 40 |
|
64 | 41 | if __name__ == "__main__": |
|
0 commit comments