Skip to content

[GA-153-0] Implement NodeDict update method #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections import UserDict
from collections.abc import Iterator
from typing import Any, Callable
Expand All @@ -10,10 +11,12 @@
from nx_arangodb.logger import logger

from ..function import (
ArangoDBBatchError,
aql,
aql_doc_get_key,
aql_doc_has_key,
aql_fetch_data,
check_list_for_errors,
doc_delete,
doc_insert,
doc_update,
Expand All @@ -27,6 +30,8 @@
keys_are_not_reserved,
keys_are_strings,
logger_debug,
separate_nodes_by_collections,
upsert_collection_documents,
)

#############
Expand Down Expand Up @@ -368,11 +373,45 @@ def clear(self) -> None:
self.FETCHED_ALL_DATA = False
self.FETCHED_ALL_IDS = False

@keys_are_strings
@logger_debug
def update_local_nodes(self, nodes: Any) -> None:
for node_id, node_data in nodes.items():
node_attr_dict = self.node_attr_dict_factory()
node_attr_dict.node_id = node_id
node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, node_data)

self.data[node_id] = node_attr_dict

@keys_are_strings
@logger_debug
def update(self, nodes: Any) -> None:
"""g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})"""
raise NotImplementedError("NodeDict.update()")
separated_by_collection = separate_nodes_by_collections(
nodes, self.default_node_type
)

result = upsert_collection_documents(self.db, separated_by_collection)

all_good = check_list_for_errors(result)
if all_good:
# Means no single operation failed, in this case we update the local cache
self.update_local_nodes(nodes)
else:
# In this case some or all documents failed. Right now we will not
# update the local cache, but raise an error instead.
# Reason: We cannot set silent to True, because we need as it does
# not report errors then. We need to update the driver to also pass
# the errors back to the user, then we can adjust the behavior here.
# This will also save network traffic and local computation time.
errors = []
for collections_results in result:
for collection_result in collections_results:
errors.append(collection_result)
warnings.warn(
"Failed to insert at least one node. Will not update local cache."
)
raise ArangoDBBatchError(errors)

@logger_debug
def values(self) -> Any:
Expand Down
127 changes: 117 additions & 10 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

from __future__ import annotations

from collections import UserDict
from typing import Any, Callable, Tuple

import networkx as nx
import numpy as np
import numpy.typing as npt
from arango import ArangoError, DocumentInsertError
from arango.collection import StandardCollection
from arango.cursor import Cursor
from arango.database import StandardDatabase
Expand All @@ -29,14 +27,9 @@
SrcIndices,
)

import nx_arangodb as nxadb
from nx_arangodb.logger import logger

from ..exceptions import (
AQLMultipleResultsFound,
GraphDoesNotExist,
InvalidTraversalDirection,
)
from ..exceptions import AQLMultipleResultsFound, InvalidTraversalDirection


def do_load_all_edge_attributes(attributes: set[str]) -> bool:
Expand Down Expand Up @@ -68,7 +61,7 @@ def get_arangodb_graph(
]:
"""Pulls the graph from the database, assuming the graph exists.

Returns the folowing representations:
Returns the following representations:
- Node dictionary (nx.Graph)
- Adjacency dictionary (nx.Graph)
- Source Indices (COO)
Expand Down Expand Up @@ -647,3 +640,117 @@ def get_update_dict(
update_dict = {key: update_dict}

return update_dict


class ArangoDBBatchError(ArangoError):
def __init__(self, errors):
self.errors = errors
super().__init__(self._format_errors())

def _format_errors(self):
return "\n".join(str(error) for error in self.errors)


def check_list_for_errors(lst):
for element in lst:
if element is type(bool):
if element is False:
return False

elif isinstance(element, list):
for sub_element in element:
if isinstance(sub_element, DocumentInsertError):
return False

return True


def extract_arangodb_key(arangodb_id):
assert "/" in arangodb_id
return arangodb_id.split("/")[1]


def extract_arangodb_collection_name(arangodb_id):
assert "/" in arangodb_id
return arangodb_id.split("/")[0]


def is_arangodb_id(key):
return "/" in key


def get_arangodb_collection_key_tuple(key):
assert is_arangodb_id(key)
if is_arangodb_id(key):
return key.split("/", 1)


def separate_nodes_by_collections(nodes: Any, default_collection: str) -> Any:
"""
Separate the dictionary into collections based on whether keys contain '/'.
:param nodes:
The input dictionary with keys that may or may not contain '/'.
:param default_collection:
The name of the default collection for keys without '/'.
:return: A dictionary where the keys are collection names and the
values are dictionaries of key-value pairs belonging to those
collections.
"""
separated: Any = {}

for key, value in nodes.items():
if is_arangodb_id(key):
collection, doc_key = get_arangodb_collection_key_tuple(key)
if collection not in separated:
separated[collection] = {}
separated[collection][doc_key] = value
else:
if default_collection not in separated:
separated[default_collection] = {}
separated[default_collection][key] = value

return separated


def transform_local_documents_for_adb(original_documents):
"""
Transform original documents into a format suitable for UPSERT
operations in ArangoDB.
:param original_documents: Original documents in the format
{'key': {'any-attr-key': 'any-attr-value'}}.
:return: List of documents with '_key' attribute and additional attributes.
"""
transformed_documents = []

for key, values in original_documents.items():
transformed_doc = {"_key": key}
transformed_doc.update(values)
transformed_documents.append(transformed_doc)

return transformed_documents


def upsert_collection_documents(db: StandardDatabase, separated: Any) -> Any:
"""
Process each collection in the separated dictionary.
:param db: The ArangoDB database object.
:param separated: A dictionary where the keys are collection names and the
values are dictionaries
of key-value pairs belonging to those collections.
:return: A list of results from the insert_many operation.
If inserting a document fails, the exception is not raised but
returned as an object in the result list.
"""

results = []

for collection_name, documents in separated.items():
collection = db.collection(collection_name)
transformed_documents = transform_local_documents_for_adb(documents)
results.append(
collection.insert_many(
transformed_documents, silent=False, overwrite_mode="update"
)
)

return results
23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def pytest_configure(config: Any) -> None:


@pytest.fixture(scope="function")
def load_graph() -> None:
def load_karate_graph() -> None:
global db
db.delete_graph("KarateGraph", drop_collections=True, ignore_missing=True)
adapter = ADBNX_Adapter(db)
Expand All @@ -64,3 +64,24 @@ def load_graph() -> None:
}
],
)


@pytest.fixture(scope="function")
def load_two_relation_graph() -> None:
global db
graph_name = "IntegrationTestTwoRelationGraph"
v1 = graph_name + "_v1"
v2 = graph_name + "_v2"
e1 = graph_name + "_e1"
e2 = graph_name + "_e2"

if db.has_graph(graph_name):
db.delete_graph(graph_name, drop_collections=True)

g = db.create_graph(graph_name)
g.create_edge_definition(
e1, from_vertex_collections=[v1], to_vertex_collections=[v2]
)
g.create_edge_definition(
e2, from_vertex_collections=[v2], to_vertex_collections=[v1]
)
Loading