Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 49e3f5f

Browse files
committed
Fix unit tests
1 parent e0da3c4 commit 49e3f5f

File tree

2 files changed

+64
-44
lines changed

2 files changed

+64
-44
lines changed

tests/test_cli.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from httpx import AsyncClient
99

1010
from codegate import __version__
11+
from codegate.pipeline.factory import PipelineFactory
1112
from codegate.pipeline.secrets.manager import SecretsManager
1213
from codegate.providers.registry import ProviderRegistry
1314
from codegate.server import init_app
@@ -26,23 +27,35 @@ def mock_provider_registry():
2627

2728

2829
@pytest.fixture
29-
def test_client() -> TestClient:
30+
def mock_pipeline_factory():
31+
"""Create a mock pipeline factory."""
32+
mock_factory = MagicMock(spec=PipelineFactory)
33+
# Mock the methods that are called on the pipeline factory
34+
mock_factory.create_input_pipeline.return_value = MagicMock()
35+
mock_factory.create_fim_pipeline.return_value = MagicMock()
36+
mock_factory.create_output_pipeline.return_value = MagicMock()
37+
mock_factory.create_fim_output_pipeline.return_value = MagicMock()
38+
return mock_factory
39+
40+
41+
@pytest.fixture
42+
def test_client(mock_pipeline_factory) -> TestClient:
3043
"""Create a test client for the FastAPI application."""
31-
app = init_app()
44+
app = init_app(mock_pipeline_factory)
3245
return TestClient(app)
3346

3447

35-
def test_app_initialization() -> None:
48+
def test_app_initialization(mock_pipeline_factory) -> None:
3649
"""Test that the FastAPI application initializes correctly."""
37-
app = init_app()
50+
app = init_app(mock_pipeline_factory)
3851
assert app is not None
3952
assert app.title == "CodeGate"
4053
assert app.version == __version__
4154

4255

43-
def test_cors_middleware() -> None:
56+
def test_cors_middleware(mock_pipeline_factory) -> None:
4457
"""Test that CORS middleware is properly configured."""
45-
app = init_app()
58+
app = init_app(mock_pipeline_factory)
4659
cors_middleware = None
4760
for middleware in app.user_middleware:
4861
if isinstance(middleware.cls, type) and issubclass(middleware.cls, CORSMiddleware):
@@ -62,14 +75,11 @@ def test_health_check(test_client: TestClient) -> None:
6275
assert response.json() == {"status": "healthy"}
6376

6477

78+
@patch("codegate.pipeline.secrets.manager.SecretsManager")
6579
@patch("codegate.server.ProviderRegistry")
66-
@patch("codegate.server.SecretsManager")
67-
def test_provider_registration(mock_secrets_mgr, mock_registry) -> None:
80+
def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None:
6881
"""Test that all providers are registered correctly."""
69-
init_app()
70-
71-
# Verify SecretsManager was initialized
72-
mock_secrets_mgr.assert_called_once()
82+
init_app(mock_pipeline_factory)
7383

7484
# Verify ProviderRegistry was initialized with the app
7585
mock_registry.assert_called_once()
@@ -90,15 +100,15 @@ def test_provider_registration(mock_secrets_mgr, mock_registry) -> None:
90100

91101

92102
@patch("codegate.server.CodegateSignatures")
93-
def test_signatures_initialization(mock_signatures) -> None:
103+
def test_signatures_initialization(mock_signatures, mock_pipeline_factory) -> None:
94104
"""Test that signatures are initialized correctly."""
95-
init_app()
105+
init_app(mock_pipeline_factory)
96106
mock_signatures.initialize.assert_called_once_with("signatures.yaml")
97107

98108

99-
def test_pipeline_initialization() -> None:
109+
def test_pipeline_initialization(mock_pipeline_factory) -> None:
100110
"""Test that pipelines are initialized correctly."""
101-
app = init_app()
111+
app = init_app(mock_pipeline_factory)
102112

103113
# Access the provider registry to check pipeline configuration
104114
registry = next((route for route in app.routes if hasattr(route, "registry")), None)
@@ -111,29 +121,29 @@ def test_pipeline_initialization() -> None:
111121
assert hasattr(provider, "output_pipeline_processor")
112122

113123

114-
def test_dashboard_routes() -> None:
124+
def test_dashboard_routes(mock_pipeline_factory) -> None:
115125
"""Test that dashboard routes are included."""
116-
app = init_app()
126+
app = init_app(mock_pipeline_factory)
117127
routes = [route.path for route in app.routes]
118128

119129
# Verify dashboard endpoints are included
120130
dashboard_routes = [route for route in routes if route.startswith("/dashboard")]
121131
assert len(dashboard_routes) > 0
122132

123133

124-
def test_system_routes() -> None:
134+
def test_system_routes(mock_pipeline_factory) -> None:
125135
"""Test that system routes are included."""
126-
app = init_app()
136+
app = init_app(mock_pipeline_factory)
127137
routes = [route.path for route in app.routes]
128138

129139
# Verify system endpoints are included
130140
assert "/health" in routes
131141

132142

133143
@pytest.mark.asyncio
134-
async def test_async_health_check() -> None:
144+
async def test_async_health_check(mock_pipeline_factory) -> None:
135145
"""Test the health check endpoint with async client."""
136-
app = init_app()
146+
app = init_app(mock_pipeline_factory)
137147
async with AsyncClient(app=app, base_url="http://test") as ac:
138148
response = await ac.get("/health")
139149
assert response.status_code == 200

tests/test_server.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from httpx import AsyncClient
99

1010
from codegate import __version__
11+
from codegate.pipeline.factory import PipelineFactory
1112
from codegate.pipeline.secrets.manager import SecretsManager
1213
from codegate.providers.registry import ProviderRegistry
1314
from codegate.server import init_app
@@ -26,23 +27,35 @@ def mock_provider_registry():
2627

2728

2829
@pytest.fixture
29-
def test_client() -> TestClient:
30+
def mock_pipeline_factory():
31+
"""Create a mock pipeline factory."""
32+
mock_factory = MagicMock(spec=PipelineFactory)
33+
# Mock the methods that are called on the pipeline factory
34+
mock_factory.create_input_pipeline.return_value = MagicMock()
35+
mock_factory.create_fim_pipeline.return_value = MagicMock()
36+
mock_factory.create_output_pipeline.return_value = MagicMock()
37+
mock_factory.create_fim_output_pipeline.return_value = MagicMock()
38+
return mock_factory
39+
40+
41+
@pytest.fixture
42+
def test_client(mock_pipeline_factory) -> TestClient:
3043
"""Create a test client for the FastAPI application."""
31-
app = init_app()
44+
app = init_app(mock_pipeline_factory)
3245
return TestClient(app)
3346

3447

35-
def test_app_initialization() -> None:
48+
def test_app_initialization(mock_pipeline_factory) -> None:
3649
"""Test that the FastAPI application initializes correctly."""
37-
app = init_app()
50+
app = init_app(mock_pipeline_factory)
3851
assert app is not None
3952
assert app.title == "CodeGate"
4053
assert app.version == __version__
4154

4255

43-
def test_cors_middleware() -> None:
56+
def test_cors_middleware(mock_pipeline_factory) -> None:
4457
"""Test that CORS middleware is properly configured."""
45-
app = init_app()
58+
app = init_app(mock_pipeline_factory)
4659
cors_middleware = None
4760
for middleware in app.user_middleware:
4861
if isinstance(middleware.cls, type) and issubclass(middleware.cls, CORSMiddleware):
@@ -62,14 +75,11 @@ def test_health_check(test_client: TestClient) -> None:
6275
assert response.json() == {"status": "healthy"}
6376

6477

78+
@patch("codegate.pipeline.secrets.manager.SecretsManager")
6579
@patch("codegate.server.ProviderRegistry")
66-
@patch("codegate.server.SecretsManager")
67-
def test_provider_registration(mock_secrets_mgr, mock_registry) -> None:
80+
def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_factory) -> None:
6881
"""Test that all providers are registered correctly."""
69-
init_app()
70-
71-
# Verify SecretsManager was initialized
72-
mock_secrets_mgr.assert_called_once()
82+
init_app(mock_pipeline_factory)
7383

7484
# Verify ProviderRegistry was initialized with the app
7585
mock_registry.assert_called_once()
@@ -90,15 +100,15 @@ def test_provider_registration(mock_secrets_mgr, mock_registry) -> None:
90100

91101

92102
@patch("codegate.server.CodegateSignatures")
93-
def test_signatures_initialization(mock_signatures) -> None:
103+
def test_signatures_initialization(mock_signatures, mock_pipeline_factory) -> None:
94104
"""Test that signatures are initialized correctly."""
95-
init_app()
105+
init_app(mock_pipeline_factory)
96106
mock_signatures.initialize.assert_called_once_with("signatures.yaml")
97107

98108

99-
def test_pipeline_initialization() -> None:
109+
def test_pipeline_initialization(mock_pipeline_factory) -> None:
100110
"""Test that pipelines are initialized correctly."""
101-
app = init_app()
111+
app = init_app(mock_pipeline_factory)
102112

103113
# Access the provider registry to check pipeline configuration
104114
registry = next((route for route in app.routes if hasattr(route, "registry")), None)
@@ -111,29 +121,29 @@ def test_pipeline_initialization() -> None:
111121
assert hasattr(provider, "output_pipeline_processor")
112122

113123

114-
def test_dashboard_routes() -> None:
124+
def test_dashboard_routes(mock_pipeline_factory) -> None:
115125
"""Test that dashboard routes are included."""
116-
app = init_app()
126+
app = init_app(mock_pipeline_factory)
117127
routes = [route.path for route in app.routes]
118128

119129
# Verify dashboard endpoints are included
120130
dashboard_routes = [route for route in routes if route.startswith("/dashboard")]
121131
assert len(dashboard_routes) > 0
122132

123133

124-
def test_system_routes() -> None:
134+
def test_system_routes(mock_pipeline_factory) -> None:
125135
"""Test that system routes are included."""
126-
app = init_app()
136+
app = init_app(mock_pipeline_factory)
127137
routes = [route.path for route in app.routes]
128138

129139
# Verify system endpoints are included
130140
assert "/health" in routes
131141

132142

133143
@pytest.mark.asyncio
134-
async def test_async_health_check() -> None:
144+
async def test_async_health_check(mock_pipeline_factory) -> None:
135145
"""Test the health check endpoint with async client."""
136-
app = init_app()
146+
app = init_app(mock_pipeline_factory)
137147
async with AsyncClient(app=app, base_url="http://test") as ac:
138148
response = await ac.get("/health")
139149
assert response.status_code == 200

0 commit comments

Comments
 (0)