Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cognee/api/v1/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def search(
triplet_distance_penalty: Optional[float] = 3.5,
verbose: bool = False,
retriever_specific_config: Optional[dict] = None,
neighborhood_depth: Optional[int] = None,
) -> List[SearchResult]:
Comment on lines +44 to 45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate neighborhood_depth before forwarding it.

Line 233 passes the new public parameter through unchanged. 0, negative values, or non-ints will currently reach the adapter get_neighborhood() queries and build invalid [*1..N] path patterns instead of returning a clear API error.

🛡️ Suggested guard
 async def search(
     query_text: str,
@@
     retriever_specific_config: Optional[dict] = None,
     neighborhood_depth: Optional[int] = None,
 ) -> List[SearchResult]:
+    if neighborhood_depth is not None and (
+        not isinstance(neighborhood_depth, int) or neighborhood_depth < 1
+    ):
+        raise CogneeValidationError(
+            message="neighborhood_depth must be a positive integer.",
+            name="InvalidNeighborhoodDepth",
+        )
+
     """

Also applies to: 217-233

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee/api/v1/search/search.py` around lines 44 - 45, The new public
parameter neighborhood_depth is forwarded unchanged to the adapter
get_neighborhood(), allowing 0, negative or non-int values to create invalid
path patterns; validate neighborhood_depth early in the containing function (the
public search handler that returns List[SearchResult]) by checking it is an
integer > 0 (and within any configured max if applicable), and if not
raise/return a clear API error (e.g., BadRequest/ValueError) before calling
get_neighborhood(); update callers that pass neighborhood_depth through (the
code around where neighborhood_depth is forwarded) to rely on this validated
value.

"""
Search and query the knowledge graph for insights, information, and connections.
Expand Down Expand Up @@ -229,6 +230,7 @@ async def search(
triplet_distance_penalty=triplet_distance_penalty,
verbose=verbose,
retriever_specific_config=retriever_specific_config,
neighborhood_depth=neighborhood_depth,
)

n = len(filtered_search_results) if filtered_search_results else 0
Expand Down
24 changes: 24 additions & 0 deletions cognee/infrastructure/databases/graph/graph_db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,30 @@ async def get_connections(
"""
raise NotImplementedError

@abstractmethod
async def get_neighborhood(
self,
node_ids: List[str],
depth: int = 1,
edge_types: Optional[List[str]] = None,
) -> Tuple[List[Node], List[EdgeData]]:
"""
Get the k-hop neighborhood subgraph around a set of seed nodes.

Returns all nodes and edges within `depth` hops of any seed node,
in the same format as get_graph_data().
Optional edge_type filtering to constrain traversal paths.

Parameters:
-----------

- node_ids (List[str]): Seed node identifiers to start traversal from.
- depth (int): Number of hops to traverse from each seed node. (default 1)
- edge_types (Optional[List[str]]): If provided, only traverse edges of these
relationship types. (default None)
"""
raise NotImplementedError

@abstractmethod
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
Expand Down
100 changes: 100 additions & 0 deletions cognee/infrastructure/databases/graph/kuzu/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,106 @@ async def get_graph_data(
logger.error(f"Failed to get graph data: {e}")
raise

async def get_neighborhood(
self,
node_ids: List[str],
depth: int = 1,
edge_types: Optional[List[str]] = None,
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
"""
Get the k-hop neighborhood subgraph around a set of seed nodes.

Returns all nodes and edges within `depth` hops of any seed node,
in the same format as get_graph_data().
"""
import time

start_time = time.time()

try:
if not node_ids:
logger.warning("No node IDs provided for neighborhood retrieval.")
return [], []

# Use variable-length path to find all nodes within depth hops
path_query = f"""
MATCH (seed:Node)-[r*1..{depth}]-(neighbor:Node)
WHERE seed.id IN $node_ids{" AND ALL(rel IN r WHERE rel.relationship_name IN $edge_types)" if edge_types else ""}
RETURN DISTINCT neighbor.id
"""
params = {"node_ids": node_ids}
if edge_types:
params["edge_types"] = edge_types

neighbor_rows = await self.query(path_query, params)
neighbor_ids = [row[0] for row in neighbor_rows if row[0]]

# Combine seed nodes and neighbor nodes
all_ids = list(set(node_ids) | set(neighbor_ids))

# Fetch all nodes
nodes_query = """
MATCH (n:Node)
WHERE n.id IN $ids
RETURN n.id, {
name: n.name,
type: n.type,
properties: n.properties
}
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
formatted_nodes = []
for n in node_rows:
if n[0]:
node_id = str(n[0])
props = n[1]
if props.get("properties"):
try:
additional_props = json.loads(props["properties"])
props.update(additional_props)
del props["properties"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse properties JSON for node {node_id}")
formatted_nodes.append((node_id, props))

if not formatted_nodes:
logger.warning("No nodes found in neighborhood.")
return [], []

# Fetch all edges between the collected nodes
edges_query = """
MATCH (n:Node)-[r]->(m:Node)
WHERE n.id IN $ids AND m.id IN $ids
RETURN n.id, m.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
formatted_edges = []
for e in edge_rows:
if e and len(e) >= 3:
source_id = str(e[0])
target_id = str(e[1])
rel_type = str(e[2])
props = {}
if len(e) > 3 and e[3]:
try:
props = json.loads(e[3])
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Failed to parse edge properties for {source_id}->{target_id}"
)
formatted_edges.append((source_id, target_id, rel_type, props))

retrieval_time = time.time() - start_time
logger.info(
f"Neighborhood retrieval ({depth}-hop): {len(formatted_nodes)} nodes and "
f"{len(formatted_edges)} edges in {retrieval_time:.2f}s"
)
return formatted_nodes, formatted_edges

except Exception as e:
logger.error(f"Failed to get neighborhood: {e}")
raise

async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
Expand Down
85 changes: 85 additions & 0 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,91 @@ async def get_graph_data(self):
logger.error(f"Error during graph data retrieval: {str(e)}")
raise

async def get_neighborhood(
self,
node_ids: List[str],
depth: int = 1,
edge_types: Optional[List[str]] = None,
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, str, str, Dict[str, Any]]]]:
"""
Get the k-hop neighborhood subgraph around a set of seed nodes.

Returns all nodes and edges within `depth` hops of any seed node,
in the same format as get_graph_data().
"""
import time

start_time = time.time()

try:
if not node_ids:
logger.warning("No node IDs provided for neighborhood retrieval.")
return [], []

# Collect all node IDs within depth hops, then fetch nodes and edges
if edge_types:
path_query = f"""
MATCH path = (seed)-[*1..{depth}]-(neighbor)
WHERE seed.id IN $node_ids
AND ALL(r IN relationships(path) WHERE TYPE(r) IN $edge_types)
RETURN DISTINCT neighbor.id AS nid
"""
else:
path_query = f"""
MATCH (seed)-[*1..{depth}]-(neighbor)
WHERE seed.id IN $node_ids
RETURN DISTINCT neighbor.id AS nid
"""

params = {"node_ids": node_ids}
if edge_types:
params["edge_types"] = edge_types

result = await self.query(path_query, params)
neighbor_ids = [record["nid"] for record in result if record.get("nid")]

all_ids = list(set(node_ids) | set(neighbor_ids))

# Step 2: Fetch all nodes
nodes_query = """
MATCH (n)
WHERE n.id IN $ids
RETURN n.id AS id, properties(n) AS properties
"""
nodes_result = await self.query(nodes_query, {"ids": all_ids})
nodes = []
for record in nodes_result:
nodes.append((record["properties"]["id"], record["properties"]))

# Step 3: Fetch all edges between collected nodes
edges_query = """
MATCH (n)-[r]->(m)
WHERE n.id IN $ids AND m.id IN $ids
RETURN properties(r) AS properties, TYPE(r) AS type
"""
edges_result = await self.query(edges_query, {"ids": all_ids})
edges = []
for record in edges_result:
edges.append(
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
)

retrieval_time = time.time() - start_time
logger.info(
f"Neighborhood retrieval ({depth}-hop): {len(nodes)} nodes and "
f"{len(edges)} edges in {retrieval_time:.2f}s"
)
return (nodes, edges)

except Exception as e:
logger.error(f"Error during neighborhood retrieval: {str(e)}")
raise

async def get_id_filtered_graph_data(self, target_ids: list[str]):
"""
Retrieve graph data filtered by specific node IDs, including their direct neighbors
Expand Down
69 changes: 69 additions & 0 deletions cognee/infrastructure/databases/graph/neptune_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,75 @@ async def get_graph_data(self) -> Tuple[List[Node], List[EdgeData]]:
logger.error(f"Failed to get graph data: {error_msg}")
raise Exception(f"Failed to get graph data: {error_msg}") from e

async def get_neighborhood(
self,
node_ids: List[str],
depth: int = 1,
edge_types: Optional[List[str]] = None,
) -> Tuple[List[Node], List[EdgeData]]:
"""
Get the k-hop neighborhood subgraph around a set of seed nodes.

Returns all nodes and edges within `depth` hops of any seed node,
in the same format as get_graph_data().
"""
try:
if not node_ids:
logger.warning("No node IDs provided for neighborhood retrieval.")
return [], []

# Step 1: Find all neighbor node IDs within depth hops
if edge_types:
allowed = "|".join(edge_types)
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[:{allowed}*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT id(neighbor) AS nid
"""
else:
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT id(neighbor) AS nid
"""

result = await self.query(path_query, {"node_ids": node_ids})
neighbor_ids = [record["nid"] for record in result if record.get("nid")]

all_ids = list(set(node_ids) | set(neighbor_ids))

# Step 2: Fetch all nodes
nodes_query = f"""
MATCH (n:{self._GRAPH_NODE_LABEL})
WHERE id(n) IN $ids
RETURN id(n) AS node_id, properties(n) AS properties
"""
nodes_result = await self.query(nodes_query, {"ids": all_ids})
nodes = [(r["node_id"], r["properties"]) for r in nodes_result]

# Step 3: Fetch all edges between collected nodes
edges_query = f"""
MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
WHERE id(source) IN $ids AND id(target) IN $ids
RETURN id(source) AS source_id, id(target) AS target_id,
type(r) AS relationship_name, properties(r) AS properties
"""
edges_result = await self.query(edges_query, {"ids": all_ids})
Comment on lines +690 to +725
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep get_neighborhood() on a single ID domain.

Line 694 matches seed nodes by ~id, but Lines 712 and 721 switch to id(n) / id(source). That makes all_ids a mix of external IDs and Neptune internal IDs, so seed nodes and their incident edges can disappear from the returned neighborhood.

🔧 One consistent way to fix it
             if edge_types:
                 allowed = "|".join(edge_types)
                 path_query = f"""
                 MATCH (seed:{self._GRAPH_NODE_LABEL})-[:{allowed}*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
                 WHERE seed.`~id` IN $node_ids
-                RETURN DISTINCT id(neighbor) AS nid
+                RETURN DISTINCT neighbor.`~id` AS nid
                 """
             else:
                 path_query = f"""
                 MATCH (seed:{self._GRAPH_NODE_LABEL})-[*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
                 WHERE seed.`~id` IN $node_ids
-                RETURN DISTINCT id(neighbor) AS nid
+                RETURN DISTINCT neighbor.`~id` AS nid
                 """
@@
             nodes_query = f"""
             MATCH (n:{self._GRAPH_NODE_LABEL})
-            WHERE id(n) IN $ids
-            RETURN id(n) AS node_id, properties(n) AS properties
+            WHERE n.`~id` IN $ids
+            RETURN n.`~id` AS node_id, properties(n) AS properties
             """
@@
             edges_query = f"""
             MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
-            WHERE id(source) IN $ids AND id(target) IN $ids
-            RETURN id(source) AS source_id, id(target) AS target_id,
+            WHERE source.`~id` IN $ids AND target.`~id` IN $ids
+            RETURN source.`~id` AS source_id, target.`~id` AS target_id,
                    type(r) AS relationship_name, properties(r) AS properties
             """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if edge_types:
allowed = "|".join(edge_types)
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[:{allowed}*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT id(neighbor) AS nid
"""
else:
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT id(neighbor) AS nid
"""
result = await self.query(path_query, {"node_ids": node_ids})
neighbor_ids = [record["nid"] for record in result if record.get("nid")]
all_ids = list(set(node_ids) | set(neighbor_ids))
# Step 2: Fetch all nodes
nodes_query = f"""
MATCH (n:{self._GRAPH_NODE_LABEL})
WHERE id(n) IN $ids
RETURN id(n) AS node_id, properties(n) AS properties
"""
nodes_result = await self.query(nodes_query, {"ids": all_ids})
nodes = [(r["node_id"], r["properties"]) for r in nodes_result]
# Step 3: Fetch all edges between collected nodes
edges_query = f"""
MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
WHERE id(source) IN $ids AND id(target) IN $ids
RETURN id(source) AS source_id, id(target) AS target_id,
type(r) AS relationship_name, properties(r) AS properties
"""
edges_result = await self.query(edges_query, {"ids": all_ids})
if edge_types:
allowed = "|".join(edge_types)
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[:{allowed}*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT neighbor.`~id` AS nid
"""
else:
path_query = f"""
MATCH (seed:{self._GRAPH_NODE_LABEL})-[*1..{depth}]-(neighbor:{self._GRAPH_NODE_LABEL})
WHERE seed.`~id` IN $node_ids
RETURN DISTINCT neighbor.`~id` AS nid
"""
result = await self.query(path_query, {"node_ids": node_ids})
neighbor_ids = [record["nid"] for record in result if record.get("nid")]
all_ids = list(set(node_ids) | set(neighbor_ids))
# Step 2: Fetch all nodes
nodes_query = f"""
MATCH (n:{self._GRAPH_NODE_LABEL})
WHERE n.`~id` IN $ids
RETURN n.`~id` AS node_id, properties(n) AS properties
"""
nodes_result = await self.query(nodes_query, {"ids": all_ids})
nodes = [(r["node_id"], r["properties"]) for r in nodes_result]
# Step 3: Fetch all edges between collected nodes
edges_query = f"""
MATCH (source:{self._GRAPH_NODE_LABEL})-[r]->(target:{self._GRAPH_NODE_LABEL})
WHERE source.`~id` IN $ids AND target.`~id` IN $ids
RETURN source.`~id` AS source_id, target.`~id` AS target_id,
type(r) AS relationship_name, properties(r) AS properties
"""
edges_result = await self.query(edges_query, {"ids": all_ids})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cognee/infrastructure/databases/graph/neptune_driver/adapter.py` around lines
690 - 725, get_neighborhood() mixes external node IDs (~id) and internal Neptune
ids (id(n)), causing mismatches; ensure the same ID domain is used throughout by
returning and filtering on the external id property (`~id`). Update the
path_query to RETURN neighbor.`~id` (collect into neighbor_ids), build all_ids
as union of node_ids and those neighbor `~id`s, change nodes_query to WHERE
n.`~id` IN $ids and RETURN n.`~id` AS node_id, and change edges_query to WHERE
source.`~id` IN $ids AND target.`~id` IN $ids and RETURN source.`~id` AS
source_id, target.`~id` AS target_id (keep function name get_neighborhood and
variables path_query, nodes_query, edges_query, all_ids, neighbor_ids, node_ids
to locate changes).

edges = [
(r["source_id"], r["target_id"], r["relationship_name"], r["properties"])
for r in edges_result
]

logger.debug(
f"Neighborhood retrieval ({depth}-hop): {len(nodes)} nodes and {len(edges)} edges"
)
return (nodes, edges)

except Exception as e:
error_msg = format_neptune_error(e)
logger.error(f"Failed to get neighborhood: {error_msg}")
raise Exception(f"Failed to get neighborhood: {error_msg}") from e

async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
"""
Fetch metrics and statistics of the graph, possibly including optional details.
Expand Down
Loading
Loading