Skip to content

Commit 16cf955

Browse files
authored
feat: adds multitenant tests via pytest (#1923)
<!-- .github/pull_request_template.md --> ## Description This PR changes the permission test in e2e tests to use pytest. Introduces: - fixtures for the environment setup - one eventloop for all pytest tests - mocking for acreate_structured_output answer generation (for search) - Asserts in permission test (before we use the example only) ## Acceptance Criteria <!-- * Key requirements to the new feature or modification; * Proof that the changes work and meet the requirements; * Include instructions on how to verify the changes. Describe how to test it locally; * Proof that it's sufficiently tested. --> ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Entity model now includes description and metadata fields for richer entity information and indexing. * **Tests** * Expanded and restructured permission tests covering multi-tenant and role-based access flows; improved test scaffolding and stability. * E2E test workflow now runs pytest with verbose output and INFO logs. * **Bug Fixes** * Access-tracking updates now commit transactions so access timestamps persist. * **Chores** * General formatting, cleanup, and refactoring across modules and maintenance scripts. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
2 parents eb444ca + 2c4f9b0 commit 16cf955

File tree

12 files changed

+659
-692
lines changed

12 files changed

+659
-692
lines changed

.github/workflows/e2e_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ jobs:
288288
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
289289
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
290290
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
291-
run: uv run python ./cognee/tests/test_permissions.py
291+
run: uv run pytest cognee/tests/test_permissions.py -v --log-level=INFO
292292

293293
test-multi-tenancy:
294294
name: Test multi tenancy with different situations in Cognee
Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,51 @@
1-
"""add_last_accessed_to_data
2-
3-
Revision ID: e1ec1dcb50b6
4-
Revises: 211ab850ef3d
5-
Create Date: 2025-11-04 21:45:52.642322
6-
7-
"""
8-
import os
9-
from typing import Sequence, Union
10-
11-
from alembic import op
12-
import sqlalchemy as sa
13-
14-
15-
# revision identifiers, used by Alembic.
16-
revision: str = 'e1ec1dcb50b6'
17-
down_revision: Union[str, None] = '211ab850ef3d'
18-
branch_labels: Union[str, Sequence[str], None] = None
19-
depends_on: Union[str, Sequence[str], None] = None
20-
21-
22-
def _get_column(inspector, table, name, schema=None):
23-
for col in inspector.get_columns(table, schema=schema):
24-
if col["name"] == name:
25-
return col
26-
return None
27-
28-
29-
def upgrade() -> None:
30-
conn = op.get_bind()
31-
insp = sa.inspect(conn)
32-
33-
last_accessed_column = _get_column(insp, "data", "last_accessed")
34-
if not last_accessed_column:
35-
# Always create the column for schema consistency
36-
op.add_column('data',
37-
sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True)
38-
)
39-
40-
# Only initialize existing records if feature is enabled
41-
enable_last_accessed = os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true"
42-
if enable_last_accessed:
43-
op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP")
44-
45-
46-
def downgrade() -> None:
47-
conn = op.get_bind()
48-
insp = sa.inspect(conn)
49-
50-
last_accessed_column = _get_column(insp, "data", "last_accessed")
51-
if last_accessed_column:
52-
op.drop_column('data', 'last_accessed')
1+
"""add_last_accessed_to_data
2+
3+
Revision ID: e1ec1dcb50b6
4+
Revises: 211ab850ef3d
5+
Create Date: 2025-11-04 21:45:52.642322
6+
7+
"""
8+
9+
import os
10+
from typing import Sequence, Union
11+
12+
from alembic import op
13+
import sqlalchemy as sa
14+
15+
16+
# revision identifiers, used by Alembic.
17+
revision: str = "e1ec1dcb50b6"
18+
down_revision: Union[str, None] = "a1b2c3d4e5f6"
19+
branch_labels: Union[str, Sequence[str], None] = None
20+
depends_on: Union[str, Sequence[str], None] = None
21+
22+
23+
def _get_column(inspector, table, name, schema=None):
24+
for col in inspector.get_columns(table, schema=schema):
25+
if col["name"] == name:
26+
return col
27+
return None
28+
29+
30+
def upgrade() -> None:
31+
conn = op.get_bind()
32+
insp = sa.inspect(conn)
33+
34+
last_accessed_column = _get_column(insp, "data", "last_accessed")
35+
if not last_accessed_column:
36+
# Always create the column for schema consistency
37+
op.add_column("data", sa.Column("last_accessed", sa.DateTime(timezone=True), nullable=True))
38+
39+
# Only initialize existing records if feature is enabled
40+
enable_last_accessed = os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true"
41+
if enable_last_accessed:
42+
op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP")
43+
44+
45+
def downgrade() -> None:
46+
conn = op.get_bind()
47+
insp = sa.inspect(conn)
48+
49+
last_accessed_column = _get_column(insp, "data", "last_accessed")
50+
if last_accessed_column:
51+
op.drop_column("data", "last_accessed")

cognee/modules/engine/models/Entity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from cognee.modules.engine.models.EntityType import EntityType
33
from typing import Optional
44

5+
56
class Entity(DataPoint):
67
name: str
78
is_a: Optional[EntityType] = None
Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
21
def get_entity_nodes_from_triplets(triplets):
3-
entity_nodes = []
4-
seen_ids = set()
5-
for triplet in triplets:
6-
if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids:
7-
entity_nodes.append({"id": str(triplet.node1.id)})
8-
seen_ids.add(triplet.node1.id)
9-
if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids:
10-
entity_nodes.append({"id": str(triplet.node2.id)})
11-
seen_ids.add(triplet.node2.id)
2+
entity_nodes = []
3+
seen_ids = set()
4+
for triplet in triplets:
5+
if hasattr(triplet, "node1") and triplet.node1 and triplet.node1.id not in seen_ids:
6+
entity_nodes.append({"id": str(triplet.node1.id)})
7+
seen_ids.add(triplet.node1.id)
8+
if hasattr(triplet, "node2") and triplet.node2 and triplet.node2.id not in seen_ids:
9+
entity_nodes.append({"id": str(triplet.node2.id)})
10+
seen_ids.add(triplet.node2.id)
1211

13-
return entity_nodes
12+
return entity_nodes

cognee/modules/retrieval/chunks_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cognee.modules.retrieval.base_retriever import BaseRetriever
66
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
77
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
8-
from datetime import datetime, timezone
8+
from datetime import datetime, timezone
99

1010
logger = get_logger("ChunksRetriever")
1111

@@ -28,7 +28,7 @@ def __init__(
2828
):
2929
self.top_k = top_k
3030

31-
async def get_context(self, query: str) -> Any:
31+
async def get_context(self, query: str) -> Any:
3232
"""
3333
Retrieves document chunks context based on the query.
3434
Searches for document chunks relevant to the specified query using a vector engine.

cognee/modules/retrieval/graph_completion_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ async def get_context(self, query: str) -> List[Edge]:
148148
# context = await self.resolve_edges_to_text(triplets)
149149

150150
entity_nodes = get_entity_nodes_from_triplets(triplets)
151-
152-
await update_node_access_timestamps(entity_nodes)
151+
152+
await update_node_access_timestamps(entity_nodes)
153153
return triplets
154154

155155
async def convert_retrieved_objects_to_context(self, triplets: List[Edge]):

cognee/modules/retrieval/summaries_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ async def get_context(self, query: str) -> Any:
5555
"TextSummary_text", query, limit=self.top_k
5656
)
5757
logger.info(f"Found {len(summaries_results)} summaries from vector search")
58-
58+
5959
await update_node_access_timestamps(summaries_results)
60-
60+
6161
except CollectionNotFoundError as error:
6262
logger.error("TextSummary_text collection not found in vector database")
6363
raise NoDataError("No data found in the system, please add data first.") from error
Lines changed: 87 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,88 @@
1-
"""Utilities for tracking data access in retrievers."""
2-
3-
import json
4-
from datetime import datetime, timezone
5-
from typing import List, Any
6-
from uuid import UUID
7-
import os
8-
from cognee.infrastructure.databases.graph import get_graph_engine
9-
from cognee.infrastructure.databases.relational import get_relational_engine
10-
from cognee.modules.data.models import Data
11-
from cognee.shared.logging_utils import get_logger
12-
from sqlalchemy import update
13-
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
14-
15-
logger = get_logger(__name__)
16-
17-
18-
async def update_node_access_timestamps(items: List[Any]):
19-
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
20-
return
21-
22-
if not items:
23-
return
24-
25-
graph_engine = await get_graph_engine()
26-
timestamp_dt = datetime.now(timezone.utc)
27-
28-
# Extract node IDs
29-
node_ids = []
30-
for item in items:
31-
item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id")
32-
if item_id:
33-
node_ids.append(str(item_id))
34-
35-
if not node_ids:
36-
return
37-
38-
# Focus on document-level tracking via projection
39-
try:
40-
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
41-
if doc_ids:
42-
await _update_sql_records(doc_ids, timestamp_dt)
43-
except Exception as e:
44-
logger.error(f"Failed to update SQL timestamps: {e}")
45-
raise
46-
47-
48-
async def _find_origin_documents_via_projection(graph_engine, node_ids):
49-
"""Find origin documents using graph projection instead of DB queries"""
50-
# Project the entire graph with necessary properties
51-
memory_fragment = CogneeGraph()
52-
await memory_fragment.project_graph_from_db(
53-
graph_engine,
54-
node_properties_to_project=["id", "type"],
55-
edge_properties_to_project=["relationship_name"]
56-
)
57-
58-
# Find origin documents by traversing the in-memory graph
59-
doc_ids = set()
60-
for node_id in node_ids:
61-
node = memory_fragment.get_node(node_id)
62-
if node and node.get_attribute("type") == "DocumentChunk":
63-
# Traverse edges to find connected documents
64-
for edge in node.get_skeleton_edges():
65-
# Get the neighbor node
66-
neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node()
67-
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
68-
doc_ids.add(neighbor.id)
69-
70-
return list(doc_ids)
71-
72-
73-
async def _update_sql_records(doc_ids, timestamp_dt):
74-
"""Update SQL Data table (same for all providers)"""
75-
db_engine = get_relational_engine()
76-
async with db_engine.get_async_session() as session:
77-
stmt = update(Data).where(
78-
Data.id.in_([UUID(doc_id) for doc_id in doc_ids])
79-
).values(last_accessed=timestamp_dt)
80-
81-
await session.execute(stmt)
1+
"""Utilities for tracking data access in retrievers."""
2+
3+
import json
4+
from datetime import datetime, timezone
5+
from typing import List, Any
6+
from uuid import UUID
7+
import os
8+
from cognee.infrastructure.databases.graph import get_graph_engine
9+
from cognee.infrastructure.databases.relational import get_relational_engine
10+
from cognee.modules.data.models import Data
11+
from cognee.shared.logging_utils import get_logger
12+
from sqlalchemy import update
13+
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
14+
15+
logger = get_logger(__name__)
16+
17+
18+
async def update_node_access_timestamps(items: List[Any]):
19+
if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true":
20+
return
21+
22+
if not items:
23+
return
24+
25+
graph_engine = await get_graph_engine()
26+
timestamp_dt = datetime.now(timezone.utc)
27+
28+
# Extract node IDs
29+
node_ids = []
30+
for item in items:
31+
item_id = item.payload.get("id") if hasattr(item, "payload") else item.get("id")
32+
if item_id:
33+
node_ids.append(str(item_id))
34+
35+
if not node_ids:
36+
return
37+
38+
# Focus on document-level tracking via projection
39+
try:
40+
doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids)
41+
if doc_ids:
42+
await _update_sql_records(doc_ids, timestamp_dt)
43+
except Exception as e:
44+
logger.error(f"Failed to update SQL timestamps: {e}")
45+
raise
46+
47+
48+
async def _find_origin_documents_via_projection(graph_engine, node_ids):
49+
"""Find origin documents using graph projection instead of DB queries"""
50+
# Project the entire graph with necessary properties
51+
memory_fragment = CogneeGraph()
52+
await memory_fragment.project_graph_from_db(
53+
graph_engine,
54+
node_properties_to_project=["id", "type"],
55+
edge_properties_to_project=["relationship_name"],
56+
)
57+
58+
# Find origin documents by traversing the in-memory graph
59+
doc_ids = set()
60+
for node_id in node_ids:
61+
node = memory_fragment.get_node(node_id)
62+
if node and node.get_attribute("type") == "DocumentChunk":
63+
# Traverse edges to find connected documents
64+
for edge in node.get_skeleton_edges():
65+
# Get the neighbor node
66+
neighbor = (
67+
edge.get_destination_node()
68+
if edge.get_source_node().id == node_id
69+
else edge.get_source_node()
70+
)
71+
if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]:
72+
doc_ids.add(neighbor.id)
73+
74+
return list(doc_ids)
75+
76+
77+
async def _update_sql_records(doc_ids, timestamp_dt):
78+
"""Update SQL Data table (same for all providers)"""
79+
db_engine = get_relational_engine()
80+
async with db_engine.get_async_session() as session:
81+
stmt = (
82+
update(Data)
83+
.where(Data.id.in_([UUID(doc_id) for doc_id in doc_ids]))
84+
.values(last_accessed=timestamp_dt)
85+
)
86+
87+
await session.execute(stmt)
8288
await session.commit()

0 commit comments

Comments
 (0)