Skip to content

Commit 053ce71

Browse files
authored
feat(models): add more tests on service (#2977)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
1 parent b3ea3c2 commit 053ce71

File tree

5 files changed

+143
-12
lines changed

5 files changed

+143
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
9898

9999
# Tox
100100
.tox
101+
Pipfile

Makefile

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
1+
.DEFAULT_TARGET=help
12

3+
## help: Display list of commands
4+
.PHONY: help
5+
help:
6+
@echo "Available commands:"
7+
@sed -n 's|^##||p' $(MAKEFILE_LIST) | column -t -s ':' | sed -e 's|^| |'
8+
9+
10+
## dev: Start development environment
11+
.PHONY: dev
212
dev:
313
DOCKER_BUILDKIT=1 docker compose -f docker-compose.dev.yml up --build
414

15+
## prod: Build and start production environment
16+
.PHONY: prod
517
prod:
618
docker compose build backend-core
719
docker compose -f docker-compose.yml up --build
820

9-
21+
## front: Build and start frontend
22+
.PHONY: front
1023
front:
1124
cd frontend && yarn build && yarn start
1225

26+
## test: Run tests
27+
.PHONY: test
1328
test:
1429
# Ensure dependencies are installed with dev and test extras
1530
# poetry install --with dev,test && brew install tesseract pandoc libmagic
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import asyncio
2+
import os
3+
from typing import Tuple
4+
5+
import pytest
6+
import pytest_asyncio
7+
import sqlalchemy
8+
from sqlalchemy.ext.asyncio import create_async_engine
9+
from sqlmodel import select
10+
from sqlmodel.ext.asyncio.session import AsyncSession
11+
12+
from quivr_api.modules.models.entity.model import Model
13+
from quivr_api.modules.user.entity.user_identity import User
14+
15+
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
16+
17+
TestData = Tuple[Model, Model, User]
18+
19+
20+
@pytest.fixture(scope="session")
21+
def event_loop(request: pytest.FixtureRequest):
22+
loop = asyncio.get_event_loop_policy().new_event_loop()
23+
yield loop
24+
loop.close()
25+
26+
27+
@pytest_asyncio.fixture(scope="session")
28+
async def async_engine():
29+
engine = create_async_engine(
30+
"postgresql+asyncpg://" + pg_database_base_url,
31+
echo=True if os.getenv("ORM_DEBUG") else False,
32+
future=True,
33+
pool_pre_ping=True,
34+
pool_size=10,
35+
pool_recycle=0.1,
36+
)
37+
yield engine
38+
39+
40+
@pytest_asyncio.fixture()
41+
async def session(async_engine):
42+
async with async_engine.connect() as conn:
43+
await conn.begin()
44+
await conn.begin_nested()
45+
async_session = AsyncSession(conn, expire_on_commit=False)
46+
47+
@sqlalchemy.event.listens_for(
48+
async_session.sync_session, "after_transaction_end"
49+
)
50+
def end_savepoint(session, transaction):
51+
if conn.closed:
52+
return
53+
if not conn.in_nested_transaction():
54+
conn.sync_connection.begin_nested()
55+
56+
yield async_session
57+
58+
59+
@pytest_asyncio.fixture()
60+
async def test_data(
61+
session: AsyncSession,
62+
) -> TestData:
63+
# User data
64+
user_1 = (
65+
await session.exec(select(User).where(User.email == "[email protected]"))
66+
).one()
67+
68+
model_1 = Model(
69+
name="this-is-a-fake-model", price=1, max_input=4000, max_output=2000
70+
)
71+
model_2 = Model(
72+
name="this-is-another-fake-model", price=5, max_input=8000, max_output=4000
73+
)
74+
75+
session.add(model_1)
76+
session.add(model_2)
77+
78+
await session.refresh(user_1)
79+
await session.commit()
80+
return model_1, model_2, user_1
81+
82+
83+
@pytest_asyncio.fixture()
84+
async def sample_models():
85+
return [
86+
Model(name="gpt-3.5-turbo", price=1, max_input=4000, max_output=2000),
87+
Model(name="gpt-4", price=5, max_input=8000, max_output=4000),
88+
]

backend/api/quivr_api/modules/models/tests/test_models.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import pytest
2-
import pytest_asyncio
3-
from quivr_api.modules.models.entity.model import Model
4-
52

6-
@pytest_asyncio.fixture()
7-
async def sample_models():
8-
return [
9-
Model(name="gpt-3.5-turbo", price=1, max_input=4000, max_output=2000),
10-
Model(name="gpt-4", price=5, max_input=8000, max_output=4000),
11-
]
3+
from quivr_api.modules.models.entity.model import Model
124

135

146
@pytest.mark.asyncio
@@ -21,8 +13,8 @@ async def test_model_creation():
2113

2214

2315
@pytest.mark.asyncio
24-
async def test_model_attributes(sample_models):
25-
model = sample_models[0]
16+
async def test_model_attributes(test_data):
17+
model = test_data[0]
2618
assert hasattr(model, "name")
2719
assert hasattr(model, "price")
2820
assert hasattr(model, "max_input")
@@ -66,5 +58,8 @@ async def test_model_dict_representation():
6658
"price": 2,
6759
"max_input": 3000,
6860
"max_output": 1500,
61+
"description": "",
62+
"image_url": "",
63+
"display_name": "",
6964
}
7065
assert model.dict() == expected_dict
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from quivr_api.modules.models.repository.model import ModelRepository
4+
from quivr_api.modules.models.service.model_service import ModelService
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_service_get_chat_models(session):
9+
repo = ModelRepository(session)
10+
service = ModelService(repo)
11+
models = await service.get_models()
12+
assert len(models) >= 1
13+
14+
15+
@pytest.mark.asyncio
16+
async def test_service_get_non_existing_chat_model(session):
17+
repo = ModelRepository(session)
18+
service = ModelService(repo)
19+
model = await service.get_model("gpt-3.5-turbo")
20+
assert model is None
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_service_get_existing_chat_model(session):
25+
repo = ModelRepository(session)
26+
service = ModelService(repo)
27+
models = await service.get_models()
28+
assert len(models) >= 1
29+
model = models[0]
30+
model_get = await service.get_model(model.name)
31+
assert model_get is not None
32+
assert model_get == model

0 commit comments

Comments
 (0)